diff --git a/arial.ttf b/arial.ttf new file mode 100644 index 0000000..7ff88f2 Binary files /dev/null and b/arial.ttf differ diff --git a/gradio_utils/__init__.py b/gradio_utils/__init__.py new file mode 100644 index 0000000..6a54920 --- /dev/null +++ b/gradio_utils/__init__.py @@ -0,0 +1,9 @@ +from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) + +__all__ = [ + 'draw_mask_on_image', 'draw_points_on_image', + 'on_change_single_global_state', 'get_latest_points_pair', + 'get_valid_mask', 'ImageMask' +] diff --git a/gradio_utils/utils.py b/gradio_utils/utils.py new file mode 100644 index 0000000..d4e760e --- /dev/null +++ b/gradio_utils/utils.py @@ -0,0 +1,154 @@ +import gradio as gr +import numpy as np +from PIL import Image, ImageDraw + + +class ImageMask(gr.components.Image): + """ + Sets: source="canvas", tool="sketch" + """ + + is_template = True + + def __init__(self, **kwargs): + super().__init__(source="upload", + tool="sketch", + interactive=False, + **kwargs) + + def preprocess(self, x): + if x is None: + return x + if self.tool == "sketch" and self.source in ["upload", "webcam" + ] and type(x) != dict: + decode_image = gr.processing_utils.decode_base64_to_image(x) + width, height = decode_image.size + mask = np.ones((height, width, 4), dtype=np.uint8) + mask[..., -1] = 255 + mask = self.postprocess(mask) + x = {'image': x, 'mask': mask} + return super().preprocess(x) + + +def get_valid_mask(mask: np.ndarray): + """Convert mask from gr.Image(0 to 255, RGBA) to binary mask. + """ + if mask.ndim == 3: + mask_pil = Image.fromarray(mask).convert('L') + mask = np.array(mask_pil) + if mask.max() == 255: + mask = mask / 255 + return mask + + +def draw_points_on_image(image, + points, + curr_point=None, + highlight_all=True, + radius_scale=0.01): + overlay_rgba = Image.new("RGBA", image.size, 0) + overlay_draw = ImageDraw.Draw(overlay_rgba) + for point_key, point in points.items(): + if ((curr_point is not None and curr_point == point_key) + or highlight_all): + p_color = (255, 0, 0) + t_color = (0, 0, 255) + + else: + p_color = (255, 0, 0, 35) + t_color = (0, 0, 255, 35) + + rad_draw = int(image.size[0] * radius_scale) + + p_start = point.get("start_temp", point["start"]) + p_target = point["target"] + + if p_start is not None and p_target is not None: + p_draw = int(p_start[0]), int(p_start[1]) + t_draw = int(p_target[0]), int(p_target[1]) + + overlay_draw.line( + (p_draw[0], p_draw[1], t_draw[0], t_draw[1]), + fill=(255, 255, 0), + width=2, + ) + + if p_start is not None: + p_draw = int(p_start[0]), int(p_start[1]) + overlay_draw.ellipse( + ( + p_draw[0] - rad_draw, + p_draw[1] - rad_draw, + p_draw[0] + rad_draw, + p_draw[1] + rad_draw, + ), + fill=p_color, + ) + + if curr_point is not None and curr_point == point_key: + # overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0)) + overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0)) + + if p_target is not None: + t_draw = int(p_target[0]), int(p_target[1]) + overlay_draw.ellipse( + ( + t_draw[0] - rad_draw, + t_draw[1] - rad_draw, + t_draw[0] + rad_draw, + t_draw[1] + rad_draw, + ), + fill=t_color, + ) + + if curr_point is not None and curr_point == point_key: + # overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0)) + overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0)) + + return Image.alpha_composite(image.convert("RGBA"), + overlay_rgba).convert("RGB") + + +def draw_mask_on_image(image, mask): + im_mask = np.uint8(mask * 255) + im_mask_rgba = np.concatenate( + ( + np.tile(im_mask[..., None], [1, 1, 3]), + 45 * np.ones( + (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8), + ), + axis=-1, + ) + im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA") + + return Image.alpha_composite(image.convert("RGBA"), + im_mask_rgba).convert("RGB") + + +def on_change_single_global_state(keys, + value, + global_state, + map_transform=None): + if map_transform is not None: + value = map_transform(value) + + curr_state = global_state + if isinstance(keys, str): + last_key = keys + + else: + for k in keys[:-1]: + curr_state = curr_state[k] + + last_key = keys[-1] + + curr_state[last_key] = value + return global_state + + +def get_latest_points_pair(points_dict): + if not points_dict: + return None + point_idx = list(points_dict.keys()) + latest_point_idx = max(point_idx) + return latest_point_idx diff --git a/visualizer_drag_gradio.py b/visualizer_drag_gradio.py new file mode 100644 index 0000000..d1cdca9 --- /dev/null +++ b/visualizer_drag_gradio.py @@ -0,0 +1,866 @@ +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import gradio as gr +import numpy as np +import torch +from PIL import Image + +import dnnlib +from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) +from viz.renderer import Renderer, add_watermark_np + +parser = ArgumentParser() +parser.add_argument('--share', action='store_true') +parser.add_argument('--cache-dir', type=str, default='./checkpoints') +args = parser.parse_args() + +cache_dir = args.cache_dir + +device = 'cuda' + + +def reverse_point_pairs(points): + new_points = [] + for p in points: + new_points.append([p[1], p[0]]) + return new_points + + +def clear_state(global_state, target=None): + """Clear target history state from global_state + If target is not defined, points and mask will be both removed. + 1. set global_state['points'] as empty dict + 2. set global_state['mask'] as full-one mask. + """ + if target is None: + target = ['point', 'mask'] + if not isinstance(target, list): + target = [target] + if 'point' in target: + global_state['points'] = dict() + print('Clear Points State!') + if 'mask' in target: + image_raw = global_state["images"]["image_raw"] + global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), + dtype=np.uint8) + print('Clear mask State!') + + return global_state + + +def init_images(global_state): + """This function is called only ones with Gradio App is started. + 0. pre-process global_state, unpack value from global_state of need + 1. Re-init renderer + 2. run `renderer._render_drag_impl` with `is_drag=False` to generate + new image + 3. Assign images to global state and re-generate mask + """ + + if isinstance(global_state, gr.State): + state = global_state.value + else: + state = global_state + + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, + None, # w_load + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, + None, # input_transform + state['params']['lr'] # lr, + ) + + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) + return global_state + + +def update_image_draw(image, points, mask, show_mask, global_state=None): + + image_draw = draw_points_on_image(image, points) + if show_mask and mask is not None and not (mask == 0).all() and not ( + mask == 1).all(): + image_draw = draw_mask_on_image(image_draw, mask) + + image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) + if global_state is not None: + global_state['images']['image_show'] = image_draw + return image_draw + + +def preprocess_mask_info(global_state, image): + """Function to handle mask information. + 1. last_mask is None: Do not need to change mask, return mask + 2. last_mask is not None: + 2.1 global_state is remove_mask: + 2.2 global_state is add_mask: + """ + if isinstance(image, dict): + last_mask = get_valid_mask(image['mask']) + else: + last_mask = None + mask = global_state['mask'] + + # mask in global state is a placeholder with all 1. + if (mask == 1).all(): + mask = last_mask + + # last_mask = global_state['last_mask'] + editing_mode = global_state['editing_state'] + + if last_mask is None: + return global_state + + if editing_mode == 'remove_mask': + updated_mask = np.clip(mask - last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do remove.') + elif editing_mode == 'add_mask': + updated_mask = np.clip(mask + last_mask, 0, 1) + print(f'Last editing_state is {editing_mode}, do add.') + else: + updated_mask = mask + print(f'Last editing_state is {editing_mode}, ' + 'do nothing to mask.') + + global_state['mask'] = updated_mask + # global_state['last_mask'] = None # clear buffer + return global_state + + +valid_checkpoints_dict = { + f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) + for f in os.listdir(cache_dir) + if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) +} +print(f'File under cache_dir ({cache_dir}):') +print(os.listdir(cache_dir)) +print('Valid checkpoint file:') +print(valid_checkpoints_dict) + +init_pkl = 'stylegan_human_v2_512' + +with gr.Blocks() as app: + + # renderer = Renderer() + global_state = gr.State({ + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': init_pkl + }) + + # init image + global_state = init_images(global_state) + + with gr.Row(): + + with gr.Row(): + + # Left --> tools + with gr.Column(scale=3): + + # Pickle + with gr.Row(): + + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Pickle', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_pretrained_dropdown = gr.Dropdown( + choices=list(valid_checkpoints_dict.keys()), + label="Pretrained Model", + value=init_pkl, + ) + + # Latent + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Latent', show_label=False) + + with gr.Column(scale=4, min_width=10): + form_seed_number = gr.Number( + value=global_state.value['params']['seed'], + interactive=True, + label="Seed", + ) + form_lr_number = gr.Number( + value=global_state.value["params"]["lr"], + interactive=True, + label="Step Size") + + with gr.Row(): + with gr.Column(scale=2, min_width=10): + form_reset_image = gr.Button("Reset Image") + with gr.Column(scale=3, min_width=10): + form_latent_space = gr.Radio( + ['w', 'w+'], + value=global_state.value['params'] + ['latent_space'], + interactive=True, + label='Latent space to optimize', + show_label=False, + ) + + # Drag + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Drag', show_label=False) + with gr.Column(scale=4, min_width=10): + with gr.Row(): + with gr.Column(scale=1, min_width=10): + enable_add_points = gr.Button('Add Points') + with gr.Column(scale=1, min_width=10): + undo_points = gr.Button('Reset Points') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_start_btn = gr.Button("Start") + with gr.Column(scale=1, min_width=10): + form_stop_btn = gr.Button("Stop") + + form_steps_number = gr.Number(value=0, + label="Steps", + interactive=False) + + # Mask + with gr.Row(): + with gr.Column(scale=1, min_width=10): + gr.Markdown(value='Mask', show_label=False) + with gr.Column(scale=4, min_width=10): + enable_add_mask = gr.Button('Edit Flexible Area') + with gr.Row(): + with gr.Column(scale=1, min_width=10): + form_reset_mask_btn = gr.Button("Reset mask") + with gr.Column(scale=1, min_width=10): + show_mask = gr.Checkbox( + label='Show Mask', + value=global_state.value['show_mask'], + show_label=False) + + with gr.Row(): + form_lambda_number = gr.Number( + value=global_state.value["params"] + ["motion_lambda"], + interactive=True, + label="Lambda", + ) + + form_draw_interval_number = gr.Number( + value=global_state.value["draw_interval"], + label="Draw Interval (steps)", + interactive=True, + visible=False) + + # Right --> Image + with gr.Column(scale=8): + form_image = ImageMask( + value=global_state.value['images']['image_show'], + brush_radius=20).style( + width=768, + height=768) # NOTE: hard image size code here. + gr.Markdown(""" + ## Quick Start + + 1. Select desired `Pretrained Model` and adjust `Seed` to generate an + initial image. + 2. Click on image to add control points. + 3. Click `Start` and enjoy it! + + ## Advance Usage + + 1. Change `Step Size` to adjust learning rate in drag optimization. + 2. Select `w` or `w+` to change latent space to optimize: + * Optimize on `w` space may cause greater influence to the image. + * Optimize on `w+` space may work slower than `w`, but usually achieve + better results. + * Note that changing the latent space will reset the image, points and + mask (this has the same effect as `Reset Image` button). + 3. Click `Edit Flexible Area` to create a mask and constrain the + unmasked region to remain unchanged. + """) + gr.HTML(""" + +
+ Gradio demo supported by + + OpenMMLab MMagic +
+ """) + + # Network & latents tab listeners + def on_change_pretrained_dropdown(pretrained_value, global_state): + """Function to handle model change. + 1. Set pretrained value to global_state + 2. Re-init images and clear all states + """ + + global_state['pretrained_weight'] = pretrained_value + init_images(global_state) + clear_state(global_state) + + return global_state, global_state["images"]['image_show'] + + form_pretrained_dropdown.change( + on_change_pretrained_dropdown, + inputs=[form_pretrained_dropdown, global_state], + outputs=[global_state, form_image], + ) + + def on_click_reset_image(global_state): + """Reset image to the original one and clear all states + 1. Re-init images + 2. Clear all states + """ + + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + form_reset_image.click( + on_click_reset_image, + inputs=[global_state], + outputs=[global_state, form_image], + ) + + # Update parameters + def on_change_update_image_seed(seed, global_state): + """Function to handle generation seed change. + 1. Set seed to global_state + 2. Re-init images and clear all states + """ + + global_state["params"]["seed"] = int(seed) + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + form_seed_number.change( + on_change_update_image_seed, + inputs=[form_seed_number, global_state], + outputs=[global_state, form_image], + ) + + def on_click_latent_space(latent_space, global_state): + """Function to reset latent space to optimize. + NOTE: this function we reset the image and all controls + 1. Set latent-space to global_state + 2. Re-init images and clear all state + """ + + global_state['params']['latent_space'] = latent_space + init_images(global_state) + clear_state(global_state) + + return global_state, global_state['images']['image_show'] + + form_latent_space.change(on_click_latent_space, + inputs=[form_latent_space, global_state], + outputs=[global_state, form_image]) + + # ==== Params + form_lambda_number.change( + partial(on_change_single_global_state, ["params", "motion_lambda"]), + inputs=[form_lambda_number, global_state], + outputs=[global_state], + ) + + def on_change_lr(lr, global_state): + if lr == 0: + print('lr is 0, do nothing.') + return global_state + else: + global_state["params"]["lr"] = lr + renderer = global_state['renderer'] + renderer.update_lr(lr) + print('New optimizer: ') + print(renderer.w_optim) + return global_state + + form_lr_number.change( + on_change_lr, + inputs=[form_lr_number, global_state], + outputs=[global_state], + ) + + def on_click_start(global_state, image): + p_in_pixels = [] + t_in_pixels = [] + valid_points = [] + + # handle of start drag in mask editing mode + global_state = preprocess_mask_info(global_state, image) + + # Prepare the points for the inference + if len(global_state["points"]) == 0: + # yield on_click_start_wo_points(global_state, image) + image_raw = global_state['images']['image_raw'] + update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + yield ( + global_state, + 0, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Checkbox.update(interactive=True), + # gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + ) + else: + + # Transform the points into torch tensors + for key_point, point in global_state["points"].items(): + try: + p_start = point.get("start_temp", point["start"]) + p_end = point["target"] + + if p_start is None or p_end is None: + continue + + except KeyError: + continue + + p_in_pixels.append(p_start) + t_in_pixels.append(p_end) + valid_points.append(key_point) + + mask = torch.tensor(global_state['mask']).float() + drag_mask = 1 - mask + + renderer: Renderer = global_state["renderer"] + global_state['temporal_params']['stop'] = False + global_state['editing_state'] = 'running' + + # reverse points order + p_to_opt = reverse_point_pairs(p_in_pixels) + t_to_opt = reverse_point_pairs(t_in_pixels) + print('Running with:') + print(f' Source: {p_in_pixels}') + print(f' Target: {t_in_pixels}') + step_idx = 0 + while True: + if global_state["temporal_params"]["stop"]: + break + + # do drage here! + renderer._render_drag_impl( + global_state['generator_params'], + p_to_opt, # point + t_to_opt, # target + drag_mask, # mask, + global_state['params']['motion_lambda'], # lambda_mask + reg=0, + feature_idx=5, # NOTE: do not support change for now + r1=global_state['params']['r1_in_pixels'], # r1 + r2=global_state['params']['r2_in_pixels'], # r2 + # random_seed = 0, + # noise_mode = 'const', + trunc_psi=global_state['params']['trunc_psi'], + # force_fp32 = False, + # layer_name = None, + # sel_channels = 3, + # base_channel = 0, + # img_scale_db = 0, + # img_normalize = False, + # untransform = False, + is_drag=True, + to_pil=True) + + if step_idx % global_state['draw_interval'] == 0: + print('Current Source:') + for key_point, p_i, t_i in zip(valid_points, p_to_opt, + t_to_opt): + global_state["points"][key_point]["start_temp"] = [ + p_i[1], + p_i[0], + ] + global_state["points"][key_point]["target"] = [ + t_i[1], + t_i[0], + ] + start_temp = global_state["points"][key_point][ + "start_temp"] + print(f' {start_temp}') + + image_result = global_state['generator_params']['image'] + image_draw = update_image_draw( + image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + global_state['images']['image_raw'] = image_result + + yield ( + global_state, + step_idx, + global_state['images']['image_show'], + # gr.File.update(visible=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + # latent space + gr.Radio.update(interactive=False), + gr.Button.update(interactive=False), + # enable stop button in loop + gr.Button.update(interactive=True), + + # update other comps + gr.Dropdown.update(interactive=False), + gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + gr.Button.update(interactive=False), + gr.Button.update(interactive=False), + gr.Checkbox.update(interactive=False), + # gr.Number.update(interactive=False), + gr.Number.update(interactive=False), + ) + + # increate step + step_idx += 1 + + image_result = global_state['generator_params']['image'] + global_state['images']['image_raw'] = image_result + image_draw = update_image_draw(image_result, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state) + + # fp = NamedTemporaryFile(suffix=".png", delete=False) + # image_result.save(fp, "PNG") + + global_state['editing_state'] = 'add_points' + + yield ( + global_state, + 0, # reset step to 0 after stop. + global_state['images']['image_show'], + # gr.File.update(visible=True, value=fp.name), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + gr.Button.update(interactive=True), + # latent space + gr.Radio.update(interactive=True), + gr.Button.update(interactive=True), + # NOTE: disable stop button with loop finish + gr.Button.update(interactive=False), + + # update other comps + gr.Dropdown.update(interactive=True), + gr.Number.update(interactive=True), + gr.Number.update(interactive=True), + gr.Checkbox.update(interactive=True), + gr.Number.update(interactive=True), + ) + + form_start_btn.click( + on_click_start, + inputs=[global_state, form_image], + outputs=[ + global_state, + form_steps_number, + form_image, + # form_download_result_file, + # >>> buttons + form_reset_image, + enable_add_points, + enable_add_mask, + undo_points, + form_reset_mask_btn, + form_latent_space, + form_start_btn, + form_stop_btn, + # <<< buttonm + # >>> inputs comps + form_pretrained_dropdown, + form_seed_number, + form_lr_number, + show_mask, + form_lambda_number, + ], + ) + + def on_click_stop(global_state): + """Function to handle stop button is clicked. + 1. send a stop signal by set global_state["temporal_params"]["stop"] as True + 2. Disable Stop button + """ + global_state["temporal_params"]["stop"] = True + + return global_state, gr.Button.update(interactive=False) + + form_stop_btn.click(on_click_stop, + inputs=[global_state], + outputs=[global_state, form_stop_btn]) + + form_draw_interval_number.change( + partial( + on_change_single_global_state, + "draw_interval", + map_transform=lambda x: int(x), + ), + inputs=[form_draw_interval_number, global_state], + outputs=[global_state], + ) + + def on_click_remove_point(global_state): + choice = global_state["curr_point"] + del global_state["points"][choice] + + choices = list(global_state["points"].keys()) + + if len(choices) > 0: + global_state["curr_point"] = choices[0] + + return ( + gr.Dropdown.update(choices=choices, value=choices[0]), + global_state, + ) + + # Mask + def on_click_reset_mask(global_state): + global_state['mask'] = np.ones( + ( + global_state["images"]["image_raw"].size[1], + global_state["images"]["image_raw"].size[0], + ), + dtype=np.uint8, + ) + image_draw = update_image_draw(global_state['images']['image_raw'], + global_state['points'], + global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + form_reset_mask_btn.click( + on_click_reset_mask, + inputs=[global_state], + outputs=[global_state, form_image], + ) + + # Image + def on_click_enable_draw(global_state, image): + """Function to start add mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to add_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + + def on_click_remove_draw(global_state, image): + """Function to start remove mask mode. + 1. Preprocess mask info from last state + 2. Change editing state to remove_mask + 3. Set curr image with points and mask + """ + global_state = preprocess_mask_info(global_state, image) + global_state['edinting_state'] = 'remove_mask' + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], + global_state['mask'], True, + global_state) + return (global_state, + gr.Image.update(value=image_draw, interactive=True)) + + enable_add_mask.click(on_click_enable_draw, + inputs=[global_state, form_image], + outputs=[ + global_state, + form_image, + ]) + + def on_click_add_point(global_state, image: dict): + """Function switch from add mask mode to add points mode. + 1. Updaste mask buffer if need + 2. Change global_state['editing_state'] to 'add_points' + 3. Set current image with mask + """ + + global_state = preprocess_mask_info(global_state, image) + global_state['editing_state'] = 'add_points' + mask = global_state['mask'] + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, global_state['points'], mask, + global_state['show_mask'], global_state) + + return (global_state, + gr.Image.update(value=image_draw, interactive=False)) + + enable_add_points.click(on_click_add_point, + inputs=[global_state, form_image], + outputs=[global_state, form_image]) + + def on_click_image(global_state, evt: gr.SelectData): + """This function only support click for point selection + """ + xy = evt.index + if global_state['editing_state'] != 'add_points': + print(f'In {global_state["editing_state"]} state. ' + 'Do not add points.') + + return global_state, global_state['images']['image_show'] + + points = global_state["points"] + + point_idx = get_latest_points_pair(points) + if point_idx is None: + points[0] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + elif points[point_idx].get('target', None) is None: + points[point_idx]['target'] = xy + print(f'Click Image - Target - {xy}') + else: + points[point_idx + 1] = {'start': xy, 'target': None} + print(f'Click Image - Start - {xy}') + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + + return global_state, image_draw + + form_image.select( + on_click_image, + inputs=[global_state], + outputs=[global_state, form_image], + ) + + def on_click_clear_points(global_state): + """Function to handle clear all control points + 1. clear global_state['points'] (clear_state) + 2. re-init network + 2. re-draw image + """ + clear_state(global_state, target='point') + + renderer: Renderer = global_state["renderer"] + renderer.feat_refs = None + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw(image_raw, {}, global_state['mask'], + global_state['show_mask'], global_state) + return global_state, image_draw + + undo_points.click(on_click_clear_points, + inputs=[global_state], + outputs=[global_state, form_image]) + + def on_click_show_mask(global_state, show_mask): + """Function to control whether show mask on image.""" + global_state['show_mask'] = show_mask + + image_raw = global_state['images']['image_raw'] + image_draw = update_image_draw( + image_raw, + global_state['points'], + global_state['mask'], + global_state['show_mask'], + global_state, + ) + return global_state, image_draw + + show_mask.change( + on_click_show_mask, + inputs=[global_state, show_mask], + outputs=[global_state, form_image], + ) + +gr.close_all() +app.queue(concurrency_count=5, max_size=20) +app.launch(share=args.share) diff --git a/viz/renderer.py b/viz/renderer.py index 488b699..3a7d228 100644 --- a/viz/renderer.py +++ b/viz/renderer.py @@ -47,7 +47,7 @@ class CaptureSuccess(Exception): def add_watermark_np(input_image_array, watermark_text="Watermark"): image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA") - + # Initialize text image txt = Image.new('RGBA', image.size, (255, 255, 255, 0)) font = ImageFont.truetype('arial.ttf', round(25/512*image.size[0])) @@ -68,20 +68,25 @@ def add_watermark_np(input_image_array, watermark_text="Watermark"): #---------------------------------------------------------------------------- class Renderer: - def __init__(self): + def __init__(self, disable_timing=False): self._device = torch.device('cuda') self._pkl_data = dict() # {pkl: dict | CapturedException, ...} self._networks = dict() # {cache_key: torch.nn.Module, ...} self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} self._cmaps = dict() # {name: torch.Tensor, ...} self._is_timing = False - self._start_event = torch.cuda.Event(enable_timing=True) - self._end_event = torch.cuda.Event(enable_timing=True) + if not disable_timing: + self._start_event = torch.cuda.Event(enable_timing=True) + self._end_event = torch.cuda.Event(enable_timing=True) + self._disable_timing = disable_timing self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} def render(self, **args): - self._is_timing = True - self._start_event.record(torch.cuda.current_stream(self._device)) + if self._disable_timing: + self._is_timing = False + else: + self._start_event.record(torch.cuda.current_stream(self._device)) + self._is_timing = True res = dnnlib.EasyDict() try: init_net = False @@ -90,9 +95,6 @@ class Renderer: if hasattr(self, 'pkl'): if self.pkl != args['pkl']: init_net = True - if hasattr(self, 'w_load'): - if self.w_load is not args['w_load']: - init_net = True if hasattr(self, 'w0_seed'): if self.w0_seed != args['w0_seed']: init_net = True @@ -107,7 +109,8 @@ class Renderer: self._render_drag_impl(res, **args) except: res.error = CapturedException() - self._end_event.record(torch.cuda.current_stream(self._device)) + if not self._disable_timing: + self._end_event.record(torch.cuda.current_stream(self._device)) if 'image' in res: res.image = self.to_cpu(res.image).detach().numpy() res.image = add_watermark_np(res.image, 'AI Generated') @@ -115,8 +118,9 @@ class Renderer: res.stats = self.to_cpu(res.stats).detach().numpy() if 'error' in res: res.error = str(res.error) + # if 'stop' in res and res.stop: - if self._is_timing: + if self._is_timing and not self._disable_timing: self._end_event.synchronize() res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 self._is_timing = False @@ -228,6 +232,7 @@ class Renderer: res.error = CapturedException() G.synthesis.input.transform.copy_(torch.from_numpy(m)) + # Generate random latents. self.w0_seed = w0_seed self.w_load = w_load @@ -246,13 +251,20 @@ class Renderer: if w_plus: self.w = w.detach() else: - self.w = w[:,0,:].detach() + self.w = w[:, 0, :].detach() self.w.requires_grad = True self.w_optim = torch.optim.Adam([self.w], lr=lr) self.feat_refs = None self.points0_pt = None + def update_lr(self, lr): + + del self.w_optim + self.w_optim = torch.optim.Adam([self.w], lr=lr) + print(f'Rebuild optimizer with lr: {lr}') + print(' Remain feat_refs and points0_pt') + def _render_drag_impl(self, res, points = [], targets = [], @@ -274,6 +286,7 @@ class Renderer: untransform = False, is_drag = False, reset = False, + to_pil = False, **kwargs ): G = self.G @@ -361,6 +374,10 @@ class Renderer: img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) img = img * (10 ** (img_scale_db / 20)) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + if to_pil: + from PIL import Image + img = img.cpu().numpy() + img = Image.fromarray(img) res.image = img #----------------------------------------------------------------------------