| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -9,24 +9,19 @@ import torch
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					from PIL import Image
 | 
					 | 
					 | 
					 | 
					from PIL import Image
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					import dnnlib
 | 
					 | 
					 | 
					 | 
					import dnnlib
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					from gradio_utils import (
 | 
					 | 
					 | 
					 | 
					from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    ImageMask,
 | 
					 | 
					 | 
					 | 
					                          get_latest_points_pair, get_valid_mask,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    draw_mask_on_image,
 | 
					 | 
					 | 
					 | 
					                          on_change_single_global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    draw_points_on_image,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    get_latest_points_pair,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    get_valid_mask,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    on_change_single_global_state,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					from viz.renderer import Renderer, add_watermark_np
 | 
					 | 
					 | 
					 | 
					from viz.renderer import Renderer, add_watermark_np
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					parser = ArgumentParser()
 | 
					 | 
					 | 
					 | 
					parser = ArgumentParser()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					parser.add_argument("--share", action="store_true")
 | 
					 | 
					 | 
					 | 
					parser.add_argument('--share', action='store_true')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					parser.add_argument("--cache-dir", type=str, default="./checkpoints")
 | 
					 | 
					 | 
					 | 
					parser.add_argument('--cache-dir', type=str, default='./checkpoints')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					args = parser.parse_args()
 | 
					 | 
					 | 
					 | 
					args = parser.parse_args()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					cache_dir = args.cache_dir
 | 
					 | 
					 | 
					 | 
					cache_dir = args.cache_dir
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					device = "cuda"
 | 
					 | 
					 | 
					 | 
					device = 'cuda'
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def reverse_point_pairs(points):
 | 
					 | 
					 | 
					 | 
					def reverse_point_pairs(points):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -43,18 +38,17 @@ def clear_state(global_state, target=None):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    2. set global_state['mask'] as full-one mask.
 | 
					 | 
					 | 
					 | 
					    2. set global_state['mask'] as full-one mask.
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    """
 | 
					 | 
					 | 
					 | 
					    """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if target is None:
 | 
					 | 
					 | 
					 | 
					    if target is None:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        target = ["point", "mask"]
 | 
					 | 
					 | 
					 | 
					        target = ['point', 'mask']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if not isinstance(target, list):
 | 
					 | 
					 | 
					 | 
					    if not isinstance(target, list):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        target = [target]
 | 
					 | 
					 | 
					 | 
					        target = [target]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if "point" in target:
 | 
					 | 
					 | 
					 | 
					    if 'point' in target:
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["points"] = dict()
 | 
					 | 
					 | 
					 | 
					        global_state['points'] = dict()
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        print("Clear Points State!")
 | 
					 | 
					 | 
					 | 
					        print('Clear Points State!')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if "mask" in target:
 | 
					 | 
					 | 
					 | 
					    if 'mask' in target:
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["mask"] = np.ones(
 | 
					 | 
					 | 
					 | 
					        global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            (image_raw.size[1], image_raw.size[0]), dtype=np.uint8
 | 
					 | 
					 | 
					 | 
					                                       dtype=np.uint8)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					        print('Clear mask State!')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        print("Clear mask State!")
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    return global_state
 | 
					 | 
					 | 
					 | 
					    return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -73,46 +67,43 @@ def init_images(global_state):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    else:
 | 
					 | 
					 | 
					 | 
					    else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state = global_state
 | 
					 | 
					 | 
					 | 
					        state = global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["renderer"].init_network(
 | 
					 | 
					 | 
					 | 
					    state['renderer'].init_network(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["generator_params"],  # res
 | 
					 | 
					 | 
					 | 
					        state['generator_params'],  # res
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        valid_checkpoints_dict[state["pretrained_weight"]],  # pkl
 | 
					 | 
					 | 
					 | 
					        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["params"]["seed"],  # w0_seed,
 | 
					 | 
					 | 
					 | 
					        state['params']['seed'],  # w0_seed,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        None,  # w_load
 | 
					 | 
					 | 
					 | 
					        None,  # w_load
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["params"]["latent_space"] == "w+",  # w_plus
 | 
					 | 
					 | 
					 | 
					        state['params']['latent_space'] == 'w+',  # w_plus
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        "const",
 | 
					 | 
					 | 
					 | 
					        'const',
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["params"]["trunc_psi"],  # trunc_psi,
 | 
					 | 
					 | 
					 | 
					        state['params']['trunc_psi'],  # trunc_psi,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["params"]["trunc_cutoff"],  # trunc_cutoff,
 | 
					 | 
					 | 
					 | 
					        state['params']['trunc_cutoff'],  # trunc_cutoff,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        None,  # input_transform
 | 
					 | 
					 | 
					 | 
					        None,  # input_transform
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["params"]["lr"],  # lr,
 | 
					 | 
					 | 
					 | 
					        state['params']['lr']  # lr,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["renderer"]._render_drag_impl(
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        state["generator_params"], is_drag=False, to_pil=True
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					    )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    init_image = state["generator_params"].image
 | 
					 | 
					 | 
					 | 
					    state['renderer']._render_drag_impl(state['generator_params'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["images"]["image_orig"] = init_image
 | 
					 | 
					 | 
					 | 
					                                        is_drag=False,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["images"]["image_raw"] = init_image
 | 
					 | 
					 | 
					 | 
					                                        to_pil=True)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["images"]["image_show"] = Image.fromarray(
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        add_watermark_np(np.array(init_image))
 | 
					 | 
					 | 
					 | 
					    init_image = state['generator_params'].image
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					    state['images']['image_orig'] = init_image
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    state["mask"] = np.ones((init_image.size[1], init_image.size[0]), dtype=np.uint8)
 | 
					 | 
					 | 
					 | 
					    state['images']['image_raw'] = init_image
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    state['images']['image_show'] = Image.fromarray(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        add_watermark_np(np.array(init_image)))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            dtype=np.uint8)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    return global_state
 | 
					 | 
					 | 
					 | 
					    return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def update_image_draw(image, points, mask, show_mask, global_state=None):
 | 
					 | 
					 | 
					 | 
					def update_image_draw(image, points, mask, show_mask, global_state=None):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    image_draw = draw_points_on_image(image, points)
 | 
					 | 
					 | 
					 | 
					    image_draw = draw_points_on_image(image, points)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if (
 | 
					 | 
					 | 
					 | 
					    if show_mask and mask is not None and not (mask == 0).all() and not (
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        show_mask
 | 
					 | 
					 | 
					 | 
					            mask == 1).all():
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        and mask is not None
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        and not (mask == 0).all()
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        and not (mask == 1).all()
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    ):
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = draw_mask_on_image(image_draw, mask)
 | 
					 | 
					 | 
					 | 
					        image_draw = draw_mask_on_image(image_draw, mask)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
 | 
					 | 
					 | 
					 | 
					    image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if global_state is not None:
 | 
					 | 
					 | 
					 | 
					    if global_state is not None:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["images"]["image_show"] = image_draw
 | 
					 | 
					 | 
					 | 
					        global_state['images']['image_show'] = image_draw
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    return image_draw
 | 
					 | 
					 | 
					 | 
					    return image_draw
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -124,97 +115,102 @@ def preprocess_mask_info(global_state, image):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        2.2 global_state is add_mask:
 | 
					 | 
					 | 
					 | 
					        2.2 global_state is add_mask:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    """
 | 
					 | 
					 | 
					 | 
					    """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if isinstance(image, dict):
 | 
					 | 
					 | 
					 | 
					    if isinstance(image, dict):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        last_mask = get_valid_mask(image["mask"])
 | 
					 | 
					 | 
					 | 
					        last_mask = get_valid_mask(image['mask'])
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    else:
 | 
					 | 
					 | 
					 | 
					    else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        last_mask = None
 | 
					 | 
					 | 
					 | 
					        last_mask = None
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    mask = global_state["mask"]
 | 
					 | 
					 | 
					 | 
					    mask = global_state['mask']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # mask in global state is a placeholder with all 1.
 | 
					 | 
					 | 
					 | 
					    # mask in global state is a placeholder with all 1.
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if (mask == 1).all():
 | 
					 | 
					 | 
					 | 
					    if (mask == 1).all():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        mask = last_mask
 | 
					 | 
					 | 
					 | 
					        mask = last_mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # last_mask = global_state['last_mask']
 | 
					 | 
					 | 
					 | 
					    # last_mask = global_state['last_mask']
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    editing_mode = global_state["editing_state"]
 | 
					 | 
					 | 
					 | 
					    editing_mode = global_state['editing_state']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if last_mask is None:
 | 
					 | 
					 | 
					 | 
					    if last_mask is None:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state
 | 
					 | 
					 | 
					 | 
					        return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if editing_mode == "remove_mask":
 | 
					 | 
					 | 
					 | 
					    if editing_mode == 'remove_mask':
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        updated_mask = np.clip(mask - last_mask, 0, 1)
 | 
					 | 
					 | 
					 | 
					        updated_mask = np.clip(mask - last_mask, 0, 1)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        print(f"Last editing_state is {editing_mode}, do remove.")
 | 
					 | 
					 | 
					 | 
					        print(f'Last editing_state is {editing_mode}, do remove.')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    elif editing_mode == "add_mask":
 | 
					 | 
					 | 
					 | 
					    elif editing_mode == 'add_mask':
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        updated_mask = np.clip(mask + last_mask, 0, 1)
 | 
					 | 
					 | 
					 | 
					        updated_mask = np.clip(mask + last_mask, 0, 1)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        print(f"Last editing_state is {editing_mode}, do add.")
 | 
					 | 
					 | 
					 | 
					        print(f'Last editing_state is {editing_mode}, do add.')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    else:
 | 
					 | 
					 | 
					 | 
					    else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        updated_mask = mask
 | 
					 | 
					 | 
					 | 
					        updated_mask = mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        print(f"Last editing_state is {editing_mode}, " "do nothing to mask.")
 | 
					 | 
					 | 
					 | 
					        print(f'Last editing_state is {editing_mode}, '
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					              'do nothing to mask.')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    global_state["mask"] = updated_mask
 | 
					 | 
					 | 
					 | 
					    global_state['mask'] = updated_mask
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # global_state['last_mask'] = None  # clear buffer
 | 
					 | 
					 | 
					 | 
					    # global_state['last_mask'] = None  # clear buffer
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    return global_state
 | 
					 | 
					 | 
					 | 
					    return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					valid_checkpoints_dict = {
 | 
					 | 
					 | 
					 | 
					valid_checkpoints_dict = {
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    f.split("/")[-1].split(".")[0]: osp.join(cache_dir, f)
 | 
					 | 
					 | 
					 | 
					    f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    for f in os.listdir(cache_dir)
 | 
					 | 
					 | 
					 | 
					    for f in os.listdir(cache_dir)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if (f.endswith("pkl") and osp.exists(osp.join(cache_dir, f)))
 | 
					 | 
					 | 
					 | 
					    if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					}
 | 
					 | 
					 | 
					 | 
					}
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(f"File under cache_dir ({cache_dir}):")
 | 
					 | 
					 | 
					 | 
					print(f'File under cache_dir ({cache_dir}):')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(os.listdir(cache_dir))
 | 
					 | 
					 | 
					 | 
					print(os.listdir(cache_dir))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print("Valid checkpoint file:")
 | 
					 | 
					 | 
					 | 
					print('Valid checkpoint file:')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					print(valid_checkpoints_dict)
 | 
					 | 
					 | 
					 | 
					print(valid_checkpoints_dict)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					init_pkl = list(valid_checkpoints_dict.keys())[4]
 | 
					 | 
					 | 
					 | 
					init_pkl = 'stylegan_human_v2_512'
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					with gr.Blocks() as app:
 | 
					 | 
					 | 
					 | 
					with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # renderer = Renderer()
 | 
					 | 
					 | 
					 | 
					    # renderer = Renderer()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    global_state = gr.State(
 | 
					 | 
					 | 
					 | 
					    global_state = gr.State({
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        {
 | 
					 | 
					 | 
					 | 
					        "images": {
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "images": {
 | 
					 | 
					 | 
					 | 
					            # image_orig: the original image, change with seed/model is changed
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # image_orig: the original image, change with seed/model is changed
 | 
					 | 
					 | 
					 | 
					            # image_raw: image with mask and points, change durning optimization
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # image_raw: image with mask and points, change durning optimization
 | 
					 | 
					 | 
					 | 
					            # image_show: image showed on screen
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # image_show: image showed on screen
 | 
					 | 
					 | 
					 | 
					        },
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            },
 | 
					 | 
					 | 
					 | 
					        "temporal_params": {
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "temporal_params": {
 | 
					 | 
					 | 
					 | 
					            # stop
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # stop
 | 
					 | 
					 | 
					 | 
					        },
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            },
 | 
					 | 
					 | 
					 | 
					        'mask':
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "mask": None,  # mask for visualization, 1 for editing and 0 for unchange
 | 
					 | 
					 | 
					 | 
					        None,  # mask for visualization, 1 for editing and 0 for unchange
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "last_mask": None,  # last edited mask
 | 
					 | 
					 | 
					 | 
					        'last_mask': None,  # last edited mask
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "show_mask": True,  # add button
 | 
					 | 
					 | 
					 | 
					        'show_mask': True,  # add button
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "generator_params": dnnlib.EasyDict(),
 | 
					 | 
					 | 
					 | 
					        "generator_params": dnnlib.EasyDict(),
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "params": {
 | 
					 | 
					 | 
					 | 
					        "params": {
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "seed": 0,
 | 
					 | 
					 | 
					 | 
					            "seed": 0,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "motion_lambda": 20,
 | 
					 | 
					 | 
					 | 
					            "motion_lambda": 20,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "r1_in_pixels": 3,
 | 
					 | 
					 | 
					 | 
					            "r1_in_pixels": 3,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "r2_in_pixels": 12,
 | 
					 | 
					 | 
					 | 
					            "r2_in_pixels": 12,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "magnitude_direction_in_pixels": 1.0,
 | 
					 | 
					 | 
					 | 
					            "magnitude_direction_in_pixels": 1.0,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "latent_space": "w+",
 | 
					 | 
					 | 
					 | 
					            "latent_space": "w+",
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "trunc_psi": 0.7,
 | 
					 | 
					 | 
					 | 
					            "trunc_psi": 0.7,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "trunc_cutoff": None,
 | 
					 | 
					 | 
					 | 
					            "trunc_cutoff": None,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                "lr": 0.001,
 | 
					 | 
					 | 
					 | 
					            "lr": 0.001,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            },
 | 
					 | 
					 | 
					 | 
					        },
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "device": device,
 | 
					 | 
					 | 
					 | 
					        "device": device,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "draw_interval": 1,
 | 
					 | 
					 | 
					 | 
					        "draw_interval": 1,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "renderer": Renderer(disable_timing=True),
 | 
					 | 
					 | 
					 | 
					        "renderer": Renderer(disable_timing=True),
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "points": {},
 | 
					 | 
					 | 
					 | 
					        "points": {},
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "curr_point": None,
 | 
					 | 
					 | 
					 | 
					        "curr_point": None,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "curr_type_point": "start",
 | 
					 | 
					 | 
					 | 
					        "curr_type_point": "start",
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "editing_state": "add_points",
 | 
					 | 
					 | 
					 | 
					        'editing_state': 'add_points',
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            "pretrained_weight": init_pkl,
 | 
					 | 
					 | 
					 | 
					        'pretrained_weight': init_pkl
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        }
 | 
					 | 
					 | 
					 | 
					    })
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # init image
 | 
					 | 
					 | 
					 | 
					    # init image
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    global_state = init_images(global_state)
 | 
					 | 
					 | 
					 | 
					    global_state = init_images(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    with gr.Row():
 | 
					 | 
					 | 
					 | 
					    with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        with gr.Row():
 | 
					 | 
					 | 
					 | 
					        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # Left --> tools
 | 
					 | 
					 | 
					 | 
					            # Left --> tools
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            with gr.Column(scale=3):
 | 
					 | 
					 | 
					 | 
					            with gr.Column(scale=3):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # Pickle
 | 
					 | 
					 | 
					 | 
					                # Pickle
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        gr.Markdown(value="Pickle", show_label=False)
 | 
					 | 
					 | 
					 | 
					                        gr.Markdown(value='Pickle', show_label=False)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        form_pretrained_dropdown = gr.Dropdown(
 | 
					 | 
					 | 
					 | 
					                        form_pretrained_dropdown = gr.Dropdown(
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -226,71 +222,71 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # Latent
 | 
					 | 
					 | 
					 | 
					                # Latent
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        gr.Markdown(value="Latent", show_label=False)
 | 
					 | 
					 | 
					 | 
					                        gr.Markdown(value='Latent', show_label=False)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        form_seed_number = gr.Number(
 | 
					 | 
					 | 
					 | 
					                        form_seed_number = gr.Number(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            value=global_state.value["params"]["seed"],
 | 
					 | 
					 | 
					 | 
					                            value=global_state.value['params']['seed'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            interactive=True,
 | 
					 | 
					 | 
					 | 
					                            interactive=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            label="Seed",
 | 
					 | 
					 | 
					 | 
					                            label="Seed",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        )
 | 
					 | 
					 | 
					 | 
					                        )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        form_lr_number = gr.Number(
 | 
					 | 
					 | 
					 | 
					                        form_lr_number = gr.Number(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            value=global_state.value["params"]["lr"],
 | 
					 | 
					 | 
					 | 
					                            value=global_state.value["params"]["lr"],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            interactive=True,
 | 
					 | 
					 | 
					 | 
					                            interactive=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            label="Step Size",
 | 
					 | 
					 | 
					 | 
					                            label="Step Size")
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=2, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=2, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                form_reset_image = gr.Button("Reset Image")
 | 
					 | 
					 | 
					 | 
					                                form_reset_image = gr.Button("Reset Image")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=3, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=3, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                form_latent_space = gr.Radio(
 | 
					 | 
					 | 
					 | 
					                                form_latent_space = gr.Radio(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    ["w", "w+"],
 | 
					 | 
					 | 
					 | 
					                                    ['w', 'w+'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    value=global_state.value["params"]["latent_space"],
 | 
					 | 
					 | 
					 | 
					                                    value=global_state.value['params']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                                    ['latent_space'],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    interactive=True,
 | 
					 | 
					 | 
					 | 
					                                    interactive=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    label="Latent space to optimize",
 | 
					 | 
					 | 
					 | 
					                                    label='Latent space to optimize',
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    show_label=False,
 | 
					 | 
					 | 
					 | 
					                                    show_label=False,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                )
 | 
					 | 
					 | 
					 | 
					                                )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # Drag
 | 
					 | 
					 | 
					 | 
					                # Drag
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        gr.Markdown(value="Drag", show_label=False)
 | 
					 | 
					 | 
					 | 
					                        gr.Markdown(value='Drag', show_label=False)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                enable_add_points = gr.Button("Add Points")
 | 
					 | 
					 | 
					 | 
					                                enable_add_points = gr.Button('Add Points')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                undo_points = gr.Button("Reset Points")
 | 
					 | 
					 | 
					 | 
					                                undo_points = gr.Button('Reset Points')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                form_start_btn = gr.Button("Start")
 | 
					 | 
					 | 
					 | 
					                                form_start_btn = gr.Button("Start")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                form_stop_btn = gr.Button("Stop")
 | 
					 | 
					 | 
					 | 
					                                form_stop_btn = gr.Button("Stop")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        form_steps_number = gr.Number(
 | 
					 | 
					 | 
					 | 
					                        form_steps_number = gr.Number(value=0,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            value=0, label="Steps", interactive=False
 | 
					 | 
					 | 
					 | 
					                                                      label="Steps",
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        )
 | 
					 | 
					 | 
					 | 
					                                                      interactive=False)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # Mask
 | 
					 | 
					 | 
					 | 
					                # Mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
					 | 
					 | 
					 | 
					                with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        gr.Markdown(value="Mask", show_label=False)
 | 
					 | 
					 | 
					 | 
					                        gr.Markdown(value='Mask', show_label=False)
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
					 | 
					 | 
					 | 
					                    with gr.Column(scale=4, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        enable_add_mask = gr.Button("Edit Flexible Area")
 | 
					 | 
					 | 
					 | 
					                        enable_add_mask = gr.Button('Edit Flexible Area')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                form_reset_mask_btn = gr.Button("Reset mask")
 | 
					 | 
					 | 
					 | 
					                                form_reset_mask_btn = gr.Button("Reset mask")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
					 | 
					 | 
					 | 
					                            with gr.Column(scale=1, min_width=10):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                show_mask = gr.Checkbox(
 | 
					 | 
					 | 
					 | 
					                                show_mask = gr.Checkbox(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    label="Show Mask",
 | 
					 | 
					 | 
					 | 
					                                    label='Show Mask',
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    value=global_state.value["show_mask"],
 | 
					 | 
					 | 
					 | 
					                                    value=global_state.value['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                    show_label=False,
 | 
					 | 
					 | 
					 | 
					                                    show_label=False)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
					 | 
					 | 
					 | 
					                        with gr.Row():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            form_lambda_number = gr.Number(
 | 
					 | 
					 | 
					 | 
					                            form_lambda_number = gr.Number(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                value=global_state.value["params"]["motion_lambda"],
 | 
					 | 
					 | 
					 | 
					                                value=global_state.value["params"]
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                                ["motion_lambda"],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                interactive=True,
 | 
					 | 
					 | 
					 | 
					                                interactive=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                label="Lambda",
 | 
					 | 
					 | 
					 | 
					                                label="Lambda",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            )
 | 
					 | 
					 | 
					 | 
					                            )
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -299,18 +295,16 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    value=global_state.value["draw_interval"],
 | 
					 | 
					 | 
					 | 
					                    value=global_state.value["draw_interval"],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    label="Draw Interval (steps)",
 | 
					 | 
					 | 
					 | 
					                    label="Draw Interval (steps)",
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    interactive=True,
 | 
					 | 
					 | 
					 | 
					                    interactive=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    visible=False,
 | 
					 | 
					 | 
					 | 
					                    visible=False)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # Right --> Image
 | 
					 | 
					 | 
					 | 
					            # Right --> Image
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            with gr.Column(scale=8):
 | 
					 | 
					 | 
					 | 
					            with gr.Column(scale=8):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                form_image = ImageMask(
 | 
					 | 
					 | 
					 | 
					                form_image = ImageMask(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    value=global_state.value["images"]["image_show"], brush_radius=20
 | 
					 | 
					 | 
					 | 
					                    value=global_state.value['images']['image_show'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                ).style(
 | 
					 | 
					 | 
					 | 
					                    brush_radius=20).style(
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    width=768, height=768
 | 
					 | 
					 | 
					 | 
					                        width=768,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                )  # NOTE: hard image size code here.
 | 
					 | 
					 | 
					 | 
					                        height=768)  # NOTE: hard image size code here.
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    gr.Markdown(
 | 
					 | 
					 | 
					 | 
					    gr.Markdown("""
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        ## Quick Start
 | 
					 | 
					 | 
					 | 
					        ## Quick Start
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        1. Select desired `Pretrained Model` and adjust `Seed` to generate an
 | 
					 | 
					 | 
					 | 
					        1. Select desired `Pretrained Model` and adjust `Seed` to generate an
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -329,10 +323,8 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					          mask (this has the same effect as `Reset Image` button).
 | 
					 | 
					 | 
					 | 
					          mask (this has the same effect as `Reset Image` button).
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        3. Click `Edit Flexible Area` to create a mask and constrain the
 | 
					 | 
					 | 
					 | 
					        3. Click `Edit Flexible Area` to create a mask and constrain the
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					           unmasked region to remain unchanged.
 | 
					 | 
					 | 
					 | 
					           unmasked region to remain unchanged.
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					    gr.HTML("""
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    gr.HTML(
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        <style>
 | 
					 | 
					 | 
					 | 
					        <style>
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            .container {
 | 
					 | 
					 | 
					 | 
					            .container {
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                position: absolute;
 | 
					 | 
					 | 
					 | 
					                position: absolute;
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -347,8 +339,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
 | 
					 | 
					 | 
					 | 
					        <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
 | 
					 | 
					 | 
					 | 
					        <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        </div>
 | 
					 | 
					 | 
					 | 
					        </div>
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # Network & latents tab listeners
 | 
					 | 
					 | 
					 | 
					    # Network & latents tab listeners
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_change_pretrained_dropdown(pretrained_value, global_state):
 | 
					 | 
					 | 
					 | 
					    def on_change_pretrained_dropdown(pretrained_value, global_state):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -357,11 +348,11 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        2. Re-init images and clear all states
 | 
					 | 
					 | 
					 | 
					        2. Re-init images and clear all states
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["pretrained_weight"] = pretrained_value
 | 
					 | 
					 | 
					 | 
					        global_state['pretrained_weight'] = pretrained_value
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, global_state["images"]["image_show"]
 | 
					 | 
					 | 
					 | 
					        return global_state, global_state["images"]['image_show']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_pretrained_dropdown.change(
 | 
					 | 
					 | 
					 | 
					    form_pretrained_dropdown.change(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_change_pretrained_dropdown,
 | 
					 | 
					 | 
					 | 
					        on_change_pretrained_dropdown,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -378,7 +369,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, global_state["images"]["image_show"]
 | 
					 | 
					 | 
					 | 
					        return global_state, global_state['images']['image_show']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_reset_image.click(
 | 
					 | 
					 | 
					 | 
					    form_reset_image.click(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_reset_image,
 | 
					 | 
					 | 
					 | 
					        on_click_reset_image,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -397,7 +388,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, global_state["images"]["image_show"]
 | 
					 | 
					 | 
					 | 
					        return global_state, global_state['images']['image_show']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_seed_number.change(
 | 
					 | 
					 | 
					 | 
					    form_seed_number.change(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_change_update_image_seed,
 | 
					 | 
					 | 
					 | 
					        on_change_update_image_seed,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -412,17 +403,15 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        2. Re-init images and clear all state
 | 
					 | 
					 | 
					 | 
					        2. Re-init images and clear all state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["params"]["latent_space"] = latent_space
 | 
					 | 
					 | 
					 | 
					        global_state['params']['latent_space'] = latent_space
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
					 | 
					 | 
					 | 
					        init_images(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
					 | 
					 | 
					 | 
					        clear_state(global_state)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, global_state["images"]["image_show"]
 | 
					 | 
					 | 
					 | 
					        return global_state, global_state['images']['image_show']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_latent_space.change(
 | 
					 | 
					 | 
					 | 
					    form_latent_space.change(on_click_latent_space,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_latent_space,
 | 
					 | 
					 | 
					 | 
					                             inputs=[form_latent_space, global_state],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        inputs=[form_latent_space, global_state],
 | 
					 | 
					 | 
					 | 
					                             outputs=[global_state, form_image])
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        outputs=[global_state, form_image],
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # ==== Params
 | 
					 | 
					 | 
					 | 
					    # ==== Params
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_lambda_number.change(
 | 
					 | 
					 | 
					 | 
					    form_lambda_number.change(
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -433,13 +422,13 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_change_lr(lr, global_state):
 | 
					 | 
					 | 
					 | 
					    def on_change_lr(lr, global_state):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if lr == 0:
 | 
					 | 
					 | 
					 | 
					        if lr == 0:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print("lr is 0, do nothing.")
 | 
					 | 
					 | 
					 | 
					            print('lr is 0, do nothing.')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            return global_state
 | 
					 | 
					 | 
					 | 
					            return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        else:
 | 
					 | 
					 | 
					 | 
					        else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["params"]["lr"] = lr
 | 
					 | 
					 | 
					 | 
					            global_state["params"]["lr"] = lr
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            renderer = global_state["renderer"]
 | 
					 | 
					 | 
					 | 
					            renderer = global_state['renderer']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            renderer.update_lr(lr)
 | 
					 | 
					 | 
					 | 
					            renderer.update_lr(lr)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print("New optimizer: ")
 | 
					 | 
					 | 
					 | 
					            print('New optimizer: ')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(renderer.w_optim)
 | 
					 | 
					 | 
					 | 
					            print(renderer.w_optim)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state
 | 
					 | 
					 | 
					 | 
					        return global_state
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -460,19 +449,19 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        # Prepare the points for the inference
 | 
					 | 
					 | 
					 | 
					        # Prepare the points for the inference
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if len(global_state["points"]) == 0:
 | 
					 | 
					 | 
					 | 
					        if len(global_state["points"]) == 0:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # yield on_click_start_wo_points(global_state, image)
 | 
					 | 
					 | 
					 | 
					            # yield on_click_start_wo_points(global_state, image)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					            image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            update_image_draw(
 | 
					 | 
					 | 
					 | 
					            update_image_draw(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                image_raw,
 | 
					 | 
					 | 
					 | 
					                image_raw,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["points"],
 | 
					 | 
					 | 
					 | 
					                global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["mask"],
 | 
					 | 
					 | 
					 | 
					                global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					                global_state['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state,
 | 
					 | 
					 | 
					 | 
					                global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            )
 | 
					 | 
					 | 
					 | 
					            )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            yield (
 | 
					 | 
					 | 
					 | 
					            yield (
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state,
 | 
					 | 
					 | 
					 | 
					                global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                0,
 | 
					 | 
					 | 
					 | 
					                0,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["images"]["image_show"],
 | 
					 | 
					 | 
					 | 
					                global_state['images']['image_show'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # gr.File.update(visible=False),
 | 
					 | 
					 | 
					 | 
					                # gr.File.update(visible=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -484,6 +473,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # NOTE: disable stop button
 | 
					 | 
					 | 
					 | 
					                # NOTE: disable stop button
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # update other comps
 | 
					 | 
					 | 
					 | 
					                # update other comps
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Dropdown.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Dropdown.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -495,6 +485,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            )
 | 
					 | 
					 | 
					 | 
					            )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        else:
 | 
					 | 
					 | 
					 | 
					        else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # Transform the points into torch tensors
 | 
					 | 
					 | 
					 | 
					            # Transform the points into torch tensors
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            for key_point, point in global_state["points"].items():
 | 
					 | 
					 | 
					 | 
					            for key_point, point in global_state["points"].items():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                try:
 | 
					 | 
					 | 
					 | 
					                try:
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -511,19 +502,19 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                t_in_pixels.append(p_end)
 | 
					 | 
					 | 
					 | 
					                t_in_pixels.append(p_end)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                valid_points.append(key_point)
 | 
					 | 
					 | 
					 | 
					                valid_points.append(key_point)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            mask = torch.tensor(global_state["mask"]).float()
 | 
					 | 
					 | 
					 | 
					            mask = torch.tensor(global_state['mask']).float()
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            drag_mask = 1 - mask
 | 
					 | 
					 | 
					 | 
					            drag_mask = 1 - mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            renderer: Renderer = global_state["renderer"]
 | 
					 | 
					 | 
					 | 
					            renderer: Renderer = global_state["renderer"]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["temporal_params"]["stop"] = False
 | 
					 | 
					 | 
					 | 
					            global_state['temporal_params']['stop'] = False
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["editing_state"] = "running"
 | 
					 | 
					 | 
					 | 
					            global_state['editing_state'] = 'running'
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # reverse points order
 | 
					 | 
					 | 
					 | 
					            # reverse points order
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            p_to_opt = reverse_point_pairs(p_in_pixels)
 | 
					 | 
					 | 
					 | 
					            p_to_opt = reverse_point_pairs(p_in_pixels)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            t_to_opt = reverse_point_pairs(t_in_pixels)
 | 
					 | 
					 | 
					 | 
					            t_to_opt = reverse_point_pairs(t_in_pixels)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print("Running with:")
 | 
					 | 
					 | 
					 | 
					            print('Running with:')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"    Source: {p_in_pixels}")
 | 
					 | 
					 | 
					 | 
					            print(f'    Source: {p_in_pixels}')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"    Target: {t_in_pixels}")
 | 
					 | 
					 | 
					 | 
					            print(f'    Target: {t_in_pixels}')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            step_idx = 0
 | 
					 | 
					 | 
					 | 
					            step_idx = 0
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            while True:
 | 
					 | 
					 | 
					 | 
					            while True:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                if global_state["temporal_params"]["stop"]:
 | 
					 | 
					 | 
					 | 
					                if global_state["temporal_params"]["stop"]:
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -531,18 +522,18 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # do drage here!
 | 
					 | 
					 | 
					 | 
					                # do drage here!
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                renderer._render_drag_impl(
 | 
					 | 
					 | 
					 | 
					                renderer._render_drag_impl(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    global_state["generator_params"],
 | 
					 | 
					 | 
					 | 
					                    global_state['generator_params'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    p_to_opt,  # point
 | 
					 | 
					 | 
					 | 
					                    p_to_opt,  # point
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    t_to_opt,  # target
 | 
					 | 
					 | 
					 | 
					                    t_to_opt,  # target
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    drag_mask,  # mask,
 | 
					 | 
					 | 
					 | 
					                    drag_mask,  # mask,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    global_state["params"]["motion_lambda"],  # lambda_mask
 | 
					 | 
					 | 
					 | 
					                    global_state['params']['motion_lambda'],  # lambda_mask
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    reg=0,
 | 
					 | 
					 | 
					 | 
					                    reg=0,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    feature_idx=5,  # NOTE: do not support change for now
 | 
					 | 
					 | 
					 | 
					                    feature_idx=5,  # NOTE: do not support change for now
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    r1=global_state["params"]["r1_in_pixels"],  # r1
 | 
					 | 
					 | 
					 | 
					                    r1=global_state['params']['r1_in_pixels'],  # r1
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    r2=global_state["params"]["r2_in_pixels"],  # r2
 | 
					 | 
					 | 
					 | 
					                    r2=global_state['params']['r2_in_pixels'],  # r2
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # random_seed     = 0,
 | 
					 | 
					 | 
					 | 
					                    # random_seed     = 0,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # noise_mode      = 'const',
 | 
					 | 
					 | 
					 | 
					                    # noise_mode      = 'const',
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    trunc_psi=global_state["params"]["trunc_psi"],
 | 
					 | 
					 | 
					 | 
					                    trunc_psi=global_state['params']['trunc_psi'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # force_fp32      = False,
 | 
					 | 
					 | 
					 | 
					                    # force_fp32      = False,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # layer_name      = None,
 | 
					 | 
					 | 
					 | 
					                    # layer_name      = None,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # sel_channels    = 3,
 | 
					 | 
					 | 
					 | 
					                    # sel_channels    = 3,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -551,12 +542,12 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # img_normalize   = False,
 | 
					 | 
					 | 
					 | 
					                    # img_normalize   = False,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # untransform     = False,
 | 
					 | 
					 | 
					 | 
					                    # untransform     = False,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    is_drag=True,
 | 
					 | 
					 | 
					 | 
					                    is_drag=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    to_pil=True,
 | 
					 | 
					 | 
					 | 
					                    to_pil=True)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                if step_idx % global_state["draw_interval"] == 0:
 | 
					 | 
					 | 
					 | 
					                if step_idx % global_state['draw_interval'] == 0:
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    print("Current Source:")
 | 
					 | 
					 | 
					 | 
					                    print('Current Source:')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    for key_point, p_i, t_i in zip(valid_points, p_to_opt, t_to_opt):
 | 
					 | 
					 | 
					 | 
					                    for key_point, p_i, t_i in zip(valid_points, p_to_opt,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                                                   t_to_opt):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        global_state["points"][key_point]["start_temp"] = [
 | 
					 | 
					 | 
					 | 
					                        global_state["points"][key_point]["start_temp"] = [
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            p_i[1],
 | 
					 | 
					 | 
					 | 
					                            p_i[1],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            p_i[0],
 | 
					 | 
					 | 
					 | 
					                            p_i[0],
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -565,23 +556,24 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            t_i[1],
 | 
					 | 
					 | 
					 | 
					                            t_i[1],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                            t_i[0],
 | 
					 | 
					 | 
					 | 
					                            t_i[0],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        ]
 | 
					 | 
					 | 
					 | 
					                        ]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        start_temp = global_state["points"][key_point]["start_temp"]
 | 
					 | 
					 | 
					 | 
					                        start_temp = global_state["points"][key_point][
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        print(f"    {start_temp}")
 | 
					 | 
					 | 
					 | 
					                            "start_temp"]
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                        print(f'    {start_temp}')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    image_result = global_state["generator_params"]["image"]
 | 
					 | 
					 | 
					 | 
					                    image_result = global_state['generator_params']['image']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					                    image_draw = update_image_draw(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        image_result,
 | 
					 | 
					 | 
					 | 
					                        image_result,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        global_state["points"],
 | 
					 | 
					 | 
					 | 
					                        global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        global_state["mask"],
 | 
					 | 
					 | 
					 | 
					                        global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					                        global_state['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        global_state,
 | 
					 | 
					 | 
					 | 
					                        global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    )
 | 
					 | 
					 | 
					 | 
					                    )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    global_state["images"]["image_raw"] = image_result
 | 
					 | 
					 | 
					 | 
					                    global_state['images']['image_raw'] = image_result
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                yield (
 | 
					 | 
					 | 
					 | 
					                yield (
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    global_state,
 | 
					 | 
					 | 
					 | 
					                    global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    step_idx,
 | 
					 | 
					 | 
					 | 
					                    step_idx,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    global_state["images"]["image_show"],
 | 
					 | 
					 | 
					 | 
					                    global_state['images']['image_show'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # gr.File.update(visible=False),
 | 
					 | 
					 | 
					 | 
					                    # gr.File.update(visible=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -593,6 +585,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # enable stop button in loop
 | 
					 | 
					 | 
					 | 
					                    # enable stop button in loop
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                    gr.Button.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    # update other comps
 | 
					 | 
					 | 
					 | 
					                    # update other comps
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Dropdown.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                    gr.Dropdown.update(interactive=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    gr.Number.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                    gr.Number.update(interactive=False),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -607,25 +600,23 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # increate step
 | 
					 | 
					 | 
					 | 
					                # increate step
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                step_idx += 1
 | 
					 | 
					 | 
					 | 
					                step_idx += 1
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_result = global_state["generator_params"]["image"]
 | 
					 | 
					 | 
					 | 
					            image_result = global_state['generator_params']['image']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["images"]["image_raw"] = image_result
 | 
					 | 
					 | 
					 | 
					            global_state['images']['image_raw'] = image_result
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					            image_draw = update_image_draw(image_result,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                image_result,
 | 
					 | 
					 | 
					 | 
					                                           global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["points"],
 | 
					 | 
					 | 
					 | 
					                                           global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["mask"],
 | 
					 | 
					 | 
					 | 
					                                           global_state['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					                                           global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # fp = NamedTemporaryFile(suffix=".png", delete=False)
 | 
					 | 
					 | 
					 | 
					            # fp = NamedTemporaryFile(suffix=".png", delete=False)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # image_result.save(fp, "PNG")
 | 
					 | 
					 | 
					 | 
					            # image_result.save(fp, "PNG")
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["editing_state"] = "add_points"
 | 
					 | 
					 | 
					 | 
					            global_state['editing_state'] = 'add_points'
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            yield (
 | 
					 | 
					 | 
					 | 
					            yield (
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state,
 | 
					 | 
					 | 
					 | 
					                global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                0,  # reset step to 0 after stop.
 | 
					 | 
					 | 
					 | 
					                0,  # reset step to 0 after stop.
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["images"]["image_show"],
 | 
					 | 
					 | 
					 | 
					                global_state['images']['image_show'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # gr.File.update(visible=True, value=fp.name),
 | 
					 | 
					 | 
					 | 
					                # gr.File.update(visible=True, value=fp.name),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -637,6 +628,7 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # NOTE: disable stop button with loop finish
 | 
					 | 
					 | 
					 | 
					                # NOTE: disable stop button with loop finish
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=False),
 | 
					 | 
					 | 
					 | 
					                gr.Button.update(interactive=False),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                # update other comps
 | 
					 | 
					 | 
					 | 
					                # update other comps
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Dropdown.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Dropdown.update(interactive=True),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
					 | 
					 | 
					 | 
					                gr.Number.update(interactive=True),
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -681,9 +673,9 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, gr.Button.update(interactive=False)
 | 
					 | 
					 | 
					 | 
					        return global_state, gr.Button.update(interactive=False)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_stop_btn.click(
 | 
					 | 
					 | 
					 | 
					    form_stop_btn.click(on_click_stop,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_stop, inputs=[global_state], outputs=[global_state, form_stop_btn]
 | 
					 | 
					 | 
					 | 
					                        inputs=[global_state],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					                        outputs=[global_state, form_stop_btn])
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_draw_interval_number.change(
 | 
					 | 
					 | 
					 | 
					    form_draw_interval_number.change(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        partial(
 | 
					 | 
					 | 
					 | 
					        partial(
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -711,20 +703,17 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    # Mask
 | 
					 | 
					 | 
					 | 
					    # Mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_click_reset_mask(global_state):
 | 
					 | 
					 | 
					 | 
					    def on_click_reset_mask(global_state):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["mask"] = np.ones(
 | 
					 | 
					 | 
					 | 
					        global_state['mask'] = np.ones(
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            (
 | 
					 | 
					 | 
					 | 
					            (
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["images"]["image_raw"].size[1],
 | 
					 | 
					 | 
					 | 
					                global_state["images"]["image_raw"].size[1],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                global_state["images"]["image_raw"].size[0],
 | 
					 | 
					 | 
					 | 
					                global_state["images"]["image_raw"].size[0],
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            ),
 | 
					 | 
					 | 
					 | 
					            ),
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            dtype=np.uint8,
 | 
					 | 
					 | 
					 | 
					            dtype=np.uint8,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					        )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(global_state['images']['image_raw'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["images"]["image_raw"],
 | 
					 | 
					 | 
					 | 
					                                       global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["points"],
 | 
					 | 
					 | 
					 | 
					                                       global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["mask"],
 | 
					 | 
					 | 
					 | 
					                                       global_state['show_mask'], global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    form_reset_mask_btn.click(
 | 
					 | 
					 | 
					 | 
					    form_reset_mask_btn.click(
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -741,12 +730,13 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        3. Set curr image with points and mask
 | 
					 | 
					 | 
					 | 
					        3. Set curr image with points and mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["editing_state"] = "add_mask"
 | 
					 | 
					 | 
					 | 
					        global_state['editing_state'] = 'add_mask'
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(image_raw, global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw, global_state["points"], global_state["mask"], True, global_state
 | 
					 | 
					 | 
					 | 
					                                       global_state['mask'], True,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					                                       global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return (global_state, gr.Image.update(value=image_draw, interactive=True))
 | 
					 | 
					 | 
					 | 
					        return (global_state,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                gr.Image.update(value=image_draw, interactive=True))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_click_remove_draw(global_state, image):
 | 
					 | 
					 | 
					 | 
					    def on_click_remove_draw(global_state, image):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """Function to start remove mask mode.
 | 
					 | 
					 | 
					 | 
					        """Function to start remove mask mode.
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -755,21 +745,20 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        3. Set curr image with points and mask
 | 
					 | 
					 | 
					 | 
					        3. Set curr image with points and mask
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["edinting_state"] = "remove_mask"
 | 
					 | 
					 | 
					 | 
					        global_state['edinting_state'] = 'remove_mask'
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(image_raw, global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw, global_state["points"], global_state["mask"], True, global_state
 | 
					 | 
					 | 
					 | 
					                                       global_state['mask'], True,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					                                       global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return (global_state, gr.Image.update(value=image_draw, interactive=True))
 | 
					 | 
					 | 
					 | 
					        return (global_state,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					                gr.Image.update(value=image_draw, interactive=True))
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    enable_add_mask.click(
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_enable_draw,
 | 
					 | 
					 | 
					 | 
					    enable_add_mask.click(on_click_enable_draw,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        inputs=[global_state, form_image],
 | 
					 | 
					 | 
					 | 
					                          inputs=[global_state, form_image],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        outputs=[
 | 
					 | 
					 | 
					 | 
					                          outputs=[
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state,
 | 
					 | 
					 | 
					 | 
					                              global_state,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            form_image,
 | 
					 | 
					 | 
					 | 
					                              form_image,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        ],
 | 
					 | 
					 | 
					 | 
					                          ])
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_click_add_point(global_state, image: dict):
 | 
					 | 
					 | 
					 | 
					    def on_click_add_point(global_state, image: dict):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """Function switch from add mask mode to add points mode.
 | 
					 | 
					 | 
					 | 
					        """Function switch from add mask mode to add points mode.
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -779,52 +768,48 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
					 | 
					 | 
					 | 
					        global_state = preprocess_mask_info(global_state, image)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["editing_state"] = "add_points"
 | 
					 | 
					 | 
					 | 
					        global_state['editing_state'] = 'add_points'
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        mask = global_state["mask"]
 | 
					 | 
					 | 
					 | 
					        mask = global_state['mask']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(image_raw, global_state['points'], mask,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw,
 | 
					 | 
					 | 
					 | 
					                                       global_state['show_mask'], global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["points"],
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            mask,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state,
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return (global_state, gr.Image.update(value=image_draw, interactive=False))
 | 
					 | 
					 | 
					 | 
					        return (global_state,
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                gr.Image.update(value=image_draw, interactive=False))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    enable_add_points.click(
 | 
					 | 
					 | 
					 | 
					    enable_add_points.click(on_click_add_point,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_add_point,
 | 
					 | 
					 | 
					 | 
					                            inputs=[global_state, form_image],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        inputs=[global_state, form_image],
 | 
					 | 
					 | 
					 | 
					                            outputs=[global_state, form_image])
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        outputs=[global_state, form_image],
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_click_image(global_state, evt: gr.SelectData):
 | 
					 | 
					 | 
					 | 
					    def on_click_image(global_state, evt: gr.SelectData):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """This function only support click for point selection"""
 | 
					 | 
					 | 
					 | 
					        """This function only support click for point selection
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        xy = evt.index
 | 
					 | 
					 | 
					 | 
					        xy = evt.index
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if global_state["editing_state"] != "add_points":
 | 
					 | 
					 | 
					 | 
					        if global_state['editing_state'] != 'add_points':
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f'In {global_state["editing_state"]} state. ' "Do not add points.")
 | 
					 | 
					 | 
					 | 
					            print(f'In {global_state["editing_state"]} state. '
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                  'Do not add points.')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            return global_state, global_state["images"]["image_show"]
 | 
					 | 
					 | 
					 | 
					            return global_state, global_state['images']['image_show']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        points = global_state["points"]
 | 
					 | 
					 | 
					 | 
					        points = global_state["points"]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        point_idx = get_latest_points_pair(points)
 | 
					 | 
					 | 
					 | 
					        point_idx = get_latest_points_pair(points)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if point_idx is None:
 | 
					 | 
					 | 
					 | 
					        if point_idx is None:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            points[0] = {"start": xy, "target": None}
 | 
					 | 
					 | 
					 | 
					            points[0] = {'start': xy, 'target': None}
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"Click Image - Start - {xy}")
 | 
					 | 
					 | 
					 | 
					            print(f'Click Image - Start - {xy}')
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        elif points[point_idx].get("target", None) is None:
 | 
					 | 
					 | 
					 | 
					        elif points[point_idx].get('target', None) is None:
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            points[point_idx]["target"] = xy
 | 
					 | 
					 | 
					 | 
					            points[point_idx]['target'] = xy
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"Click Image - Target - {xy}")
 | 
					 | 
					 | 
					 | 
					            print(f'Click Image - Target - {xy}')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        else:
 | 
					 | 
					 | 
					 | 
					        else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            points[point_idx + 1] = {"start": xy, "target": None}
 | 
					 | 
					 | 
					 | 
					            points[point_idx + 1] = {'start': xy, 'target': None}
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            print(f"Click Image - Start - {xy}")
 | 
					 | 
					 | 
					 | 
					            print(f'Click Image - Start - {xy}')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw,
 | 
					 | 
					 | 
					 | 
					            image_raw,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["points"],
 | 
					 | 
					 | 
					 | 
					            global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["mask"],
 | 
					 | 
					 | 
					 | 
					            global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					            global_state['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state,
 | 
					 | 
					 | 
					 | 
					            global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					        )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -842,31 +827,30 @@ with gr.Blocks() as app:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        2. re-init network
 | 
					 | 
					 | 
					 | 
					        2. re-init network
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        2. re-draw image
 | 
					 | 
					 | 
					 | 
					        2. re-draw image
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        clear_state(global_state, target="point")
 | 
					 | 
					 | 
					 | 
					        clear_state(global_state, target='point')
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        renderer: Renderer = global_state["renderer"]
 | 
					 | 
					 | 
					 | 
					        renderer: Renderer = global_state["renderer"]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        renderer.feat_refs = None
 | 
					 | 
					 | 
					 | 
					        renderer.feat_refs = None
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(image_raw, {}, global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw, {}, global_state["mask"], global_state["show_mask"], global_state
 | 
					 | 
					 | 
					 | 
					                                       global_state['show_mask'], global_state)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    undo_points.click(
 | 
					 | 
					 | 
					 | 
					    undo_points.click(on_click_clear_points,
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        on_click_clear_points, inputs=[global_state], outputs=[global_state, form_image]
 | 
					 | 
					 | 
					 | 
					                      inputs=[global_state],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    )
 | 
					 | 
					 | 
					 | 
					                      outputs=[global_state, form_image])
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def on_click_show_mask(global_state, show_mask):
 | 
					 | 
					 | 
					 | 
					    def on_click_show_mask(global_state, show_mask):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """Function to control whether show mask on image."""
 | 
					 | 
					 | 
					 | 
					        """Function to control whether show mask on image."""
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        global_state["show_mask"] = show_mask
 | 
					 | 
					 | 
					 | 
					        global_state['show_mask'] = show_mask
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_raw = global_state["images"]["image_raw"]
 | 
					 | 
					 | 
					 | 
					        image_raw = global_state['images']['image_raw']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
					 | 
					 | 
					 | 
					        image_draw = update_image_draw(
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            image_raw,
 | 
					 | 
					 | 
					 | 
					            image_raw,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["points"],
 | 
					 | 
					 | 
					 | 
					            global_state['points'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["mask"],
 | 
					 | 
					 | 
					 | 
					            global_state['mask'],
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state["show_mask"],
 | 
					 | 
					 | 
					 | 
					            global_state['show_mask'],
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            global_state,
 | 
					 | 
					 | 
					 | 
					            global_state,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        )
 | 
					 | 
					 | 
					 | 
					        )
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
					 | 
					 | 
					 | 
					        return global_state, image_draw
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					 | 
					
 
 |