diff --git a/visualizer_drag_gradio.py b/visualizer_drag_gradio.py index 023ce16..699787d 100644 --- a/visualizer_drag_gradio.py +++ b/visualizer_drag_gradio.py @@ -9,24 +9,19 @@ import torch from PIL import Image import dnnlib -from gradio_utils import ( - ImageMask, - draw_mask_on_image, - draw_points_on_image, - get_latest_points_pair, - get_valid_mask, - on_change_single_global_state, -) +from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, + get_latest_points_pair, get_valid_mask, + on_change_single_global_state) from viz.renderer import Renderer, add_watermark_np parser = ArgumentParser() -parser.add_argument("--share", action="store_true") -parser.add_argument("--cache-dir", type=str, default="./checkpoints") +parser.add_argument('--share', action='store_true') +parser.add_argument('--cache-dir', type=str, default='./checkpoints') args = parser.parse_args() cache_dir = args.cache_dir -device = "cuda" +device = 'cuda' def reverse_point_pairs(points): @@ -43,18 +38,17 @@ def clear_state(global_state, target=None): 2. set global_state['mask'] as full-one mask. """ if target is None: - target = ["point", "mask"] + target = ['point', 'mask'] if not isinstance(target, list): target = [target] - if "point" in target: - global_state["points"] = dict() - print("Clear Points State!") - if "mask" in target: + if 'point' in target: + global_state['points'] = dict() + print('Clear Points State!') + if 'mask' in target: image_raw = global_state["images"]["image_raw"] - global_state["mask"] = np.ones( - (image_raw.size[1], image_raw.size[0]), dtype=np.uint8 - ) - print("Clear mask State!") + global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), + dtype=np.uint8) + print('Clear mask State!') return global_state @@ -73,46 +67,43 @@ def init_images(global_state): else: state = global_state - state["renderer"].init_network( - state["generator_params"], # res - valid_checkpoints_dict[state["pretrained_weight"]], # pkl - state["params"]["seed"], # w0_seed, + state['renderer'].init_network( + state['generator_params'], # res + valid_checkpoints_dict[state['pretrained_weight']], # pkl + state['params']['seed'], # w0_seed, None, # w_load - state["params"]["latent_space"] == "w+", # w_plus - "const", - state["params"]["trunc_psi"], # trunc_psi, - state["params"]["trunc_cutoff"], # trunc_cutoff, + state['params']['latent_space'] == 'w+', # w_plus + 'const', + state['params']['trunc_psi'], # trunc_psi, + state['params']['trunc_cutoff'], # trunc_cutoff, None, # input_transform - state["params"]["lr"], # lr, - ) - - state["renderer"]._render_drag_impl( - state["generator_params"], is_drag=False, to_pil=True + state['params']['lr'] # lr, ) - init_image = state["generator_params"].image - state["images"]["image_orig"] = init_image - state["images"]["image_raw"] = init_image - state["images"]["image_show"] = Image.fromarray( - add_watermark_np(np.array(init_image)) - ) - state["mask"] = np.ones((init_image.size[1], init_image.size[0]), dtype=np.uint8) + state['renderer']._render_drag_impl(state['generator_params'], + is_drag=False, + to_pil=True) + + init_image = state['generator_params'].image + state['images']['image_orig'] = init_image + state['images']['image_raw'] = init_image + state['images']['image_show'] = Image.fromarray( + add_watermark_np(np.array(init_image))) + state['mask'] = np.ones((init_image.size[1], init_image.size[0]), + dtype=np.uint8) return global_state def update_image_draw(image, points, mask, show_mask, global_state=None): + image_draw = draw_points_on_image(image, points) - if ( - show_mask - and mask is not None - and not (mask == 0).all() - and not (mask == 1).all() - ): + if show_mask and mask is not None and not (mask == 0).all() and not ( + mask == 1).all(): image_draw = draw_mask_on_image(image_draw, mask) image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) if global_state is not None: - global_state["images"]["image_show"] = image_draw + global_state['images']['image_show'] = image_draw return image_draw @@ -124,97 +115,102 @@ def preprocess_mask_info(global_state, image): 2.2 global_state is add_mask: """ if isinstance(image, dict): - last_mask = get_valid_mask(image["mask"]) + last_mask = get_valid_mask(image['mask']) else: last_mask = None - mask = global_state["mask"] + mask = global_state['mask'] # mask in global state is a placeholder with all 1. if (mask == 1).all(): mask = last_mask # last_mask = global_state['last_mask'] - editing_mode = global_state["editing_state"] + editing_mode = global_state['editing_state'] if last_mask is None: return global_state - if editing_mode == "remove_mask": + if editing_mode == 'remove_mask': updated_mask = np.clip(mask - last_mask, 0, 1) - print(f"Last editing_state is {editing_mode}, do remove.") - elif editing_mode == "add_mask": + print(f'Last editing_state is {editing_mode}, do remove.') + elif editing_mode == 'add_mask': updated_mask = np.clip(mask + last_mask, 0, 1) - print(f"Last editing_state is {editing_mode}, do add.") + print(f'Last editing_state is {editing_mode}, do add.') else: updated_mask = mask - print(f"Last editing_state is {editing_mode}, " "do nothing to mask.") + print(f'Last editing_state is {editing_mode}, ' + 'do nothing to mask.') - global_state["mask"] = updated_mask + global_state['mask'] = updated_mask # global_state['last_mask'] = None # clear buffer return global_state valid_checkpoints_dict = { - f.split("/")[-1].split(".")[0]: osp.join(cache_dir, f) + f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f) for f in os.listdir(cache_dir) - if (f.endswith("pkl") and osp.exists(osp.join(cache_dir, f))) + if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f))) } -print(f"File under cache_dir ({cache_dir}):") +print(f'File under cache_dir ({cache_dir}):') print(os.listdir(cache_dir)) -print("Valid checkpoint file:") +print('Valid checkpoint file:') print(valid_checkpoints_dict) -init_pkl = list(valid_checkpoints_dict.keys())[4] +init_pkl = 'stylegan_human_v2_512' with gr.Blocks() as app: + # renderer = Renderer() - global_state = gr.State( - { - "images": { - # image_orig: the original image, change with seed/model is changed - # image_raw: image with mask and points, change durning optimization - # image_show: image showed on screen - }, - "temporal_params": { - # stop - }, - "mask": None, # mask for visualization, 1 for editing and 0 for unchange - "last_mask": None, # last edited mask - "show_mask": True, # add button - "generator_params": dnnlib.EasyDict(), - "params": { - "seed": 0, - "motion_lambda": 20, - "r1_in_pixels": 3, - "r2_in_pixels": 12, - "magnitude_direction_in_pixels": 1.0, - "latent_space": "w+", - "trunc_psi": 0.7, - "trunc_cutoff": None, - "lr": 0.001, - }, - "device": device, - "draw_interval": 1, - "renderer": Renderer(disable_timing=True), - "points": {}, - "curr_point": None, - "curr_type_point": "start", - "editing_state": "add_points", - "pretrained_weight": init_pkl, - } - ) + global_state = gr.State({ + "images": { + # image_orig: the original image, change with seed/model is changed + # image_raw: image with mask and points, change durning optimization + # image_show: image showed on screen + }, + "temporal_params": { + # stop + }, + 'mask': + None, # mask for visualization, 1 for editing and 0 for unchange + 'last_mask': None, # last edited mask + 'show_mask': True, # add button + "generator_params": dnnlib.EasyDict(), + "params": { + "seed": 0, + "motion_lambda": 20, + "r1_in_pixels": 3, + "r2_in_pixels": 12, + "magnitude_direction_in_pixels": 1.0, + "latent_space": "w+", + "trunc_psi": 0.7, + "trunc_cutoff": None, + "lr": 0.001, + }, + "device": device, + "draw_interval": 1, + "renderer": Renderer(disable_timing=True), + "points": {}, + "curr_point": None, + "curr_type_point": "start", + 'editing_state': 'add_points', + 'pretrained_weight': init_pkl + }) # init image global_state = init_images(global_state) with gr.Row(): + with gr.Row(): + # Left --> tools with gr.Column(scale=3): + # Pickle with gr.Row(): + with gr.Column(scale=1, min_width=10): - gr.Markdown(value="Pickle", show_label=False) + gr.Markdown(value='Pickle', show_label=False) with gr.Column(scale=4, min_width=10): form_pretrained_dropdown = gr.Dropdown( @@ -226,71 +222,71 @@ with gr.Blocks() as app: # Latent with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value="Latent", show_label=False) + gr.Markdown(value='Latent', show_label=False) with gr.Column(scale=4, min_width=10): form_seed_number = gr.Number( - value=global_state.value["params"]["seed"], + value=global_state.value['params']['seed'], interactive=True, label="Seed", ) form_lr_number = gr.Number( value=global_state.value["params"]["lr"], interactive=True, - label="Step Size", - ) + label="Step Size") with gr.Row(): with gr.Column(scale=2, min_width=10): form_reset_image = gr.Button("Reset Image") with gr.Column(scale=3, min_width=10): form_latent_space = gr.Radio( - ["w", "w+"], - value=global_state.value["params"]["latent_space"], + ['w', 'w+'], + value=global_state.value['params'] + ['latent_space'], interactive=True, - label="Latent space to optimize", + label='Latent space to optimize', show_label=False, ) # Drag with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value="Drag", show_label=False) + gr.Markdown(value='Drag', show_label=False) with gr.Column(scale=4, min_width=10): with gr.Row(): with gr.Column(scale=1, min_width=10): - enable_add_points = gr.Button("Add Points") + enable_add_points = gr.Button('Add Points') with gr.Column(scale=1, min_width=10): - undo_points = gr.Button("Reset Points") + undo_points = gr.Button('Reset Points') with gr.Row(): with gr.Column(scale=1, min_width=10): form_start_btn = gr.Button("Start") with gr.Column(scale=1, min_width=10): form_stop_btn = gr.Button("Stop") - form_steps_number = gr.Number( - value=0, label="Steps", interactive=False - ) + form_steps_number = gr.Number(value=0, + label="Steps", + interactive=False) # Mask with gr.Row(): with gr.Column(scale=1, min_width=10): - gr.Markdown(value="Mask", show_label=False) + gr.Markdown(value='Mask', show_label=False) with gr.Column(scale=4, min_width=10): - enable_add_mask = gr.Button("Edit Flexible Area") + enable_add_mask = gr.Button('Edit Flexible Area') with gr.Row(): with gr.Column(scale=1, min_width=10): form_reset_mask_btn = gr.Button("Reset mask") with gr.Column(scale=1, min_width=10): show_mask = gr.Checkbox( - label="Show Mask", - value=global_state.value["show_mask"], - show_label=False, - ) + label='Show Mask', + value=global_state.value['show_mask'], + show_label=False) with gr.Row(): form_lambda_number = gr.Number( - value=global_state.value["params"]["motion_lambda"], + value=global_state.value["params"] + ["motion_lambda"], interactive=True, label="Lambda", ) @@ -299,18 +295,16 @@ with gr.Blocks() as app: value=global_state.value["draw_interval"], label="Draw Interval (steps)", interactive=True, - visible=False, - ) + visible=False) # Right --> Image with gr.Column(scale=8): form_image = ImageMask( - value=global_state.value["images"]["image_show"], brush_radius=20 - ).style( - width=768, height=768 - ) # NOTE: hard image size code here. - gr.Markdown( - """ + value=global_state.value['images']['image_show'], + brush_radius=20).style( + width=768, + height=768) # NOTE: hard image size code here. + gr.Markdown(""" ## Quick Start 1. Select desired `Pretrained Model` and adjust `Seed` to generate an @@ -329,10 +323,8 @@ with gr.Blocks() as app: mask (this has the same effect as `Reset Image` button). 3. Click `Edit Flexible Area` to create a mask and constrain the unmasked region to remain unchanged. - """ - ) - gr.HTML( - """ + """) + gr.HTML("""