mirror of https://github.com/XingangPan/DragGAN
				
				
				
			support gradio demo for DragGAN
							parent
							
								
									91b9322e2f
								
							
						
					
					
						commit
						ab719c4a7a
					
				@ -0,0 +1,9 @@
 | 
			
		||||
from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
 | 
			
		||||
                    get_latest_points_pair, get_valid_mask,
 | 
			
		||||
                    on_change_single_global_state)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'draw_mask_on_image', 'draw_points_on_image',
 | 
			
		||||
    'on_change_single_global_state', 'get_latest_points_pair',
 | 
			
		||||
    'get_valid_mask', 'ImageMask'
 | 
			
		||||
]
 | 
			
		||||
@ -0,0 +1,154 @@
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image, ImageDraw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageMask(gr.components.Image):
 | 
			
		||||
    """
 | 
			
		||||
    Sets: source="canvas", tool="sketch"
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    is_template = True
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        super().__init__(source="upload",
 | 
			
		||||
                         tool="sketch",
 | 
			
		||||
                         interactive=False,
 | 
			
		||||
                         **kwargs)
 | 
			
		||||
 | 
			
		||||
    def preprocess(self, x):
 | 
			
		||||
        if x is None:
 | 
			
		||||
            return x
 | 
			
		||||
        if self.tool == "sketch" and self.source in ["upload", "webcam"
 | 
			
		||||
                                                     ] and type(x) != dict:
 | 
			
		||||
            decode_image = gr.processing_utils.decode_base64_to_image(x)
 | 
			
		||||
            width, height = decode_image.size
 | 
			
		||||
            mask = np.ones((height, width, 4), dtype=np.uint8)
 | 
			
		||||
            mask[..., -1] = 255
 | 
			
		||||
            mask = self.postprocess(mask)
 | 
			
		||||
            x = {'image': x, 'mask': mask}
 | 
			
		||||
        return super().preprocess(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_valid_mask(mask: np.ndarray):
 | 
			
		||||
    """Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
 | 
			
		||||
    """
 | 
			
		||||
    if mask.ndim == 3:
 | 
			
		||||
        mask_pil = Image.fromarray(mask).convert('L')
 | 
			
		||||
        mask = np.array(mask_pil)
 | 
			
		||||
    if mask.max() == 255:
 | 
			
		||||
        mask = mask / 255
 | 
			
		||||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_points_on_image(image,
 | 
			
		||||
                         points,
 | 
			
		||||
                         curr_point=None,
 | 
			
		||||
                         highlight_all=True,
 | 
			
		||||
                         radius_scale=0.01):
 | 
			
		||||
    overlay_rgba = Image.new("RGBA", image.size, 0)
 | 
			
		||||
    overlay_draw = ImageDraw.Draw(overlay_rgba)
 | 
			
		||||
    for point_key, point in points.items():
 | 
			
		||||
        if ((curr_point is not None and curr_point == point_key)
 | 
			
		||||
                or highlight_all):
 | 
			
		||||
            p_color = (255, 0, 0)
 | 
			
		||||
            t_color = (0, 0, 255)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            p_color = (255, 0, 0, 35)
 | 
			
		||||
            t_color = (0, 0, 255, 35)
 | 
			
		||||
 | 
			
		||||
        rad_draw = int(image.size[0] * radius_scale)
 | 
			
		||||
 | 
			
		||||
        p_start = point.get("start_temp", point["start"])
 | 
			
		||||
        p_target = point["target"]
 | 
			
		||||
 | 
			
		||||
        if p_start is not None and p_target is not None:
 | 
			
		||||
            p_draw = int(p_start[0]), int(p_start[1])
 | 
			
		||||
            t_draw = int(p_target[0]), int(p_target[1])
 | 
			
		||||
 | 
			
		||||
            overlay_draw.line(
 | 
			
		||||
                (p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
 | 
			
		||||
                fill=(255, 255, 0),
 | 
			
		||||
                width=2,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if p_start is not None:
 | 
			
		||||
            p_draw = int(p_start[0]), int(p_start[1])
 | 
			
		||||
            overlay_draw.ellipse(
 | 
			
		||||
                (
 | 
			
		||||
                    p_draw[0] - rad_draw,
 | 
			
		||||
                    p_draw[1] - rad_draw,
 | 
			
		||||
                    p_draw[0] + rad_draw,
 | 
			
		||||
                    p_draw[1] + rad_draw,
 | 
			
		||||
                ),
 | 
			
		||||
                fill=p_color,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if curr_point is not None and curr_point == point_key:
 | 
			
		||||
                # overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
 | 
			
		||||
                overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
 | 
			
		||||
 | 
			
		||||
        if p_target is not None:
 | 
			
		||||
            t_draw = int(p_target[0]), int(p_target[1])
 | 
			
		||||
            overlay_draw.ellipse(
 | 
			
		||||
                (
 | 
			
		||||
                    t_draw[0] - rad_draw,
 | 
			
		||||
                    t_draw[1] - rad_draw,
 | 
			
		||||
                    t_draw[0] + rad_draw,
 | 
			
		||||
                    t_draw[1] + rad_draw,
 | 
			
		||||
                ),
 | 
			
		||||
                fill=t_color,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if curr_point is not None and curr_point == point_key:
 | 
			
		||||
                # overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
 | 
			
		||||
                overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
 | 
			
		||||
 | 
			
		||||
    return Image.alpha_composite(image.convert("RGBA"),
 | 
			
		||||
                                 overlay_rgba).convert("RGB")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_mask_on_image(image, mask):
 | 
			
		||||
    im_mask = np.uint8(mask * 255)
 | 
			
		||||
    im_mask_rgba = np.concatenate(
 | 
			
		||||
        (
 | 
			
		||||
            np.tile(im_mask[..., None], [1, 1, 3]),
 | 
			
		||||
            45 * np.ones(
 | 
			
		||||
                (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
 | 
			
		||||
        ),
 | 
			
		||||
        axis=-1,
 | 
			
		||||
    )
 | 
			
		||||
    im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
 | 
			
		||||
 | 
			
		||||
    return Image.alpha_composite(image.convert("RGBA"),
 | 
			
		||||
                                 im_mask_rgba).convert("RGB")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_change_single_global_state(keys,
 | 
			
		||||
                                  value,
 | 
			
		||||
                                  global_state,
 | 
			
		||||
                                  map_transform=None):
 | 
			
		||||
    if map_transform is not None:
 | 
			
		||||
        value = map_transform(value)
 | 
			
		||||
 | 
			
		||||
    curr_state = global_state
 | 
			
		||||
    if isinstance(keys, str):
 | 
			
		||||
        last_key = keys
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        for k in keys[:-1]:
 | 
			
		||||
            curr_state = curr_state[k]
 | 
			
		||||
 | 
			
		||||
        last_key = keys[-1]
 | 
			
		||||
 | 
			
		||||
    curr_state[last_key] = value
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_latest_points_pair(points_dict):
 | 
			
		||||
    if not points_dict:
 | 
			
		||||
        return None
 | 
			
		||||
    point_idx = list(points_dict.keys())
 | 
			
		||||
    latest_point_idx = max(point_idx)
 | 
			
		||||
    return latest_point_idx
 | 
			
		||||
@ -0,0 +1,866 @@
 | 
			
		||||
import os
 | 
			
		||||
import os.path as osp
 | 
			
		||||
from argparse import ArgumentParser
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
import dnnlib
 | 
			
		||||
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
 | 
			
		||||
                          get_latest_points_pair, get_valid_mask,
 | 
			
		||||
                          on_change_single_global_state)
 | 
			
		||||
from viz.renderer import Renderer, add_watermark_np
 | 
			
		||||
 | 
			
		||||
parser = ArgumentParser()
 | 
			
		||||
parser.add_argument('--share', action='store_true')
 | 
			
		||||
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
 | 
			
		||||
args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
cache_dir = args.cache_dir
 | 
			
		||||
 | 
			
		||||
device = 'cuda'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reverse_point_pairs(points):
 | 
			
		||||
    new_points = []
 | 
			
		||||
    for p in points:
 | 
			
		||||
        new_points.append([p[1], p[0]])
 | 
			
		||||
    return new_points
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_state(global_state, target=None):
 | 
			
		||||
    """Clear target history state from global_state
 | 
			
		||||
    If target is not defined, points and mask will be both removed.
 | 
			
		||||
    1. set global_state['points'] as empty dict
 | 
			
		||||
    2. set global_state['mask'] as full-one mask.
 | 
			
		||||
    """
 | 
			
		||||
    if target is None:
 | 
			
		||||
        target = ['point', 'mask']
 | 
			
		||||
    if not isinstance(target, list):
 | 
			
		||||
        target = [target]
 | 
			
		||||
    if 'point' in target:
 | 
			
		||||
        global_state['points'] = dict()
 | 
			
		||||
        print('Clear Points State!')
 | 
			
		||||
    if 'mask' in target:
 | 
			
		||||
        image_raw = global_state["images"]["image_raw"]
 | 
			
		||||
        global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
 | 
			
		||||
                                       dtype=np.uint8)
 | 
			
		||||
        print('Clear mask State!')
 | 
			
		||||
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_images(global_state):
 | 
			
		||||
    """This function is called only ones with Gradio App is started.
 | 
			
		||||
    0. pre-process global_state, unpack value from global_state of need
 | 
			
		||||
    1. Re-init renderer
 | 
			
		||||
    2. run `renderer._render_drag_impl` with `is_drag=False` to generate
 | 
			
		||||
       new image
 | 
			
		||||
    3. Assign images to global state and re-generate mask
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if isinstance(global_state, gr.State):
 | 
			
		||||
        state = global_state.value
 | 
			
		||||
    else:
 | 
			
		||||
        state = global_state
 | 
			
		||||
 | 
			
		||||
    state['renderer'].init_network(
 | 
			
		||||
        state['generator_params'],  # res
 | 
			
		||||
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
 | 
			
		||||
        state['params']['seed'],  # w0_seed,
 | 
			
		||||
        None,  # w_load
 | 
			
		||||
        state['params']['latent_space'] == 'w+',  # w_plus
 | 
			
		||||
        'const',
 | 
			
		||||
        state['params']['trunc_psi'],  # trunc_psi,
 | 
			
		||||
        state['params']['trunc_cutoff'],  # trunc_cutoff,
 | 
			
		||||
        None,  # input_transform
 | 
			
		||||
        state['params']['lr']  # lr,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    state['renderer']._render_drag_impl(state['generator_params'],
 | 
			
		||||
                                        is_drag=False,
 | 
			
		||||
                                        to_pil=True)
 | 
			
		||||
 | 
			
		||||
    init_image = state['generator_params'].image
 | 
			
		||||
    state['images']['image_orig'] = init_image
 | 
			
		||||
    state['images']['image_raw'] = init_image
 | 
			
		||||
    state['images']['image_show'] = Image.fromarray(
 | 
			
		||||
        add_watermark_np(np.array(init_image)))
 | 
			
		||||
    state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
 | 
			
		||||
                            dtype=np.uint8)
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_image_draw(image, points, mask, show_mask, global_state=None):
 | 
			
		||||
 | 
			
		||||
    image_draw = draw_points_on_image(image, points)
 | 
			
		||||
    if show_mask and mask is not None and not (mask == 0).all() and not (
 | 
			
		||||
            mask == 1).all():
 | 
			
		||||
        image_draw = draw_mask_on_image(image_draw, mask)
 | 
			
		||||
 | 
			
		||||
    image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
 | 
			
		||||
    if global_state is not None:
 | 
			
		||||
        global_state['images']['image_show'] = image_draw
 | 
			
		||||
    return image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_mask_info(global_state, image):
 | 
			
		||||
    """Function to handle mask information.
 | 
			
		||||
    1. last_mask is None: Do not need to change mask, return mask
 | 
			
		||||
    2. last_mask is not None:
 | 
			
		||||
        2.1 global_state is remove_mask:
 | 
			
		||||
        2.2 global_state is add_mask:
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(image, dict):
 | 
			
		||||
        last_mask = get_valid_mask(image['mask'])
 | 
			
		||||
    else:
 | 
			
		||||
        last_mask = None
 | 
			
		||||
    mask = global_state['mask']
 | 
			
		||||
 | 
			
		||||
    # mask in global state is a placeholder with all 1.
 | 
			
		||||
    if (mask == 1).all():
 | 
			
		||||
        mask = last_mask
 | 
			
		||||
 | 
			
		||||
    # last_mask = global_state['last_mask']
 | 
			
		||||
    editing_mode = global_state['editing_state']
 | 
			
		||||
 | 
			
		||||
    if last_mask is None:
 | 
			
		||||
        return global_state
 | 
			
		||||
 | 
			
		||||
    if editing_mode == 'remove_mask':
 | 
			
		||||
        updated_mask = np.clip(mask - last_mask, 0, 1)
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, do remove.')
 | 
			
		||||
    elif editing_mode == 'add_mask':
 | 
			
		||||
        updated_mask = np.clip(mask + last_mask, 0, 1)
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, do add.')
 | 
			
		||||
    else:
 | 
			
		||||
        updated_mask = mask
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, '
 | 
			
		||||
              'do nothing to mask.')
 | 
			
		||||
 | 
			
		||||
    global_state['mask'] = updated_mask
 | 
			
		||||
    # global_state['last_mask'] = None  # clear buffer
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
valid_checkpoints_dict = {
 | 
			
		||||
    f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
 | 
			
		||||
    for f in os.listdir(cache_dir)
 | 
			
		||||
    if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
 | 
			
		||||
}
 | 
			
		||||
print(f'File under cache_dir ({cache_dir}):')
 | 
			
		||||
print(os.listdir(cache_dir))
 | 
			
		||||
print('Valid checkpoint file:')
 | 
			
		||||
print(valid_checkpoints_dict)
 | 
			
		||||
 | 
			
		||||
init_pkl = 'stylegan_human_v2_512'
 | 
			
		||||
 | 
			
		||||
with gr.Blocks() as app:
 | 
			
		||||
 | 
			
		||||
    # renderer = Renderer()
 | 
			
		||||
    global_state = gr.State({
 | 
			
		||||
        "images": {
 | 
			
		||||
            # image_orig: the original image, change with seed/model is changed
 | 
			
		||||
            # image_raw: image with mask and points, change durning optimization
 | 
			
		||||
            # image_show: image showed on screen
 | 
			
		||||
        },
 | 
			
		||||
        "temporal_params": {
 | 
			
		||||
            # stop
 | 
			
		||||
        },
 | 
			
		||||
        'mask':
 | 
			
		||||
        None,  # mask for visualization, 1 for editing and 0 for unchange
 | 
			
		||||
        'last_mask': None,  # last edited mask
 | 
			
		||||
        'show_mask': True,  # add button
 | 
			
		||||
        "generator_params": dnnlib.EasyDict(),
 | 
			
		||||
        "params": {
 | 
			
		||||
            "seed": 0,
 | 
			
		||||
            "motion_lambda": 20,
 | 
			
		||||
            "r1_in_pixels": 3,
 | 
			
		||||
            "r2_in_pixels": 12,
 | 
			
		||||
            "magnitude_direction_in_pixels": 1.0,
 | 
			
		||||
            "latent_space": "w+",
 | 
			
		||||
            "trunc_psi": 0.7,
 | 
			
		||||
            "trunc_cutoff": None,
 | 
			
		||||
            "lr": 0.001,
 | 
			
		||||
        },
 | 
			
		||||
        "device": device,
 | 
			
		||||
        "draw_interval": 1,
 | 
			
		||||
        "renderer": Renderer(disable_timing=True),
 | 
			
		||||
        "points": {},
 | 
			
		||||
        "curr_point": None,
 | 
			
		||||
        "curr_type_point": "start",
 | 
			
		||||
        'editing_state': 'add_points',
 | 
			
		||||
        'pretrained_weight': init_pkl
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
    # init image
 | 
			
		||||
    global_state = init_images(global_state)
 | 
			
		||||
 | 
			
		||||
    with gr.Row():
 | 
			
		||||
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
 | 
			
		||||
            # Left --> tools
 | 
			
		||||
            with gr.Column(scale=3):
 | 
			
		||||
 | 
			
		||||
                # Pickle
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
 | 
			
		||||
                    with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                        gr.Markdown(value='Pickle', show_label=False)
 | 
			
		||||
 | 
			
		||||
                    with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                        form_pretrained_dropdown = gr.Dropdown(
 | 
			
		||||
                            choices=list(valid_checkpoints_dict.keys()),
 | 
			
		||||
                            label="Pretrained Model",
 | 
			
		||||
                            value=init_pkl,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                # Latent
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
                    with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                        gr.Markdown(value='Latent', show_label=False)
 | 
			
		||||
 | 
			
		||||
                    with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                        form_seed_number = gr.Number(
 | 
			
		||||
                            value=global_state.value['params']['seed'],
 | 
			
		||||
                            interactive=True,
 | 
			
		||||
                            label="Seed",
 | 
			
		||||
                        )
 | 
			
		||||
                        form_lr_number = gr.Number(
 | 
			
		||||
                            value=global_state.value["params"]["lr"],
 | 
			
		||||
                            interactive=True,
 | 
			
		||||
                            label="Step Size")
 | 
			
		||||
 | 
			
		||||
                        with gr.Row():
 | 
			
		||||
                            with gr.Column(scale=2, min_width=10):
 | 
			
		||||
                                form_reset_image = gr.Button("Reset Image")
 | 
			
		||||
                            with gr.Column(scale=3, min_width=10):
 | 
			
		||||
                                form_latent_space = gr.Radio(
 | 
			
		||||
                                    ['w', 'w+'],
 | 
			
		||||
                                    value=global_state.value['params']
 | 
			
		||||
                                    ['latent_space'],
 | 
			
		||||
                                    interactive=True,
 | 
			
		||||
                                    label='Latent space to optimize',
 | 
			
		||||
                                    show_label=False,
 | 
			
		||||
                                )
 | 
			
		||||
 | 
			
		||||
                # Drag
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
                    with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                        gr.Markdown(value='Drag', show_label=False)
 | 
			
		||||
                    with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                        with gr.Row():
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                enable_add_points = gr.Button('Add Points')
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                undo_points = gr.Button('Reset Points')
 | 
			
		||||
                        with gr.Row():
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                form_start_btn = gr.Button("Start")
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                form_stop_btn = gr.Button("Stop")
 | 
			
		||||
 | 
			
		||||
                        form_steps_number = gr.Number(value=0,
 | 
			
		||||
                                                      label="Steps",
 | 
			
		||||
                                                      interactive=False)
 | 
			
		||||
 | 
			
		||||
                # Mask
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
                    with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                        gr.Markdown(value='Mask', show_label=False)
 | 
			
		||||
                    with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                        enable_add_mask = gr.Button('Edit Flexible Area')
 | 
			
		||||
                        with gr.Row():
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                form_reset_mask_btn = gr.Button("Reset mask")
 | 
			
		||||
                            with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                show_mask = gr.Checkbox(
 | 
			
		||||
                                    label='Show Mask',
 | 
			
		||||
                                    value=global_state.value['show_mask'],
 | 
			
		||||
                                    show_label=False)
 | 
			
		||||
 | 
			
		||||
                        with gr.Row():
 | 
			
		||||
                            form_lambda_number = gr.Number(
 | 
			
		||||
                                value=global_state.value["params"]
 | 
			
		||||
                                ["motion_lambda"],
 | 
			
		||||
                                interactive=True,
 | 
			
		||||
                                label="Lambda",
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                form_draw_interval_number = gr.Number(
 | 
			
		||||
                    value=global_state.value["draw_interval"],
 | 
			
		||||
                    label="Draw Interval (steps)",
 | 
			
		||||
                    interactive=True,
 | 
			
		||||
                    visible=False)
 | 
			
		||||
 | 
			
		||||
            # Right --> Image
 | 
			
		||||
            with gr.Column(scale=8):
 | 
			
		||||
                form_image = ImageMask(
 | 
			
		||||
                    value=global_state.value['images']['image_show'],
 | 
			
		||||
                    brush_radius=20).style(
 | 
			
		||||
                        width=768,
 | 
			
		||||
                        height=768)  # NOTE: hard image size code here.
 | 
			
		||||
    gr.Markdown("""
 | 
			
		||||
        ## Quick Start
 | 
			
		||||
 | 
			
		||||
        1. Select desired `Pretrained Model` and adjust `Seed` to generate an
 | 
			
		||||
           initial image.
 | 
			
		||||
        2. Click on image to add control points.
 | 
			
		||||
        3. Click `Start` and enjoy it!
 | 
			
		||||
 | 
			
		||||
        ## Advance Usage
 | 
			
		||||
 | 
			
		||||
        1. Change `Step Size` to adjust learning rate in drag optimization.
 | 
			
		||||
        2. Select `w` or `w+` to change latent space to optimize:
 | 
			
		||||
        * Optimize on `w` space may cause greater influence to the image.
 | 
			
		||||
        * Optimize on `w+` space may work slower than `w`, but usually achieve
 | 
			
		||||
          better results.
 | 
			
		||||
        * Note that changing the latent space will reset the image, points and
 | 
			
		||||
          mask (this has the same effect as `Reset Image` button).
 | 
			
		||||
        3. Click `Edit Flexible Area` to create a mask and constrain the
 | 
			
		||||
           unmasked region to remain unchanged.
 | 
			
		||||
        """)
 | 
			
		||||
    gr.HTML("""
 | 
			
		||||
        <style>
 | 
			
		||||
            .container {
 | 
			
		||||
                position: absolute;
 | 
			
		||||
                height: 50px;
 | 
			
		||||
                text-align: center;
 | 
			
		||||
                line-height: 50px;
 | 
			
		||||
                width: 100%;
 | 
			
		||||
            }
 | 
			
		||||
        </style>
 | 
			
		||||
        <div class="container">
 | 
			
		||||
        Gradio demo supported by
 | 
			
		||||
        <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
 | 
			
		||||
        <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
 | 
			
		||||
        </div>
 | 
			
		||||
        """)
 | 
			
		||||
 | 
			
		||||
    # Network & latents tab listeners
 | 
			
		||||
    def on_change_pretrained_dropdown(pretrained_value, global_state):
 | 
			
		||||
        """Function to handle model change.
 | 
			
		||||
        1. Set pretrained value to global_state
 | 
			
		||||
        2. Re-init images and clear all states
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        global_state['pretrained_weight'] = pretrained_value
 | 
			
		||||
        init_images(global_state)
 | 
			
		||||
        clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
        return global_state, global_state["images"]['image_show']
 | 
			
		||||
 | 
			
		||||
    form_pretrained_dropdown.change(
 | 
			
		||||
        on_change_pretrained_dropdown,
 | 
			
		||||
        inputs=[form_pretrained_dropdown, global_state],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_reset_image(global_state):
 | 
			
		||||
        """Reset image to the original one and clear all states
 | 
			
		||||
        1. Re-init images
 | 
			
		||||
        2. Clear all states
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        init_images(global_state)
 | 
			
		||||
        clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
        return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
    form_reset_image.click(
 | 
			
		||||
        on_click_reset_image,
 | 
			
		||||
        inputs=[global_state],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Update parameters
 | 
			
		||||
    def on_change_update_image_seed(seed, global_state):
 | 
			
		||||
        """Function to handle generation seed change.
 | 
			
		||||
        1. Set seed to global_state
 | 
			
		||||
        2. Re-init images and clear all states
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        global_state["params"]["seed"] = int(seed)
 | 
			
		||||
        init_images(global_state)
 | 
			
		||||
        clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
        return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
    form_seed_number.change(
 | 
			
		||||
        on_change_update_image_seed,
 | 
			
		||||
        inputs=[form_seed_number, global_state],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_latent_space(latent_space, global_state):
 | 
			
		||||
        """Function to reset latent space to optimize.
 | 
			
		||||
        NOTE: this function we reset the image and all controls
 | 
			
		||||
        1. Set latent-space to global_state
 | 
			
		||||
        2. Re-init images and clear all state
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        global_state['params']['latent_space'] = latent_space
 | 
			
		||||
        init_images(global_state)
 | 
			
		||||
        clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
        return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
    form_latent_space.change(on_click_latent_space,
 | 
			
		||||
                             inputs=[form_latent_space, global_state],
 | 
			
		||||
                             outputs=[global_state, form_image])
 | 
			
		||||
 | 
			
		||||
    # ==== Params
 | 
			
		||||
    form_lambda_number.change(
 | 
			
		||||
        partial(on_change_single_global_state, ["params", "motion_lambda"]),
 | 
			
		||||
        inputs=[form_lambda_number, global_state],
 | 
			
		||||
        outputs=[global_state],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_change_lr(lr, global_state):
 | 
			
		||||
        if lr == 0:
 | 
			
		||||
            print('lr is 0, do nothing.')
 | 
			
		||||
            return global_state
 | 
			
		||||
        else:
 | 
			
		||||
            global_state["params"]["lr"] = lr
 | 
			
		||||
            renderer = global_state['renderer']
 | 
			
		||||
            renderer.update_lr(lr)
 | 
			
		||||
            print('New optimizer: ')
 | 
			
		||||
            print(renderer.w_optim)
 | 
			
		||||
        return global_state
 | 
			
		||||
 | 
			
		||||
    form_lr_number.change(
 | 
			
		||||
        on_change_lr,
 | 
			
		||||
        inputs=[form_lr_number, global_state],
 | 
			
		||||
        outputs=[global_state],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_start(global_state, image):
 | 
			
		||||
        p_in_pixels = []
 | 
			
		||||
        t_in_pixels = []
 | 
			
		||||
        valid_points = []
 | 
			
		||||
 | 
			
		||||
        # handle of start drag in mask editing mode
 | 
			
		||||
        global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
 | 
			
		||||
        # Prepare the points for the inference
 | 
			
		||||
        if len(global_state["points"]) == 0:
 | 
			
		||||
            # yield on_click_start_wo_points(global_state, image)
 | 
			
		||||
            image_raw = global_state['images']['image_raw']
 | 
			
		||||
            update_image_draw(
 | 
			
		||||
                image_raw,
 | 
			
		||||
                global_state['points'],
 | 
			
		||||
                global_state['mask'],
 | 
			
		||||
                global_state['show_mask'],
 | 
			
		||||
                global_state,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            yield (
 | 
			
		||||
                global_state,
 | 
			
		||||
                0,
 | 
			
		||||
                global_state['images']['image_show'],
 | 
			
		||||
                # gr.File.update(visible=False),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                # latent space
 | 
			
		||||
                gr.Radio.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                # NOTE: disable stop button
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
 | 
			
		||||
                # update other comps
 | 
			
		||||
                gr.Dropdown.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Checkbox.update(interactive=True),
 | 
			
		||||
                # gr.Number.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
            # Transform the points into torch tensors
 | 
			
		||||
            for key_point, point in global_state["points"].items():
 | 
			
		||||
                try:
 | 
			
		||||
                    p_start = point.get("start_temp", point["start"])
 | 
			
		||||
                    p_end = point["target"]
 | 
			
		||||
 | 
			
		||||
                    if p_start is None or p_end is None:
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                except KeyError:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                p_in_pixels.append(p_start)
 | 
			
		||||
                t_in_pixels.append(p_end)
 | 
			
		||||
                valid_points.append(key_point)
 | 
			
		||||
 | 
			
		||||
            mask = torch.tensor(global_state['mask']).float()
 | 
			
		||||
            drag_mask = 1 - mask
 | 
			
		||||
 | 
			
		||||
            renderer: Renderer = global_state["renderer"]
 | 
			
		||||
            global_state['temporal_params']['stop'] = False
 | 
			
		||||
            global_state['editing_state'] = 'running'
 | 
			
		||||
 | 
			
		||||
            # reverse points order
 | 
			
		||||
            p_to_opt = reverse_point_pairs(p_in_pixels)
 | 
			
		||||
            t_to_opt = reverse_point_pairs(t_in_pixels)
 | 
			
		||||
            print('Running with:')
 | 
			
		||||
            print(f'    Source: {p_in_pixels}')
 | 
			
		||||
            print(f'    Target: {t_in_pixels}')
 | 
			
		||||
            step_idx = 0
 | 
			
		||||
            while True:
 | 
			
		||||
                if global_state["temporal_params"]["stop"]:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                # do drage here!
 | 
			
		||||
                renderer._render_drag_impl(
 | 
			
		||||
                    global_state['generator_params'],
 | 
			
		||||
                    p_to_opt,  # point
 | 
			
		||||
                    t_to_opt,  # target
 | 
			
		||||
                    drag_mask,  # mask,
 | 
			
		||||
                    global_state['params']['motion_lambda'],  # lambda_mask
 | 
			
		||||
                    reg=0,
 | 
			
		||||
                    feature_idx=5,  # NOTE: do not support change for now
 | 
			
		||||
                    r1=global_state['params']['r1_in_pixels'],  # r1
 | 
			
		||||
                    r2=global_state['params']['r2_in_pixels'],  # r2
 | 
			
		||||
                    # random_seed     = 0,
 | 
			
		||||
                    # noise_mode      = 'const',
 | 
			
		||||
                    trunc_psi=global_state['params']['trunc_psi'],
 | 
			
		||||
                    # force_fp32      = False,
 | 
			
		||||
                    # layer_name      = None,
 | 
			
		||||
                    # sel_channels    = 3,
 | 
			
		||||
                    # base_channel    = 0,
 | 
			
		||||
                    # img_scale_db    = 0,
 | 
			
		||||
                    # img_normalize   = False,
 | 
			
		||||
                    # untransform     = False,
 | 
			
		||||
                    is_drag=True,
 | 
			
		||||
                    to_pil=True)
 | 
			
		||||
 | 
			
		||||
                if step_idx % global_state['draw_interval'] == 0:
 | 
			
		||||
                    print('Current Source:')
 | 
			
		||||
                    for key_point, p_i, t_i in zip(valid_points, p_to_opt,
 | 
			
		||||
                                                   t_to_opt):
 | 
			
		||||
                        global_state["points"][key_point]["start_temp"] = [
 | 
			
		||||
                            p_i[1],
 | 
			
		||||
                            p_i[0],
 | 
			
		||||
                        ]
 | 
			
		||||
                        global_state["points"][key_point]["target"] = [
 | 
			
		||||
                            t_i[1],
 | 
			
		||||
                            t_i[0],
 | 
			
		||||
                        ]
 | 
			
		||||
                        start_temp = global_state["points"][key_point][
 | 
			
		||||
                            "start_temp"]
 | 
			
		||||
                        print(f'    {start_temp}')
 | 
			
		||||
 | 
			
		||||
                    image_result = global_state['generator_params']['image']
 | 
			
		||||
                    image_draw = update_image_draw(
 | 
			
		||||
                        image_result,
 | 
			
		||||
                        global_state['points'],
 | 
			
		||||
                        global_state['mask'],
 | 
			
		||||
                        global_state['show_mask'],
 | 
			
		||||
                        global_state,
 | 
			
		||||
                    )
 | 
			
		||||
                    global_state['images']['image_raw'] = image_result
 | 
			
		||||
 | 
			
		||||
                yield (
 | 
			
		||||
                    global_state,
 | 
			
		||||
                    step_idx,
 | 
			
		||||
                    global_state['images']['image_show'],
 | 
			
		||||
                    # gr.File.update(visible=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    # latent space
 | 
			
		||||
                    gr.Radio.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    # enable stop button in loop
 | 
			
		||||
                    gr.Button.update(interactive=True),
 | 
			
		||||
 | 
			
		||||
                    # update other comps
 | 
			
		||||
                    gr.Dropdown.update(interactive=False),
 | 
			
		||||
                    gr.Number.update(interactive=False),
 | 
			
		||||
                    gr.Number.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Button.update(interactive=False),
 | 
			
		||||
                    gr.Checkbox.update(interactive=False),
 | 
			
		||||
                    # gr.Number.update(interactive=False),
 | 
			
		||||
                    gr.Number.update(interactive=False),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # increate step
 | 
			
		||||
                step_idx += 1
 | 
			
		||||
 | 
			
		||||
            image_result = global_state['generator_params']['image']
 | 
			
		||||
            global_state['images']['image_raw'] = image_result
 | 
			
		||||
            image_draw = update_image_draw(image_result,
 | 
			
		||||
                                           global_state['points'],
 | 
			
		||||
                                           global_state['mask'],
 | 
			
		||||
                                           global_state['show_mask'],
 | 
			
		||||
                                           global_state)
 | 
			
		||||
 | 
			
		||||
            # fp = NamedTemporaryFile(suffix=".png", delete=False)
 | 
			
		||||
            # image_result.save(fp, "PNG")
 | 
			
		||||
 | 
			
		||||
            global_state['editing_state'] = 'add_points'
 | 
			
		||||
 | 
			
		||||
            yield (
 | 
			
		||||
                global_state,
 | 
			
		||||
                0,  # reset step to 0 after stop.
 | 
			
		||||
                global_state['images']['image_show'],
 | 
			
		||||
                # gr.File.update(visible=True, value=fp.name),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                # latent space
 | 
			
		||||
                gr.Radio.update(interactive=True),
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
                # NOTE: disable stop button with loop finish
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
 | 
			
		||||
                # update other comps
 | 
			
		||||
                gr.Dropdown.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
                gr.Checkbox.update(interactive=True),
 | 
			
		||||
                gr.Number.update(interactive=True),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    form_start_btn.click(
 | 
			
		||||
        on_click_start,
 | 
			
		||||
        inputs=[global_state, form_image],
 | 
			
		||||
        outputs=[
 | 
			
		||||
            global_state,
 | 
			
		||||
            form_steps_number,
 | 
			
		||||
            form_image,
 | 
			
		||||
            # form_download_result_file,
 | 
			
		||||
            # >>> buttons
 | 
			
		||||
            form_reset_image,
 | 
			
		||||
            enable_add_points,
 | 
			
		||||
            enable_add_mask,
 | 
			
		||||
            undo_points,
 | 
			
		||||
            form_reset_mask_btn,
 | 
			
		||||
            form_latent_space,
 | 
			
		||||
            form_start_btn,
 | 
			
		||||
            form_stop_btn,
 | 
			
		||||
            # <<< buttonm
 | 
			
		||||
            # >>> inputs comps
 | 
			
		||||
            form_pretrained_dropdown,
 | 
			
		||||
            form_seed_number,
 | 
			
		||||
            form_lr_number,
 | 
			
		||||
            show_mask,
 | 
			
		||||
            form_lambda_number,
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_stop(global_state):
 | 
			
		||||
        """Function to handle stop button is clicked.
 | 
			
		||||
        1. send a stop signal by set global_state["temporal_params"]["stop"] as True
 | 
			
		||||
        2. Disable Stop button
 | 
			
		||||
        """
 | 
			
		||||
        global_state["temporal_params"]["stop"] = True
 | 
			
		||||
 | 
			
		||||
        return global_state, gr.Button.update(interactive=False)
 | 
			
		||||
 | 
			
		||||
    form_stop_btn.click(on_click_stop,
 | 
			
		||||
                        inputs=[global_state],
 | 
			
		||||
                        outputs=[global_state, form_stop_btn])
 | 
			
		||||
 | 
			
		||||
    form_draw_interval_number.change(
 | 
			
		||||
        partial(
 | 
			
		||||
            on_change_single_global_state,
 | 
			
		||||
            "draw_interval",
 | 
			
		||||
            map_transform=lambda x: int(x),
 | 
			
		||||
        ),
 | 
			
		||||
        inputs=[form_draw_interval_number, global_state],
 | 
			
		||||
        outputs=[global_state],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_remove_point(global_state):
 | 
			
		||||
        choice = global_state["curr_point"]
 | 
			
		||||
        del global_state["points"][choice]
 | 
			
		||||
 | 
			
		||||
        choices = list(global_state["points"].keys())
 | 
			
		||||
 | 
			
		||||
        if len(choices) > 0:
 | 
			
		||||
            global_state["curr_point"] = choices[0]
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            gr.Dropdown.update(choices=choices, value=choices[0]),
 | 
			
		||||
            global_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Mask
 | 
			
		||||
    def on_click_reset_mask(global_state):
 | 
			
		||||
        global_state['mask'] = np.ones(
 | 
			
		||||
            (
 | 
			
		||||
                global_state["images"]["image_raw"].size[1],
 | 
			
		||||
                global_state["images"]["image_raw"].size[0],
 | 
			
		||||
            ),
 | 
			
		||||
            dtype=np.uint8,
 | 
			
		||||
        )
 | 
			
		||||
        image_draw = update_image_draw(global_state['images']['image_raw'],
 | 
			
		||||
                                       global_state['points'],
 | 
			
		||||
                                       global_state['mask'],
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
        return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
    form_reset_mask_btn.click(
 | 
			
		||||
        on_click_reset_mask,
 | 
			
		||||
        inputs=[global_state],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Image
 | 
			
		||||
    def on_click_enable_draw(global_state, image):
 | 
			
		||||
        """Function to start add mask mode.
 | 
			
		||||
        1. Preprocess mask info from last state
 | 
			
		||||
        2. Change editing state to add_mask
 | 
			
		||||
        3. Set curr image with points and mask
 | 
			
		||||
        """
 | 
			
		||||
        global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
        global_state['editing_state'] = 'add_mask'
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(image_raw, global_state['points'],
 | 
			
		||||
                                       global_state['mask'], True,
 | 
			
		||||
                                       global_state)
 | 
			
		||||
        return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=True))
 | 
			
		||||
 | 
			
		||||
    def on_click_remove_draw(global_state, image):
 | 
			
		||||
        """Function to start remove mask mode.
 | 
			
		||||
        1. Preprocess mask info from last state
 | 
			
		||||
        2. Change editing state to remove_mask
 | 
			
		||||
        3. Set curr image with points and mask
 | 
			
		||||
        """
 | 
			
		||||
        global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
        global_state['edinting_state'] = 'remove_mask'
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(image_raw, global_state['points'],
 | 
			
		||||
                                       global_state['mask'], True,
 | 
			
		||||
                                       global_state)
 | 
			
		||||
        return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=True))
 | 
			
		||||
 | 
			
		||||
    enable_add_mask.click(on_click_enable_draw,
 | 
			
		||||
                          inputs=[global_state, form_image],
 | 
			
		||||
                          outputs=[
 | 
			
		||||
                              global_state,
 | 
			
		||||
                              form_image,
 | 
			
		||||
                          ])
 | 
			
		||||
 | 
			
		||||
    def on_click_add_point(global_state, image: dict):
 | 
			
		||||
        """Function switch from add mask mode to add points mode.
 | 
			
		||||
        1. Updaste mask buffer if need
 | 
			
		||||
        2. Change global_state['editing_state'] to 'add_points'
 | 
			
		||||
        3. Set current image with mask
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
        global_state['editing_state'] = 'add_points'
 | 
			
		||||
        mask = global_state['mask']
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(image_raw, global_state['points'], mask,
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
 | 
			
		||||
        return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=False))
 | 
			
		||||
 | 
			
		||||
    enable_add_points.click(on_click_add_point,
 | 
			
		||||
                            inputs=[global_state, form_image],
 | 
			
		||||
                            outputs=[global_state, form_image])
 | 
			
		||||
 | 
			
		||||
    def on_click_image(global_state, evt: gr.SelectData):
 | 
			
		||||
        """This function only support click for point selection
 | 
			
		||||
        """
 | 
			
		||||
        xy = evt.index
 | 
			
		||||
        if global_state['editing_state'] != 'add_points':
 | 
			
		||||
            print(f'In {global_state["editing_state"]} state. '
 | 
			
		||||
                  'Do not add points.')
 | 
			
		||||
 | 
			
		||||
            return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
        points = global_state["points"]
 | 
			
		||||
 | 
			
		||||
        point_idx = get_latest_points_pair(points)
 | 
			
		||||
        if point_idx is None:
 | 
			
		||||
            points[0] = {'start': xy, 'target': None}
 | 
			
		||||
            print(f'Click Image - Start - {xy}')
 | 
			
		||||
        elif points[point_idx].get('target', None) is None:
 | 
			
		||||
            points[point_idx]['target'] = xy
 | 
			
		||||
            print(f'Click Image - Target - {xy}')
 | 
			
		||||
        else:
 | 
			
		||||
            points[point_idx + 1] = {'start': xy, 'target': None}
 | 
			
		||||
            print(f'Click Image - Start - {xy}')
 | 
			
		||||
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(
 | 
			
		||||
            image_raw,
 | 
			
		||||
            global_state['points'],
 | 
			
		||||
            global_state['mask'],
 | 
			
		||||
            global_state['show_mask'],
 | 
			
		||||
            global_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
    form_image.select(
 | 
			
		||||
        on_click_image,
 | 
			
		||||
        inputs=[global_state],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def on_click_clear_points(global_state):
 | 
			
		||||
        """Function to handle clear all control points
 | 
			
		||||
        1. clear global_state['points'] (clear_state)
 | 
			
		||||
        2. re-init network
 | 
			
		||||
        2. re-draw image
 | 
			
		||||
        """
 | 
			
		||||
        clear_state(global_state, target='point')
 | 
			
		||||
 | 
			
		||||
        renderer: Renderer = global_state["renderer"]
 | 
			
		||||
        renderer.feat_refs = None
 | 
			
		||||
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(image_raw, {}, global_state['mask'],
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
        return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
    undo_points.click(on_click_clear_points,
 | 
			
		||||
                      inputs=[global_state],
 | 
			
		||||
                      outputs=[global_state, form_image])
 | 
			
		||||
 | 
			
		||||
    def on_click_show_mask(global_state, show_mask):
 | 
			
		||||
        """Function to control whether show mask on image."""
 | 
			
		||||
        global_state['show_mask'] = show_mask
 | 
			
		||||
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        image_draw = update_image_draw(
 | 
			
		||||
            image_raw,
 | 
			
		||||
            global_state['points'],
 | 
			
		||||
            global_state['mask'],
 | 
			
		||||
            global_state['show_mask'],
 | 
			
		||||
            global_state,
 | 
			
		||||
        )
 | 
			
		||||
        return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
    show_mask.change(
 | 
			
		||||
        on_click_show_mask,
 | 
			
		||||
        inputs=[global_state, show_mask],
 | 
			
		||||
        outputs=[global_state, form_image],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
gr.close_all()
 | 
			
		||||
app.queue(concurrency_count=5, max_size=20)
 | 
			
		||||
app.launch(share=args.share)
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue