|
|
|
|
@ -9,24 +9,19 @@ import torch
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
import dnnlib
|
|
|
|
|
from gradio_utils import (
|
|
|
|
|
ImageMask,
|
|
|
|
|
draw_mask_on_image,
|
|
|
|
|
draw_points_on_image,
|
|
|
|
|
get_latest_points_pair,
|
|
|
|
|
get_valid_mask,
|
|
|
|
|
on_change_single_global_state,
|
|
|
|
|
)
|
|
|
|
|
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
|
|
|
|
|
get_latest_points_pair, get_valid_mask,
|
|
|
|
|
on_change_single_global_state)
|
|
|
|
|
from viz.renderer import Renderer, add_watermark_np
|
|
|
|
|
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
|
parser.add_argument("--share", action="store_true")
|
|
|
|
|
parser.add_argument("--cache-dir", type=str, default="./checkpoints")
|
|
|
|
|
parser.add_argument('--share', action='store_true')
|
|
|
|
|
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
cache_dir = args.cache_dir
|
|
|
|
|
|
|
|
|
|
device = "cuda"
|
|
|
|
|
device = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reverse_point_pairs(points):
|
|
|
|
|
@ -43,18 +38,17 @@ def clear_state(global_state, target=None):
|
|
|
|
|
2. set global_state['mask'] as full-one mask.
|
|
|
|
|
"""
|
|
|
|
|
if target is None:
|
|
|
|
|
target = ["point", "mask"]
|
|
|
|
|
target = ['point', 'mask']
|
|
|
|
|
if not isinstance(target, list):
|
|
|
|
|
target = [target]
|
|
|
|
|
if "point" in target:
|
|
|
|
|
global_state["points"] = dict()
|
|
|
|
|
print("Clear Points State!")
|
|
|
|
|
if "mask" in target:
|
|
|
|
|
if 'point' in target:
|
|
|
|
|
global_state['points'] = dict()
|
|
|
|
|
print('Clear Points State!')
|
|
|
|
|
if 'mask' in target:
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
global_state["mask"] = np.ones(
|
|
|
|
|
(image_raw.size[1], image_raw.size[0]), dtype=np.uint8
|
|
|
|
|
)
|
|
|
|
|
print("Clear mask State!")
|
|
|
|
|
global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
|
|
|
|
|
dtype=np.uint8)
|
|
|
|
|
print('Clear mask State!')
|
|
|
|
|
|
|
|
|
|
return global_state
|
|
|
|
|
|
|
|
|
|
@ -73,46 +67,43 @@ def init_images(global_state):
|
|
|
|
|
else:
|
|
|
|
|
state = global_state
|
|
|
|
|
|
|
|
|
|
state["renderer"].init_network(
|
|
|
|
|
state["generator_params"], # res
|
|
|
|
|
valid_checkpoints_dict[state["pretrained_weight"]], # pkl
|
|
|
|
|
state["params"]["seed"], # w0_seed,
|
|
|
|
|
state['renderer'].init_network(
|
|
|
|
|
state['generator_params'], # res
|
|
|
|
|
valid_checkpoints_dict[state['pretrained_weight']], # pkl
|
|
|
|
|
state['params']['seed'], # w0_seed,
|
|
|
|
|
None, # w_load
|
|
|
|
|
state["params"]["latent_space"] == "w+", # w_plus
|
|
|
|
|
"const",
|
|
|
|
|
state["params"]["trunc_psi"], # trunc_psi,
|
|
|
|
|
state["params"]["trunc_cutoff"], # trunc_cutoff,
|
|
|
|
|
state['params']['latent_space'] == 'w+', # w_plus
|
|
|
|
|
'const',
|
|
|
|
|
state['params']['trunc_psi'], # trunc_psi,
|
|
|
|
|
state['params']['trunc_cutoff'], # trunc_cutoff,
|
|
|
|
|
None, # input_transform
|
|
|
|
|
state["params"]["lr"], # lr,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
state["renderer"]._render_drag_impl(
|
|
|
|
|
state["generator_params"], is_drag=False, to_pil=True
|
|
|
|
|
state['params']['lr'] # lr,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
init_image = state["generator_params"].image
|
|
|
|
|
state["images"]["image_orig"] = init_image
|
|
|
|
|
state["images"]["image_raw"] = init_image
|
|
|
|
|
state["images"]["image_show"] = Image.fromarray(
|
|
|
|
|
add_watermark_np(np.array(init_image))
|
|
|
|
|
)
|
|
|
|
|
state["mask"] = np.ones((init_image.size[1], init_image.size[0]), dtype=np.uint8)
|
|
|
|
|
state['renderer']._render_drag_impl(state['generator_params'],
|
|
|
|
|
is_drag=False,
|
|
|
|
|
to_pil=True)
|
|
|
|
|
|
|
|
|
|
init_image = state['generator_params'].image
|
|
|
|
|
state['images']['image_orig'] = init_image
|
|
|
|
|
state['images']['image_raw'] = init_image
|
|
|
|
|
state['images']['image_show'] = Image.fromarray(
|
|
|
|
|
add_watermark_np(np.array(init_image)))
|
|
|
|
|
state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
|
|
|
|
|
dtype=np.uint8)
|
|
|
|
|
return global_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_image_draw(image, points, mask, show_mask, global_state=None):
|
|
|
|
|
|
|
|
|
|
image_draw = draw_points_on_image(image, points)
|
|
|
|
|
if (
|
|
|
|
|
show_mask
|
|
|
|
|
and mask is not None
|
|
|
|
|
and not (mask == 0).all()
|
|
|
|
|
and not (mask == 1).all()
|
|
|
|
|
):
|
|
|
|
|
if show_mask and mask is not None and not (mask == 0).all() and not (
|
|
|
|
|
mask == 1).all():
|
|
|
|
|
image_draw = draw_mask_on_image(image_draw, mask)
|
|
|
|
|
|
|
|
|
|
image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
|
|
|
|
|
if global_state is not None:
|
|
|
|
|
global_state["images"]["image_show"] = image_draw
|
|
|
|
|
global_state['images']['image_show'] = image_draw
|
|
|
|
|
return image_draw
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -124,97 +115,102 @@ def preprocess_mask_info(global_state, image):
|
|
|
|
|
2.2 global_state is add_mask:
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(image, dict):
|
|
|
|
|
last_mask = get_valid_mask(image["mask"])
|
|
|
|
|
last_mask = get_valid_mask(image['mask'])
|
|
|
|
|
else:
|
|
|
|
|
last_mask = None
|
|
|
|
|
mask = global_state["mask"]
|
|
|
|
|
mask = global_state['mask']
|
|
|
|
|
|
|
|
|
|
# mask in global state is a placeholder with all 1.
|
|
|
|
|
if (mask == 1).all():
|
|
|
|
|
mask = last_mask
|
|
|
|
|
|
|
|
|
|
# last_mask = global_state['last_mask']
|
|
|
|
|
editing_mode = global_state["editing_state"]
|
|
|
|
|
editing_mode = global_state['editing_state']
|
|
|
|
|
|
|
|
|
|
if last_mask is None:
|
|
|
|
|
return global_state
|
|
|
|
|
|
|
|
|
|
if editing_mode == "remove_mask":
|
|
|
|
|
if editing_mode == 'remove_mask':
|
|
|
|
|
updated_mask = np.clip(mask - last_mask, 0, 1)
|
|
|
|
|
print(f"Last editing_state is {editing_mode}, do remove.")
|
|
|
|
|
elif editing_mode == "add_mask":
|
|
|
|
|
print(f'Last editing_state is {editing_mode}, do remove.')
|
|
|
|
|
elif editing_mode == 'add_mask':
|
|
|
|
|
updated_mask = np.clip(mask + last_mask, 0, 1)
|
|
|
|
|
print(f"Last editing_state is {editing_mode}, do add.")
|
|
|
|
|
print(f'Last editing_state is {editing_mode}, do add.')
|
|
|
|
|
else:
|
|
|
|
|
updated_mask = mask
|
|
|
|
|
print(f"Last editing_state is {editing_mode}, " "do nothing to mask.")
|
|
|
|
|
print(f'Last editing_state is {editing_mode}, '
|
|
|
|
|
'do nothing to mask.')
|
|
|
|
|
|
|
|
|
|
global_state["mask"] = updated_mask
|
|
|
|
|
global_state['mask'] = updated_mask
|
|
|
|
|
# global_state['last_mask'] = None # clear buffer
|
|
|
|
|
return global_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_checkpoints_dict = {
|
|
|
|
|
f.split("/")[-1].split(".")[0]: osp.join(cache_dir, f)
|
|
|
|
|
f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
|
|
|
|
|
for f in os.listdir(cache_dir)
|
|
|
|
|
if (f.endswith("pkl") and osp.exists(osp.join(cache_dir, f)))
|
|
|
|
|
if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
|
|
|
|
|
}
|
|
|
|
|
print(f"File under cache_dir ({cache_dir}):")
|
|
|
|
|
print(f'File under cache_dir ({cache_dir}):')
|
|
|
|
|
print(os.listdir(cache_dir))
|
|
|
|
|
print("Valid checkpoint file:")
|
|
|
|
|
print('Valid checkpoint file:')
|
|
|
|
|
print(valid_checkpoints_dict)
|
|
|
|
|
|
|
|
|
|
init_pkl = list(valid_checkpoints_dict.keys())[4]
|
|
|
|
|
init_pkl = 'stylegan_human_v2_512'
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
# renderer = Renderer()
|
|
|
|
|
global_state = gr.State(
|
|
|
|
|
{
|
|
|
|
|
"images": {
|
|
|
|
|
# image_orig: the original image, change with seed/model is changed
|
|
|
|
|
# image_raw: image with mask and points, change durning optimization
|
|
|
|
|
# image_show: image showed on screen
|
|
|
|
|
},
|
|
|
|
|
"temporal_params": {
|
|
|
|
|
# stop
|
|
|
|
|
},
|
|
|
|
|
"mask": None, # mask for visualization, 1 for editing and 0 for unchange
|
|
|
|
|
"last_mask": None, # last edited mask
|
|
|
|
|
"show_mask": True, # add button
|
|
|
|
|
"generator_params": dnnlib.EasyDict(),
|
|
|
|
|
"params": {
|
|
|
|
|
"seed": 0,
|
|
|
|
|
"motion_lambda": 20,
|
|
|
|
|
"r1_in_pixels": 3,
|
|
|
|
|
"r2_in_pixels": 12,
|
|
|
|
|
"magnitude_direction_in_pixels": 1.0,
|
|
|
|
|
"latent_space": "w+",
|
|
|
|
|
"trunc_psi": 0.7,
|
|
|
|
|
"trunc_cutoff": None,
|
|
|
|
|
"lr": 0.001,
|
|
|
|
|
},
|
|
|
|
|
"device": device,
|
|
|
|
|
"draw_interval": 1,
|
|
|
|
|
"renderer": Renderer(disable_timing=True),
|
|
|
|
|
"points": {},
|
|
|
|
|
"curr_point": None,
|
|
|
|
|
"curr_type_point": "start",
|
|
|
|
|
"editing_state": "add_points",
|
|
|
|
|
"pretrained_weight": init_pkl,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
global_state = gr.State({
|
|
|
|
|
"images": {
|
|
|
|
|
# image_orig: the original image, change with seed/model is changed
|
|
|
|
|
# image_raw: image with mask and points, change durning optimization
|
|
|
|
|
# image_show: image showed on screen
|
|
|
|
|
},
|
|
|
|
|
"temporal_params": {
|
|
|
|
|
# stop
|
|
|
|
|
},
|
|
|
|
|
'mask':
|
|
|
|
|
None, # mask for visualization, 1 for editing and 0 for unchange
|
|
|
|
|
'last_mask': None, # last edited mask
|
|
|
|
|
'show_mask': True, # add button
|
|
|
|
|
"generator_params": dnnlib.EasyDict(),
|
|
|
|
|
"params": {
|
|
|
|
|
"seed": 0,
|
|
|
|
|
"motion_lambda": 20,
|
|
|
|
|
"r1_in_pixels": 3,
|
|
|
|
|
"r2_in_pixels": 12,
|
|
|
|
|
"magnitude_direction_in_pixels": 1.0,
|
|
|
|
|
"latent_space": "w+",
|
|
|
|
|
"trunc_psi": 0.7,
|
|
|
|
|
"trunc_cutoff": None,
|
|
|
|
|
"lr": 0.001,
|
|
|
|
|
},
|
|
|
|
|
"device": device,
|
|
|
|
|
"draw_interval": 1,
|
|
|
|
|
"renderer": Renderer(disable_timing=True),
|
|
|
|
|
"points": {},
|
|
|
|
|
"curr_point": None,
|
|
|
|
|
"curr_type_point": "start",
|
|
|
|
|
'editing_state': 'add_points',
|
|
|
|
|
'pretrained_weight': init_pkl
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# init image
|
|
|
|
|
global_state = init_images(global_state)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
|
|
|
|
|
# Left --> tools
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
# Pickle
|
|
|
|
|
with gr.Row():
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
gr.Markdown(value="Pickle", show_label=False)
|
|
|
|
|
gr.Markdown(value='Pickle', show_label=False)
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=4, min_width=10):
|
|
|
|
|
form_pretrained_dropdown = gr.Dropdown(
|
|
|
|
|
@ -226,71 +222,71 @@ with gr.Blocks() as app:
|
|
|
|
|
# Latent
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
gr.Markdown(value="Latent", show_label=False)
|
|
|
|
|
gr.Markdown(value='Latent', show_label=False)
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=4, min_width=10):
|
|
|
|
|
form_seed_number = gr.Number(
|
|
|
|
|
value=global_state.value["params"]["seed"],
|
|
|
|
|
value=global_state.value['params']['seed'],
|
|
|
|
|
interactive=True,
|
|
|
|
|
label="Seed",
|
|
|
|
|
)
|
|
|
|
|
form_lr_number = gr.Number(
|
|
|
|
|
value=global_state.value["params"]["lr"],
|
|
|
|
|
interactive=True,
|
|
|
|
|
label="Step Size",
|
|
|
|
|
)
|
|
|
|
|
label="Step Size")
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=2, min_width=10):
|
|
|
|
|
form_reset_image = gr.Button("Reset Image")
|
|
|
|
|
with gr.Column(scale=3, min_width=10):
|
|
|
|
|
form_latent_space = gr.Radio(
|
|
|
|
|
["w", "w+"],
|
|
|
|
|
value=global_state.value["params"]["latent_space"],
|
|
|
|
|
['w', 'w+'],
|
|
|
|
|
value=global_state.value['params']
|
|
|
|
|
['latent_space'],
|
|
|
|
|
interactive=True,
|
|
|
|
|
label="Latent space to optimize",
|
|
|
|
|
label='Latent space to optimize',
|
|
|
|
|
show_label=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Drag
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
gr.Markdown(value="Drag", show_label=False)
|
|
|
|
|
gr.Markdown(value='Drag', show_label=False)
|
|
|
|
|
with gr.Column(scale=4, min_width=10):
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
enable_add_points = gr.Button("Add Points")
|
|
|
|
|
enable_add_points = gr.Button('Add Points')
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
undo_points = gr.Button("Reset Points")
|
|
|
|
|
undo_points = gr.Button('Reset Points')
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
form_start_btn = gr.Button("Start")
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
form_stop_btn = gr.Button("Stop")
|
|
|
|
|
|
|
|
|
|
form_steps_number = gr.Number(
|
|
|
|
|
value=0, label="Steps", interactive=False
|
|
|
|
|
)
|
|
|
|
|
form_steps_number = gr.Number(value=0,
|
|
|
|
|
label="Steps",
|
|
|
|
|
interactive=False)
|
|
|
|
|
|
|
|
|
|
# Mask
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
gr.Markdown(value="Mask", show_label=False)
|
|
|
|
|
gr.Markdown(value='Mask', show_label=False)
|
|
|
|
|
with gr.Column(scale=4, min_width=10):
|
|
|
|
|
enable_add_mask = gr.Button("Edit Flexible Area")
|
|
|
|
|
enable_add_mask = gr.Button('Edit Flexible Area')
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
form_reset_mask_btn = gr.Button("Reset mask")
|
|
|
|
|
with gr.Column(scale=1, min_width=10):
|
|
|
|
|
show_mask = gr.Checkbox(
|
|
|
|
|
label="Show Mask",
|
|
|
|
|
value=global_state.value["show_mask"],
|
|
|
|
|
show_label=False,
|
|
|
|
|
)
|
|
|
|
|
label='Show Mask',
|
|
|
|
|
value=global_state.value['show_mask'],
|
|
|
|
|
show_label=False)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
form_lambda_number = gr.Number(
|
|
|
|
|
value=global_state.value["params"]["motion_lambda"],
|
|
|
|
|
value=global_state.value["params"]
|
|
|
|
|
["motion_lambda"],
|
|
|
|
|
interactive=True,
|
|
|
|
|
label="Lambda",
|
|
|
|
|
)
|
|
|
|
|
@ -299,18 +295,16 @@ with gr.Blocks() as app:
|
|
|
|
|
value=global_state.value["draw_interval"],
|
|
|
|
|
label="Draw Interval (steps)",
|
|
|
|
|
interactive=True,
|
|
|
|
|
visible=False,
|
|
|
|
|
)
|
|
|
|
|
visible=False)
|
|
|
|
|
|
|
|
|
|
# Right --> Image
|
|
|
|
|
with gr.Column(scale=8):
|
|
|
|
|
form_image = ImageMask(
|
|
|
|
|
value=global_state.value["images"]["image_show"], brush_radius=20
|
|
|
|
|
).style(
|
|
|
|
|
width=768, height=768
|
|
|
|
|
) # NOTE: hard image size code here.
|
|
|
|
|
gr.Markdown(
|
|
|
|
|
"""
|
|
|
|
|
value=global_state.value['images']['image_show'],
|
|
|
|
|
brush_radius=20).style(
|
|
|
|
|
width=768,
|
|
|
|
|
height=768) # NOTE: hard image size code here.
|
|
|
|
|
gr.Markdown("""
|
|
|
|
|
## Quick Start
|
|
|
|
|
|
|
|
|
|
1. Select desired `Pretrained Model` and adjust `Seed` to generate an
|
|
|
|
|
@ -329,10 +323,8 @@ with gr.Blocks() as app:
|
|
|
|
|
mask (this has the same effect as `Reset Image` button).
|
|
|
|
|
3. Click `Edit Flexible Area` to create a mask and constrain the
|
|
|
|
|
unmasked region to remain unchanged.
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
gr.HTML(
|
|
|
|
|
"""
|
|
|
|
|
""")
|
|
|
|
|
gr.HTML("""
|
|
|
|
|
<style>
|
|
|
|
|
.container {
|
|
|
|
|
position: absolute;
|
|
|
|
|
@ -347,8 +339,7 @@ with gr.Blocks() as app:
|
|
|
|
|
<img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
|
|
|
|
|
<a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
|
|
|
|
|
</div>
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
# Network & latents tab listeners
|
|
|
|
|
def on_change_pretrained_dropdown(pretrained_value, global_state):
|
|
|
|
|
@ -357,11 +348,11 @@ with gr.Blocks() as app:
|
|
|
|
|
2. Re-init images and clear all states
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
global_state["pretrained_weight"] = pretrained_value
|
|
|
|
|
global_state['pretrained_weight'] = pretrained_value
|
|
|
|
|
init_images(global_state)
|
|
|
|
|
clear_state(global_state)
|
|
|
|
|
|
|
|
|
|
return global_state, global_state["images"]["image_show"]
|
|
|
|
|
return global_state, global_state["images"]['image_show']
|
|
|
|
|
|
|
|
|
|
form_pretrained_dropdown.change(
|
|
|
|
|
on_change_pretrained_dropdown,
|
|
|
|
|
@ -378,7 +369,7 @@ with gr.Blocks() as app:
|
|
|
|
|
init_images(global_state)
|
|
|
|
|
clear_state(global_state)
|
|
|
|
|
|
|
|
|
|
return global_state, global_state["images"]["image_show"]
|
|
|
|
|
return global_state, global_state['images']['image_show']
|
|
|
|
|
|
|
|
|
|
form_reset_image.click(
|
|
|
|
|
on_click_reset_image,
|
|
|
|
|
@ -397,7 +388,7 @@ with gr.Blocks() as app:
|
|
|
|
|
init_images(global_state)
|
|
|
|
|
clear_state(global_state)
|
|
|
|
|
|
|
|
|
|
return global_state, global_state["images"]["image_show"]
|
|
|
|
|
return global_state, global_state['images']['image_show']
|
|
|
|
|
|
|
|
|
|
form_seed_number.change(
|
|
|
|
|
on_change_update_image_seed,
|
|
|
|
|
@ -412,17 +403,15 @@ with gr.Blocks() as app:
|
|
|
|
|
2. Re-init images and clear all state
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
global_state["params"]["latent_space"] = latent_space
|
|
|
|
|
global_state['params']['latent_space'] = latent_space
|
|
|
|
|
init_images(global_state)
|
|
|
|
|
clear_state(global_state)
|
|
|
|
|
|
|
|
|
|
return global_state, global_state["images"]["image_show"]
|
|
|
|
|
return global_state, global_state['images']['image_show']
|
|
|
|
|
|
|
|
|
|
form_latent_space.change(
|
|
|
|
|
on_click_latent_space,
|
|
|
|
|
inputs=[form_latent_space, global_state],
|
|
|
|
|
outputs=[global_state, form_image],
|
|
|
|
|
)
|
|
|
|
|
form_latent_space.change(on_click_latent_space,
|
|
|
|
|
inputs=[form_latent_space, global_state],
|
|
|
|
|
outputs=[global_state, form_image])
|
|
|
|
|
|
|
|
|
|
# ==== Params
|
|
|
|
|
form_lambda_number.change(
|
|
|
|
|
@ -433,13 +422,13 @@ with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
def on_change_lr(lr, global_state):
|
|
|
|
|
if lr == 0:
|
|
|
|
|
print("lr is 0, do nothing.")
|
|
|
|
|
print('lr is 0, do nothing.')
|
|
|
|
|
return global_state
|
|
|
|
|
else:
|
|
|
|
|
global_state["params"]["lr"] = lr
|
|
|
|
|
renderer = global_state["renderer"]
|
|
|
|
|
renderer = global_state['renderer']
|
|
|
|
|
renderer.update_lr(lr)
|
|
|
|
|
print("New optimizer: ")
|
|
|
|
|
print('New optimizer: ')
|
|
|
|
|
print(renderer.w_optim)
|
|
|
|
|
return global_state
|
|
|
|
|
|
|
|
|
|
@ -460,19 +449,19 @@ with gr.Blocks() as app:
|
|
|
|
|
# Prepare the points for the inference
|
|
|
|
|
if len(global_state["points"]) == 0:
|
|
|
|
|
# yield on_click_start_wo_points(global_state, image)
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
update_image_draw(
|
|
|
|
|
image_raw,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield (
|
|
|
|
|
global_state,
|
|
|
|
|
0,
|
|
|
|
|
global_state["images"]["image_show"],
|
|
|
|
|
global_state['images']['image_show'],
|
|
|
|
|
# gr.File.update(visible=False),
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
@ -484,6 +473,7 @@ with gr.Blocks() as app:
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
# NOTE: disable stop button
|
|
|
|
|
gr.Button.update(interactive=False),
|
|
|
|
|
|
|
|
|
|
# update other comps
|
|
|
|
|
gr.Dropdown.update(interactive=True),
|
|
|
|
|
gr.Number.update(interactive=True),
|
|
|
|
|
@ -495,6 +485,7 @@ with gr.Blocks() as app:
|
|
|
|
|
gr.Number.update(interactive=True),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
# Transform the points into torch tensors
|
|
|
|
|
for key_point, point in global_state["points"].items():
|
|
|
|
|
try:
|
|
|
|
|
@ -511,19 +502,19 @@ with gr.Blocks() as app:
|
|
|
|
|
t_in_pixels.append(p_end)
|
|
|
|
|
valid_points.append(key_point)
|
|
|
|
|
|
|
|
|
|
mask = torch.tensor(global_state["mask"]).float()
|
|
|
|
|
mask = torch.tensor(global_state['mask']).float()
|
|
|
|
|
drag_mask = 1 - mask
|
|
|
|
|
|
|
|
|
|
renderer: Renderer = global_state["renderer"]
|
|
|
|
|
global_state["temporal_params"]["stop"] = False
|
|
|
|
|
global_state["editing_state"] = "running"
|
|
|
|
|
global_state['temporal_params']['stop'] = False
|
|
|
|
|
global_state['editing_state'] = 'running'
|
|
|
|
|
|
|
|
|
|
# reverse points order
|
|
|
|
|
p_to_opt = reverse_point_pairs(p_in_pixels)
|
|
|
|
|
t_to_opt = reverse_point_pairs(t_in_pixels)
|
|
|
|
|
print("Running with:")
|
|
|
|
|
print(f" Source: {p_in_pixels}")
|
|
|
|
|
print(f" Target: {t_in_pixels}")
|
|
|
|
|
print('Running with:')
|
|
|
|
|
print(f' Source: {p_in_pixels}')
|
|
|
|
|
print(f' Target: {t_in_pixels}')
|
|
|
|
|
step_idx = 0
|
|
|
|
|
while True:
|
|
|
|
|
if global_state["temporal_params"]["stop"]:
|
|
|
|
|
@ -531,18 +522,18 @@ with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
# do drage here!
|
|
|
|
|
renderer._render_drag_impl(
|
|
|
|
|
global_state["generator_params"],
|
|
|
|
|
global_state['generator_params'],
|
|
|
|
|
p_to_opt, # point
|
|
|
|
|
t_to_opt, # target
|
|
|
|
|
drag_mask, # mask,
|
|
|
|
|
global_state["params"]["motion_lambda"], # lambda_mask
|
|
|
|
|
global_state['params']['motion_lambda'], # lambda_mask
|
|
|
|
|
reg=0,
|
|
|
|
|
feature_idx=5, # NOTE: do not support change for now
|
|
|
|
|
r1=global_state["params"]["r1_in_pixels"], # r1
|
|
|
|
|
r2=global_state["params"]["r2_in_pixels"], # r2
|
|
|
|
|
r1=global_state['params']['r1_in_pixels'], # r1
|
|
|
|
|
r2=global_state['params']['r2_in_pixels'], # r2
|
|
|
|
|
# random_seed = 0,
|
|
|
|
|
# noise_mode = 'const',
|
|
|
|
|
trunc_psi=global_state["params"]["trunc_psi"],
|
|
|
|
|
trunc_psi=global_state['params']['trunc_psi'],
|
|
|
|
|
# force_fp32 = False,
|
|
|
|
|
# layer_name = None,
|
|
|
|
|
# sel_channels = 3,
|
|
|
|
|
@ -551,12 +542,12 @@ with gr.Blocks() as app:
|
|
|
|
|
# img_normalize = False,
|
|
|
|
|
# untransform = False,
|
|
|
|
|
is_drag=True,
|
|
|
|
|
to_pil=True,
|
|
|
|
|
)
|
|
|
|
|
to_pil=True)
|
|
|
|
|
|
|
|
|
|
if step_idx % global_state["draw_interval"] == 0:
|
|
|
|
|
print("Current Source:")
|
|
|
|
|
for key_point, p_i, t_i in zip(valid_points, p_to_opt, t_to_opt):
|
|
|
|
|
if step_idx % global_state['draw_interval'] == 0:
|
|
|
|
|
print('Current Source:')
|
|
|
|
|
for key_point, p_i, t_i in zip(valid_points, p_to_opt,
|
|
|
|
|
t_to_opt):
|
|
|
|
|
global_state["points"][key_point]["start_temp"] = [
|
|
|
|
|
p_i[1],
|
|
|
|
|
p_i[0],
|
|
|
|
|
@ -565,23 +556,24 @@ with gr.Blocks() as app:
|
|
|
|
|
t_i[1],
|
|
|
|
|
t_i[0],
|
|
|
|
|
]
|
|
|
|
|
start_temp = global_state["points"][key_point]["start_temp"]
|
|
|
|
|
print(f" {start_temp}")
|
|
|
|
|
start_temp = global_state["points"][key_point][
|
|
|
|
|
"start_temp"]
|
|
|
|
|
print(f' {start_temp}')
|
|
|
|
|
|
|
|
|
|
image_result = global_state["generator_params"]["image"]
|
|
|
|
|
image_result = global_state['generator_params']['image']
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_result,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
global_state["images"]["image_raw"] = image_result
|
|
|
|
|
global_state['images']['image_raw'] = image_result
|
|
|
|
|
|
|
|
|
|
yield (
|
|
|
|
|
global_state,
|
|
|
|
|
step_idx,
|
|
|
|
|
global_state["images"]["image_show"],
|
|
|
|
|
global_state['images']['image_show'],
|
|
|
|
|
# gr.File.update(visible=False),
|
|
|
|
|
gr.Button.update(interactive=False),
|
|
|
|
|
gr.Button.update(interactive=False),
|
|
|
|
|
@ -593,6 +585,7 @@ with gr.Blocks() as app:
|
|
|
|
|
gr.Button.update(interactive=False),
|
|
|
|
|
# enable stop button in loop
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
|
|
|
|
|
# update other comps
|
|
|
|
|
gr.Dropdown.update(interactive=False),
|
|
|
|
|
gr.Number.update(interactive=False),
|
|
|
|
|
@ -607,25 +600,23 @@ with gr.Blocks() as app:
|
|
|
|
|
# increate step
|
|
|
|
|
step_idx += 1
|
|
|
|
|
|
|
|
|
|
image_result = global_state["generator_params"]["image"]
|
|
|
|
|
global_state["images"]["image_raw"] = image_result
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_result,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
image_result = global_state['generator_params']['image']
|
|
|
|
|
global_state['images']['image_raw'] = image_result
|
|
|
|
|
image_draw = update_image_draw(image_result,
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'],
|
|
|
|
|
global_state)
|
|
|
|
|
|
|
|
|
|
# fp = NamedTemporaryFile(suffix=".png", delete=False)
|
|
|
|
|
# image_result.save(fp, "PNG")
|
|
|
|
|
|
|
|
|
|
global_state["editing_state"] = "add_points"
|
|
|
|
|
global_state['editing_state'] = 'add_points'
|
|
|
|
|
|
|
|
|
|
yield (
|
|
|
|
|
global_state,
|
|
|
|
|
0, # reset step to 0 after stop.
|
|
|
|
|
global_state["images"]["image_show"],
|
|
|
|
|
global_state['images']['image_show'],
|
|
|
|
|
# gr.File.update(visible=True, value=fp.name),
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
@ -637,6 +628,7 @@ with gr.Blocks() as app:
|
|
|
|
|
gr.Button.update(interactive=True),
|
|
|
|
|
# NOTE: disable stop button with loop finish
|
|
|
|
|
gr.Button.update(interactive=False),
|
|
|
|
|
|
|
|
|
|
# update other comps
|
|
|
|
|
gr.Dropdown.update(interactive=True),
|
|
|
|
|
gr.Number.update(interactive=True),
|
|
|
|
|
@ -681,9 +673,9 @@ with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
return global_state, gr.Button.update(interactive=False)
|
|
|
|
|
|
|
|
|
|
form_stop_btn.click(
|
|
|
|
|
on_click_stop, inputs=[global_state], outputs=[global_state, form_stop_btn]
|
|
|
|
|
)
|
|
|
|
|
form_stop_btn.click(on_click_stop,
|
|
|
|
|
inputs=[global_state],
|
|
|
|
|
outputs=[global_state, form_stop_btn])
|
|
|
|
|
|
|
|
|
|
form_draw_interval_number.change(
|
|
|
|
|
partial(
|
|
|
|
|
@ -711,20 +703,17 @@ with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
# Mask
|
|
|
|
|
def on_click_reset_mask(global_state):
|
|
|
|
|
global_state["mask"] = np.ones(
|
|
|
|
|
global_state['mask'] = np.ones(
|
|
|
|
|
(
|
|
|
|
|
global_state["images"]["image_raw"].size[1],
|
|
|
|
|
global_state["images"]["image_raw"].size[0],
|
|
|
|
|
),
|
|
|
|
|
dtype=np.uint8,
|
|
|
|
|
)
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
global_state["images"]["image_raw"],
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
image_draw = update_image_draw(global_state['images']['image_raw'],
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'], global_state)
|
|
|
|
|
return global_state, image_draw
|
|
|
|
|
|
|
|
|
|
form_reset_mask_btn.click(
|
|
|
|
|
@ -741,12 +730,13 @@ with gr.Blocks() as app:
|
|
|
|
|
3. Set curr image with points and mask
|
|
|
|
|
"""
|
|
|
|
|
global_state = preprocess_mask_info(global_state, image)
|
|
|
|
|
global_state["editing_state"] = "add_mask"
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw, global_state["points"], global_state["mask"], True, global_state
|
|
|
|
|
)
|
|
|
|
|
return (global_state, gr.Image.update(value=image_draw, interactive=True))
|
|
|
|
|
global_state['editing_state'] = 'add_mask'
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(image_raw, global_state['points'],
|
|
|
|
|
global_state['mask'], True,
|
|
|
|
|
global_state)
|
|
|
|
|
return (global_state,
|
|
|
|
|
gr.Image.update(value=image_draw, interactive=True))
|
|
|
|
|
|
|
|
|
|
def on_click_remove_draw(global_state, image):
|
|
|
|
|
"""Function to start remove mask mode.
|
|
|
|
|
@ -755,21 +745,20 @@ with gr.Blocks() as app:
|
|
|
|
|
3. Set curr image with points and mask
|
|
|
|
|
"""
|
|
|
|
|
global_state = preprocess_mask_info(global_state, image)
|
|
|
|
|
global_state["edinting_state"] = "remove_mask"
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw, global_state["points"], global_state["mask"], True, global_state
|
|
|
|
|
)
|
|
|
|
|
return (global_state, gr.Image.update(value=image_draw, interactive=True))
|
|
|
|
|
|
|
|
|
|
enable_add_mask.click(
|
|
|
|
|
on_click_enable_draw,
|
|
|
|
|
inputs=[global_state, form_image],
|
|
|
|
|
outputs=[
|
|
|
|
|
global_state,
|
|
|
|
|
form_image,
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
global_state['edinting_state'] = 'remove_mask'
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(image_raw, global_state['points'],
|
|
|
|
|
global_state['mask'], True,
|
|
|
|
|
global_state)
|
|
|
|
|
return (global_state,
|
|
|
|
|
gr.Image.update(value=image_draw, interactive=True))
|
|
|
|
|
|
|
|
|
|
enable_add_mask.click(on_click_enable_draw,
|
|
|
|
|
inputs=[global_state, form_image],
|
|
|
|
|
outputs=[
|
|
|
|
|
global_state,
|
|
|
|
|
form_image,
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def on_click_add_point(global_state, image: dict):
|
|
|
|
|
"""Function switch from add mask mode to add points mode.
|
|
|
|
|
@ -779,52 +768,48 @@ with gr.Blocks() as app:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
global_state = preprocess_mask_info(global_state, image)
|
|
|
|
|
global_state["editing_state"] = "add_points"
|
|
|
|
|
mask = global_state["mask"]
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
mask,
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
global_state['editing_state'] = 'add_points'
|
|
|
|
|
mask = global_state['mask']
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(image_raw, global_state['points'], mask,
|
|
|
|
|
global_state['show_mask'], global_state)
|
|
|
|
|
|
|
|
|
|
return (global_state, gr.Image.update(value=image_draw, interactive=False))
|
|
|
|
|
return (global_state,
|
|
|
|
|
gr.Image.update(value=image_draw, interactive=False))
|
|
|
|
|
|
|
|
|
|
enable_add_points.click(
|
|
|
|
|
on_click_add_point,
|
|
|
|
|
inputs=[global_state, form_image],
|
|
|
|
|
outputs=[global_state, form_image],
|
|
|
|
|
)
|
|
|
|
|
enable_add_points.click(on_click_add_point,
|
|
|
|
|
inputs=[global_state, form_image],
|
|
|
|
|
outputs=[global_state, form_image])
|
|
|
|
|
|
|
|
|
|
def on_click_image(global_state, evt: gr.SelectData):
|
|
|
|
|
"""This function only support click for point selection"""
|
|
|
|
|
"""This function only support click for point selection
|
|
|
|
|
"""
|
|
|
|
|
xy = evt.index
|
|
|
|
|
if global_state["editing_state"] != "add_points":
|
|
|
|
|
print(f'In {global_state["editing_state"]} state. ' "Do not add points.")
|
|
|
|
|
if global_state['editing_state'] != 'add_points':
|
|
|
|
|
print(f'In {global_state["editing_state"]} state. '
|
|
|
|
|
'Do not add points.')
|
|
|
|
|
|
|
|
|
|
return global_state, global_state["images"]["image_show"]
|
|
|
|
|
return global_state, global_state['images']['image_show']
|
|
|
|
|
|
|
|
|
|
points = global_state["points"]
|
|
|
|
|
|
|
|
|
|
point_idx = get_latest_points_pair(points)
|
|
|
|
|
if point_idx is None:
|
|
|
|
|
points[0] = {"start": xy, "target": None}
|
|
|
|
|
print(f"Click Image - Start - {xy}")
|
|
|
|
|
elif points[point_idx].get("target", None) is None:
|
|
|
|
|
points[point_idx]["target"] = xy
|
|
|
|
|
print(f"Click Image - Target - {xy}")
|
|
|
|
|
points[0] = {'start': xy, 'target': None}
|
|
|
|
|
print(f'Click Image - Start - {xy}')
|
|
|
|
|
elif points[point_idx].get('target', None) is None:
|
|
|
|
|
points[point_idx]['target'] = xy
|
|
|
|
|
print(f'Click Image - Target - {xy}')
|
|
|
|
|
else:
|
|
|
|
|
points[point_idx + 1] = {"start": xy, "target": None}
|
|
|
|
|
print(f"Click Image - Start - {xy}")
|
|
|
|
|
points[point_idx + 1] = {'start': xy, 'target': None}
|
|
|
|
|
print(f'Click Image - Start - {xy}')
|
|
|
|
|
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -842,31 +827,30 @@ with gr.Blocks() as app:
|
|
|
|
|
2. re-init network
|
|
|
|
|
2. re-draw image
|
|
|
|
|
"""
|
|
|
|
|
clear_state(global_state, target="point")
|
|
|
|
|
clear_state(global_state, target='point')
|
|
|
|
|
|
|
|
|
|
renderer: Renderer = global_state["renderer"]
|
|
|
|
|
renderer.feat_refs = None
|
|
|
|
|
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw, {}, global_state["mask"], global_state["show_mask"], global_state
|
|
|
|
|
)
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(image_raw, {}, global_state['mask'],
|
|
|
|
|
global_state['show_mask'], global_state)
|
|
|
|
|
return global_state, image_draw
|
|
|
|
|
|
|
|
|
|
undo_points.click(
|
|
|
|
|
on_click_clear_points, inputs=[global_state], outputs=[global_state, form_image]
|
|
|
|
|
)
|
|
|
|
|
undo_points.click(on_click_clear_points,
|
|
|
|
|
inputs=[global_state],
|
|
|
|
|
outputs=[global_state, form_image])
|
|
|
|
|
|
|
|
|
|
def on_click_show_mask(global_state, show_mask):
|
|
|
|
|
"""Function to control whether show mask on image."""
|
|
|
|
|
global_state["show_mask"] = show_mask
|
|
|
|
|
global_state['show_mask'] = show_mask
|
|
|
|
|
|
|
|
|
|
image_raw = global_state["images"]["image_raw"]
|
|
|
|
|
image_raw = global_state['images']['image_raw']
|
|
|
|
|
image_draw = update_image_draw(
|
|
|
|
|
image_raw,
|
|
|
|
|
global_state["points"],
|
|
|
|
|
global_state["mask"],
|
|
|
|
|
global_state["show_mask"],
|
|
|
|
|
global_state['points'],
|
|
|
|
|
global_state['mask'],
|
|
|
|
|
global_state['show_mask'],
|
|
|
|
|
global_state,
|
|
|
|
|
)
|
|
|
|
|
return global_state, image_draw
|
|
|
|
|
|