mirror of https://github.com/XingangPan/DragGAN
support gradio demo for DragGAN
parent
91b9322e2f
commit
ab719c4a7a
@ -0,0 +1,9 @@
|
||||
from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
|
||||
get_latest_points_pair, get_valid_mask,
|
||||
on_change_single_global_state)
|
||||
|
||||
__all__ = [
|
||||
'draw_mask_on_image', 'draw_points_on_image',
|
||||
'on_change_single_global_state', 'get_latest_points_pair',
|
||||
'get_valid_mask', 'ImageMask'
|
||||
]
|
||||
@ -0,0 +1,154 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
class ImageMask(gr.components.Image):
|
||||
"""
|
||||
Sets: source="canvas", tool="sketch"
|
||||
"""
|
||||
|
||||
is_template = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(source="upload",
|
||||
tool="sketch",
|
||||
interactive=False,
|
||||
**kwargs)
|
||||
|
||||
def preprocess(self, x):
|
||||
if x is None:
|
||||
return x
|
||||
if self.tool == "sketch" and self.source in ["upload", "webcam"
|
||||
] and type(x) != dict:
|
||||
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
||||
width, height = decode_image.size
|
||||
mask = np.ones((height, width, 4), dtype=np.uint8)
|
||||
mask[..., -1] = 255
|
||||
mask = self.postprocess(mask)
|
||||
x = {'image': x, 'mask': mask}
|
||||
return super().preprocess(x)
|
||||
|
||||
|
||||
def get_valid_mask(mask: np.ndarray):
|
||||
"""Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
|
||||
"""
|
||||
if mask.ndim == 3:
|
||||
mask_pil = Image.fromarray(mask).convert('L')
|
||||
mask = np.array(mask_pil)
|
||||
if mask.max() == 255:
|
||||
mask = mask / 255
|
||||
return mask
|
||||
|
||||
|
||||
def draw_points_on_image(image,
|
||||
points,
|
||||
curr_point=None,
|
||||
highlight_all=True,
|
||||
radius_scale=0.01):
|
||||
overlay_rgba = Image.new("RGBA", image.size, 0)
|
||||
overlay_draw = ImageDraw.Draw(overlay_rgba)
|
||||
for point_key, point in points.items():
|
||||
if ((curr_point is not None and curr_point == point_key)
|
||||
or highlight_all):
|
||||
p_color = (255, 0, 0)
|
||||
t_color = (0, 0, 255)
|
||||
|
||||
else:
|
||||
p_color = (255, 0, 0, 35)
|
||||
t_color = (0, 0, 255, 35)
|
||||
|
||||
rad_draw = int(image.size[0] * radius_scale)
|
||||
|
||||
p_start = point.get("start_temp", point["start"])
|
||||
p_target = point["target"]
|
||||
|
||||
if p_start is not None and p_target is not None:
|
||||
p_draw = int(p_start[0]), int(p_start[1])
|
||||
t_draw = int(p_target[0]), int(p_target[1])
|
||||
|
||||
overlay_draw.line(
|
||||
(p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
|
||||
fill=(255, 255, 0),
|
||||
width=2,
|
||||
)
|
||||
|
||||
if p_start is not None:
|
||||
p_draw = int(p_start[0]), int(p_start[1])
|
||||
overlay_draw.ellipse(
|
||||
(
|
||||
p_draw[0] - rad_draw,
|
||||
p_draw[1] - rad_draw,
|
||||
p_draw[0] + rad_draw,
|
||||
p_draw[1] + rad_draw,
|
||||
),
|
||||
fill=p_color,
|
||||
)
|
||||
|
||||
if curr_point is not None and curr_point == point_key:
|
||||
# overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
|
||||
overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
|
||||
|
||||
if p_target is not None:
|
||||
t_draw = int(p_target[0]), int(p_target[1])
|
||||
overlay_draw.ellipse(
|
||||
(
|
||||
t_draw[0] - rad_draw,
|
||||
t_draw[1] - rad_draw,
|
||||
t_draw[0] + rad_draw,
|
||||
t_draw[1] + rad_draw,
|
||||
),
|
||||
fill=t_color,
|
||||
)
|
||||
|
||||
if curr_point is not None and curr_point == point_key:
|
||||
# overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
|
||||
overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
|
||||
|
||||
return Image.alpha_composite(image.convert("RGBA"),
|
||||
overlay_rgba).convert("RGB")
|
||||
|
||||
|
||||
def draw_mask_on_image(image, mask):
|
||||
im_mask = np.uint8(mask * 255)
|
||||
im_mask_rgba = np.concatenate(
|
||||
(
|
||||
np.tile(im_mask[..., None], [1, 1, 3]),
|
||||
45 * np.ones(
|
||||
(im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
|
||||
|
||||
return Image.alpha_composite(image.convert("RGBA"),
|
||||
im_mask_rgba).convert("RGB")
|
||||
|
||||
|
||||
def on_change_single_global_state(keys,
|
||||
value,
|
||||
global_state,
|
||||
map_transform=None):
|
||||
if map_transform is not None:
|
||||
value = map_transform(value)
|
||||
|
||||
curr_state = global_state
|
||||
if isinstance(keys, str):
|
||||
last_key = keys
|
||||
|
||||
else:
|
||||
for k in keys[:-1]:
|
||||
curr_state = curr_state[k]
|
||||
|
||||
last_key = keys[-1]
|
||||
|
||||
curr_state[last_key] = value
|
||||
return global_state
|
||||
|
||||
|
||||
def get_latest_points_pair(points_dict):
|
||||
if not points_dict:
|
||||
return None
|
||||
point_idx = list(points_dict.keys())
|
||||
latest_point_idx = max(point_idx)
|
||||
return latest_point_idx
|
||||
@ -0,0 +1,866 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import dnnlib
|
||||
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
|
||||
get_latest_points_pair, get_valid_mask,
|
||||
on_change_single_global_state)
|
||||
from viz.renderer import Renderer, add_watermark_np
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--share', action='store_true')
|
||||
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
|
||||
args = parser.parse_args()
|
||||
|
||||
cache_dir = args.cache_dir
|
||||
|
||||
device = 'cuda'
|
||||
|
||||
|
||||
def reverse_point_pairs(points):
|
||||
new_points = []
|
||||
for p in points:
|
||||
new_points.append([p[1], p[0]])
|
||||
return new_points
|
||||
|
||||
|
||||
def clear_state(global_state, target=None):
|
||||
"""Clear target history state from global_state
|
||||
If target is not defined, points and mask will be both removed.
|
||||
1. set global_state['points'] as empty dict
|
||||
2. set global_state['mask'] as full-one mask.
|
||||
"""
|
||||
if target is None:
|
||||
target = ['point', 'mask']
|
||||
if not isinstance(target, list):
|
||||
target = [target]
|
||||
if 'point' in target:
|
||||
global_state['points'] = dict()
|
||||
print('Clear Points State!')
|
||||
if 'mask' in target:
|
||||
image_raw = global_state["images"]["image_raw"]
|
||||
global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
|
||||
dtype=np.uint8)
|
||||
print('Clear mask State!')
|
||||
|
||||
return global_state
|
||||
|
||||
|
||||
def init_images(global_state):
|
||||
"""This function is called only ones with Gradio App is started.
|
||||
0. pre-process global_state, unpack value from global_state of need
|
||||
1. Re-init renderer
|
||||
2. run `renderer._render_drag_impl` with `is_drag=False` to generate
|
||||
new image
|
||||
3. Assign images to global state and re-generate mask
|
||||
"""
|
||||
|
||||
if isinstance(global_state, gr.State):
|
||||
state = global_state.value
|
||||
else:
|
||||
state = global_state
|
||||
|
||||
state['renderer'].init_network(
|
||||
state['generator_params'], # res
|
||||
valid_checkpoints_dict[state['pretrained_weight']], # pkl
|
||||
state['params']['seed'], # w0_seed,
|
||||
None, # w_load
|
||||
state['params']['latent_space'] == 'w+', # w_plus
|
||||
'const',
|
||||
state['params']['trunc_psi'], # trunc_psi,
|
||||
state['params']['trunc_cutoff'], # trunc_cutoff,
|
||||
None, # input_transform
|
||||
state['params']['lr'] # lr,
|
||||
)
|
||||
|
||||
state['renderer']._render_drag_impl(state['generator_params'],
|
||||
is_drag=False,
|
||||
to_pil=True)
|
||||
|
||||
init_image = state['generator_params'].image
|
||||
state['images']['image_orig'] = init_image
|
||||
state['images']['image_raw'] = init_image
|
||||
state['images']['image_show'] = Image.fromarray(
|
||||
add_watermark_np(np.array(init_image)))
|
||||
state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
|
||||
dtype=np.uint8)
|
||||
return global_state
|
||||
|
||||
|
||||
def update_image_draw(image, points, mask, show_mask, global_state=None):
|
||||
|
||||
image_draw = draw_points_on_image(image, points)
|
||||
if show_mask and mask is not None and not (mask == 0).all() and not (
|
||||
mask == 1).all():
|
||||
image_draw = draw_mask_on_image(image_draw, mask)
|
||||
|
||||
image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
|
||||
if global_state is not None:
|
||||
global_state['images']['image_show'] = image_draw
|
||||
return image_draw
|
||||
|
||||
|
||||
def preprocess_mask_info(global_state, image):
|
||||
"""Function to handle mask information.
|
||||
1. last_mask is None: Do not need to change mask, return mask
|
||||
2. last_mask is not None:
|
||||
2.1 global_state is remove_mask:
|
||||
2.2 global_state is add_mask:
|
||||
"""
|
||||
if isinstance(image, dict):
|
||||
last_mask = get_valid_mask(image['mask'])
|
||||
else:
|
||||
last_mask = None
|
||||
mask = global_state['mask']
|
||||
|
||||
# mask in global state is a placeholder with all 1.
|
||||
if (mask == 1).all():
|
||||
mask = last_mask
|
||||
|
||||
# last_mask = global_state['last_mask']
|
||||
editing_mode = global_state['editing_state']
|
||||
|
||||
if last_mask is None:
|
||||
return global_state
|
||||
|
||||
if editing_mode == 'remove_mask':
|
||||
updated_mask = np.clip(mask - last_mask, 0, 1)
|
||||
print(f'Last editing_state is {editing_mode}, do remove.')
|
||||
elif editing_mode == 'add_mask':
|
||||
updated_mask = np.clip(mask + last_mask, 0, 1)
|
||||
print(f'Last editing_state is {editing_mode}, do add.')
|
||||
else:
|
||||
updated_mask = mask
|
||||
print(f'Last editing_state is {editing_mode}, '
|
||||
'do nothing to mask.')
|
||||
|
||||
global_state['mask'] = updated_mask
|
||||
# global_state['last_mask'] = None # clear buffer
|
||||
return global_state
|
||||
|
||||
|
||||
valid_checkpoints_dict = {
|
||||
f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
|
||||
for f in os.listdir(cache_dir)
|
||||
if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
|
||||
}
|
||||
print(f'File under cache_dir ({cache_dir}):')
|
||||
print(os.listdir(cache_dir))
|
||||
print('Valid checkpoint file:')
|
||||
print(valid_checkpoints_dict)
|
||||
|
||||
init_pkl = 'stylegan_human_v2_512'
|
||||
|
||||
with gr.Blocks() as app:
|
||||
|
||||
# renderer = Renderer()
|
||||
global_state = gr.State({
|
||||
"images": {
|
||||
# image_orig: the original image, change with seed/model is changed
|
||||
# image_raw: image with mask and points, change durning optimization
|
||||
# image_show: image showed on screen
|
||||
},
|
||||
"temporal_params": {
|
||||
# stop
|
||||
},
|
||||
'mask':
|
||||
None, # mask for visualization, 1 for editing and 0 for unchange
|
||||
'last_mask': None, # last edited mask
|
||||
'show_mask': True, # add button
|
||||
"generator_params": dnnlib.EasyDict(),
|
||||
"params": {
|
||||
"seed": 0,
|
||||
"motion_lambda": 20,
|
||||
"r1_in_pixels": 3,
|
||||
"r2_in_pixels": 12,
|
||||
"magnitude_direction_in_pixels": 1.0,
|
||||
"latent_space": "w+",
|
||||
"trunc_psi": 0.7,
|
||||
"trunc_cutoff": None,
|
||||
"lr": 0.001,
|
||||
},
|
||||
"device": device,
|
||||
"draw_interval": 1,
|
||||
"renderer": Renderer(disable_timing=True),
|
||||
"points": {},
|
||||
"curr_point": None,
|
||||
"curr_type_point": "start",
|
||||
'editing_state': 'add_points',
|
||||
'pretrained_weight': init_pkl
|
||||
})
|
||||
|
||||
# init image
|
||||
global_state = init_images(global_state)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
with gr.Row():
|
||||
|
||||
# Left --> tools
|
||||
with gr.Column(scale=3):
|
||||
|
||||
# Pickle
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
gr.Markdown(value='Pickle', show_label=False)
|
||||
|
||||
with gr.Column(scale=4, min_width=10):
|
||||
form_pretrained_dropdown = gr.Dropdown(
|
||||
choices=list(valid_checkpoints_dict.keys()),
|
||||
label="Pretrained Model",
|
||||
value=init_pkl,
|
||||
)
|
||||
|
||||
# Latent
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
gr.Markdown(value='Latent', show_label=False)
|
||||
|
||||
with gr.Column(scale=4, min_width=10):
|
||||
form_seed_number = gr.Number(
|
||||
value=global_state.value['params']['seed'],
|
||||
interactive=True,
|
||||
label="Seed",
|
||||
)
|
||||
form_lr_number = gr.Number(
|
||||
value=global_state.value["params"]["lr"],
|
||||
interactive=True,
|
||||
label="Step Size")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2, min_width=10):
|
||||
form_reset_image = gr.Button("Reset Image")
|
||||
with gr.Column(scale=3, min_width=10):
|
||||
form_latent_space = gr.Radio(
|
||||
['w', 'w+'],
|
||||
value=global_state.value['params']
|
||||
['latent_space'],
|
||||
interactive=True,
|
||||
label='Latent space to optimize',
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
# Drag
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
gr.Markdown(value='Drag', show_label=False)
|
||||
with gr.Column(scale=4, min_width=10):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
enable_add_points = gr.Button('Add Points')
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
undo_points = gr.Button('Reset Points')
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
form_start_btn = gr.Button("Start")
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
form_stop_btn = gr.Button("Stop")
|
||||
|
||||
form_steps_number = gr.Number(value=0,
|
||||
label="Steps",
|
||||
interactive=False)
|
||||
|
||||
# Mask
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
gr.Markdown(value='Mask', show_label=False)
|
||||
with gr.Column(scale=4, min_width=10):
|
||||
enable_add_mask = gr.Button('Edit Flexible Area')
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
form_reset_mask_btn = gr.Button("Reset mask")
|
||||
with gr.Column(scale=1, min_width=10):
|
||||
show_mask = gr.Checkbox(
|
||||
label='Show Mask',
|
||||
value=global_state.value['show_mask'],
|
||||
show_label=False)
|
||||
|
||||
with gr.Row():
|
||||
form_lambda_number = gr.Number(
|
||||
value=global_state.value["params"]
|
||||
["motion_lambda"],
|
||||
interactive=True,
|
||||
label="Lambda",
|
||||
)
|
||||
|
||||
form_draw_interval_number = gr.Number(
|
||||
value=global_state.value["draw_interval"],
|
||||
label="Draw Interval (steps)",
|
||||
interactive=True,
|
||||
visible=False)
|
||||
|
||||
# Right --> Image
|
||||
with gr.Column(scale=8):
|
||||
form_image = ImageMask(
|
||||
value=global_state.value['images']['image_show'],
|
||||
brush_radius=20).style(
|
||||
width=768,
|
||||
height=768) # NOTE: hard image size code here.
|
||||
gr.Markdown("""
|
||||
## Quick Start
|
||||
|
||||
1. Select desired `Pretrained Model` and adjust `Seed` to generate an
|
||||
initial image.
|
||||
2. Click on image to add control points.
|
||||
3. Click `Start` and enjoy it!
|
||||
|
||||
## Advance Usage
|
||||
|
||||
1. Change `Step Size` to adjust learning rate in drag optimization.
|
||||
2. Select `w` or `w+` to change latent space to optimize:
|
||||
* Optimize on `w` space may cause greater influence to the image.
|
||||
* Optimize on `w+` space may work slower than `w`, but usually achieve
|
||||
better results.
|
||||
* Note that changing the latent space will reset the image, points and
|
||||
mask (this has the same effect as `Reset Image` button).
|
||||
3. Click `Edit Flexible Area` to create a mask and constrain the
|
||||
unmasked region to remain unchanged.
|
||||
""")
|
||||
gr.HTML("""
|
||||
<style>
|
||||
.container {
|
||||
position: absolute;
|
||||
height: 50px;
|
||||
text-align: center;
|
||||
line-height: 50px;
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
<div class="container">
|
||||
Gradio demo supported by
|
||||
<img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
|
||||
<a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
|
||||
</div>
|
||||
""")
|
||||
|
||||
# Network & latents tab listeners
|
||||
def on_change_pretrained_dropdown(pretrained_value, global_state):
|
||||
"""Function to handle model change.
|
||||
1. Set pretrained value to global_state
|
||||
2. Re-init images and clear all states
|
||||
"""
|
||||
|
||||
global_state['pretrained_weight'] = pretrained_value
|
||||
init_images(global_state)
|
||||
clear_state(global_state)
|
||||
|
||||
return global_state, global_state["images"]['image_show']
|
||||
|
||||
form_pretrained_dropdown.change(
|
||||
on_change_pretrained_dropdown,
|
||||
inputs=[form_pretrained_dropdown, global_state],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
def on_click_reset_image(global_state):
|
||||
"""Reset image to the original one and clear all states
|
||||
1. Re-init images
|
||||
2. Clear all states
|
||||
"""
|
||||
|
||||
init_images(global_state)
|
||||
clear_state(global_state)
|
||||
|
||||
return global_state, global_state['images']['image_show']
|
||||
|
||||
form_reset_image.click(
|
||||
on_click_reset_image,
|
||||
inputs=[global_state],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
# Update parameters
|
||||
def on_change_update_image_seed(seed, global_state):
|
||||
"""Function to handle generation seed change.
|
||||
1. Set seed to global_state
|
||||
2. Re-init images and clear all states
|
||||
"""
|
||||
|
||||
global_state["params"]["seed"] = int(seed)
|
||||
init_images(global_state)
|
||||
clear_state(global_state)
|
||||
|
||||
return global_state, global_state['images']['image_show']
|
||||
|
||||
form_seed_number.change(
|
||||
on_change_update_image_seed,
|
||||
inputs=[form_seed_number, global_state],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
def on_click_latent_space(latent_space, global_state):
|
||||
"""Function to reset latent space to optimize.
|
||||
NOTE: this function we reset the image and all controls
|
||||
1. Set latent-space to global_state
|
||||
2. Re-init images and clear all state
|
||||
"""
|
||||
|
||||
global_state['params']['latent_space'] = latent_space
|
||||
init_images(global_state)
|
||||
clear_state(global_state)
|
||||
|
||||
return global_state, global_state['images']['image_show']
|
||||
|
||||
form_latent_space.change(on_click_latent_space,
|
||||
inputs=[form_latent_space, global_state],
|
||||
outputs=[global_state, form_image])
|
||||
|
||||
# ==== Params
|
||||
form_lambda_number.change(
|
||||
partial(on_change_single_global_state, ["params", "motion_lambda"]),
|
||||
inputs=[form_lambda_number, global_state],
|
||||
outputs=[global_state],
|
||||
)
|
||||
|
||||
def on_change_lr(lr, global_state):
|
||||
if lr == 0:
|
||||
print('lr is 0, do nothing.')
|
||||
return global_state
|
||||
else:
|
||||
global_state["params"]["lr"] = lr
|
||||
renderer = global_state['renderer']
|
||||
renderer.update_lr(lr)
|
||||
print('New optimizer: ')
|
||||
print(renderer.w_optim)
|
||||
return global_state
|
||||
|
||||
form_lr_number.change(
|
||||
on_change_lr,
|
||||
inputs=[form_lr_number, global_state],
|
||||
outputs=[global_state],
|
||||
)
|
||||
|
||||
def on_click_start(global_state, image):
|
||||
p_in_pixels = []
|
||||
t_in_pixels = []
|
||||
valid_points = []
|
||||
|
||||
# handle of start drag in mask editing mode
|
||||
global_state = preprocess_mask_info(global_state, image)
|
||||
|
||||
# Prepare the points for the inference
|
||||
if len(global_state["points"]) == 0:
|
||||
# yield on_click_start_wo_points(global_state, image)
|
||||
image_raw = global_state['images']['image_raw']
|
||||
update_image_draw(
|
||||
image_raw,
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'],
|
||||
global_state,
|
||||
)
|
||||
|
||||
yield (
|
||||
global_state,
|
||||
0,
|
||||
global_state['images']['image_show'],
|
||||
# gr.File.update(visible=False),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
# latent space
|
||||
gr.Radio.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
# NOTE: disable stop button
|
||||
gr.Button.update(interactive=False),
|
||||
|
||||
# update other comps
|
||||
gr.Dropdown.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Checkbox.update(interactive=True),
|
||||
# gr.Number.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
)
|
||||
else:
|
||||
|
||||
# Transform the points into torch tensors
|
||||
for key_point, point in global_state["points"].items():
|
||||
try:
|
||||
p_start = point.get("start_temp", point["start"])
|
||||
p_end = point["target"]
|
||||
|
||||
if p_start is None or p_end is None:
|
||||
continue
|
||||
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
p_in_pixels.append(p_start)
|
||||
t_in_pixels.append(p_end)
|
||||
valid_points.append(key_point)
|
||||
|
||||
mask = torch.tensor(global_state['mask']).float()
|
||||
drag_mask = 1 - mask
|
||||
|
||||
renderer: Renderer = global_state["renderer"]
|
||||
global_state['temporal_params']['stop'] = False
|
||||
global_state['editing_state'] = 'running'
|
||||
|
||||
# reverse points order
|
||||
p_to_opt = reverse_point_pairs(p_in_pixels)
|
||||
t_to_opt = reverse_point_pairs(t_in_pixels)
|
||||
print('Running with:')
|
||||
print(f' Source: {p_in_pixels}')
|
||||
print(f' Target: {t_in_pixels}')
|
||||
step_idx = 0
|
||||
while True:
|
||||
if global_state["temporal_params"]["stop"]:
|
||||
break
|
||||
|
||||
# do drage here!
|
||||
renderer._render_drag_impl(
|
||||
global_state['generator_params'],
|
||||
p_to_opt, # point
|
||||
t_to_opt, # target
|
||||
drag_mask, # mask,
|
||||
global_state['params']['motion_lambda'], # lambda_mask
|
||||
reg=0,
|
||||
feature_idx=5, # NOTE: do not support change for now
|
||||
r1=global_state['params']['r1_in_pixels'], # r1
|
||||
r2=global_state['params']['r2_in_pixels'], # r2
|
||||
# random_seed = 0,
|
||||
# noise_mode = 'const',
|
||||
trunc_psi=global_state['params']['trunc_psi'],
|
||||
# force_fp32 = False,
|
||||
# layer_name = None,
|
||||
# sel_channels = 3,
|
||||
# base_channel = 0,
|
||||
# img_scale_db = 0,
|
||||
# img_normalize = False,
|
||||
# untransform = False,
|
||||
is_drag=True,
|
||||
to_pil=True)
|
||||
|
||||
if step_idx % global_state['draw_interval'] == 0:
|
||||
print('Current Source:')
|
||||
for key_point, p_i, t_i in zip(valid_points, p_to_opt,
|
||||
t_to_opt):
|
||||
global_state["points"][key_point]["start_temp"] = [
|
||||
p_i[1],
|
||||
p_i[0],
|
||||
]
|
||||
global_state["points"][key_point]["target"] = [
|
||||
t_i[1],
|
||||
t_i[0],
|
||||
]
|
||||
start_temp = global_state["points"][key_point][
|
||||
"start_temp"]
|
||||
print(f' {start_temp}')
|
||||
|
||||
image_result = global_state['generator_params']['image']
|
||||
image_draw = update_image_draw(
|
||||
image_result,
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'],
|
||||
global_state,
|
||||
)
|
||||
global_state['images']['image_raw'] = image_result
|
||||
|
||||
yield (
|
||||
global_state,
|
||||
step_idx,
|
||||
global_state['images']['image_show'],
|
||||
# gr.File.update(visible=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
# latent space
|
||||
gr.Radio.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
# enable stop button in loop
|
||||
gr.Button.update(interactive=True),
|
||||
|
||||
# update other comps
|
||||
gr.Dropdown.update(interactive=False),
|
||||
gr.Number.update(interactive=False),
|
||||
gr.Number.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Button.update(interactive=False),
|
||||
gr.Checkbox.update(interactive=False),
|
||||
# gr.Number.update(interactive=False),
|
||||
gr.Number.update(interactive=False),
|
||||
)
|
||||
|
||||
# increate step
|
||||
step_idx += 1
|
||||
|
||||
image_result = global_state['generator_params']['image']
|
||||
global_state['images']['image_raw'] = image_result
|
||||
image_draw = update_image_draw(image_result,
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'],
|
||||
global_state)
|
||||
|
||||
# fp = NamedTemporaryFile(suffix=".png", delete=False)
|
||||
# image_result.save(fp, "PNG")
|
||||
|
||||
global_state['editing_state'] = 'add_points'
|
||||
|
||||
yield (
|
||||
global_state,
|
||||
0, # reset step to 0 after stop.
|
||||
global_state['images']['image_show'],
|
||||
# gr.File.update(visible=True, value=fp.name),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
# latent space
|
||||
gr.Radio.update(interactive=True),
|
||||
gr.Button.update(interactive=True),
|
||||
# NOTE: disable stop button with loop finish
|
||||
gr.Button.update(interactive=False),
|
||||
|
||||
# update other comps
|
||||
gr.Dropdown.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
gr.Checkbox.update(interactive=True),
|
||||
gr.Number.update(interactive=True),
|
||||
)
|
||||
|
||||
form_start_btn.click(
|
||||
on_click_start,
|
||||
inputs=[global_state, form_image],
|
||||
outputs=[
|
||||
global_state,
|
||||
form_steps_number,
|
||||
form_image,
|
||||
# form_download_result_file,
|
||||
# >>> buttons
|
||||
form_reset_image,
|
||||
enable_add_points,
|
||||
enable_add_mask,
|
||||
undo_points,
|
||||
form_reset_mask_btn,
|
||||
form_latent_space,
|
||||
form_start_btn,
|
||||
form_stop_btn,
|
||||
# <<< buttonm
|
||||
# >>> inputs comps
|
||||
form_pretrained_dropdown,
|
||||
form_seed_number,
|
||||
form_lr_number,
|
||||
show_mask,
|
||||
form_lambda_number,
|
||||
],
|
||||
)
|
||||
|
||||
def on_click_stop(global_state):
|
||||
"""Function to handle stop button is clicked.
|
||||
1. send a stop signal by set global_state["temporal_params"]["stop"] as True
|
||||
2. Disable Stop button
|
||||
"""
|
||||
global_state["temporal_params"]["stop"] = True
|
||||
|
||||
return global_state, gr.Button.update(interactive=False)
|
||||
|
||||
form_stop_btn.click(on_click_stop,
|
||||
inputs=[global_state],
|
||||
outputs=[global_state, form_stop_btn])
|
||||
|
||||
form_draw_interval_number.change(
|
||||
partial(
|
||||
on_change_single_global_state,
|
||||
"draw_interval",
|
||||
map_transform=lambda x: int(x),
|
||||
),
|
||||
inputs=[form_draw_interval_number, global_state],
|
||||
outputs=[global_state],
|
||||
)
|
||||
|
||||
def on_click_remove_point(global_state):
|
||||
choice = global_state["curr_point"]
|
||||
del global_state["points"][choice]
|
||||
|
||||
choices = list(global_state["points"].keys())
|
||||
|
||||
if len(choices) > 0:
|
||||
global_state["curr_point"] = choices[0]
|
||||
|
||||
return (
|
||||
gr.Dropdown.update(choices=choices, value=choices[0]),
|
||||
global_state,
|
||||
)
|
||||
|
||||
# Mask
|
||||
def on_click_reset_mask(global_state):
|
||||
global_state['mask'] = np.ones(
|
||||
(
|
||||
global_state["images"]["image_raw"].size[1],
|
||||
global_state["images"]["image_raw"].size[0],
|
||||
),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
image_draw = update_image_draw(global_state['images']['image_raw'],
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'], global_state)
|
||||
return global_state, image_draw
|
||||
|
||||
form_reset_mask_btn.click(
|
||||
on_click_reset_mask,
|
||||
inputs=[global_state],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
# Image
|
||||
def on_click_enable_draw(global_state, image):
|
||||
"""Function to start add mask mode.
|
||||
1. Preprocess mask info from last state
|
||||
2. Change editing state to add_mask
|
||||
3. Set curr image with points and mask
|
||||
"""
|
||||
global_state = preprocess_mask_info(global_state, image)
|
||||
global_state['editing_state'] = 'add_mask'
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(image_raw, global_state['points'],
|
||||
global_state['mask'], True,
|
||||
global_state)
|
||||
return (global_state,
|
||||
gr.Image.update(value=image_draw, interactive=True))
|
||||
|
||||
def on_click_remove_draw(global_state, image):
|
||||
"""Function to start remove mask mode.
|
||||
1. Preprocess mask info from last state
|
||||
2. Change editing state to remove_mask
|
||||
3. Set curr image with points and mask
|
||||
"""
|
||||
global_state = preprocess_mask_info(global_state, image)
|
||||
global_state['edinting_state'] = 'remove_mask'
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(image_raw, global_state['points'],
|
||||
global_state['mask'], True,
|
||||
global_state)
|
||||
return (global_state,
|
||||
gr.Image.update(value=image_draw, interactive=True))
|
||||
|
||||
enable_add_mask.click(on_click_enable_draw,
|
||||
inputs=[global_state, form_image],
|
||||
outputs=[
|
||||
global_state,
|
||||
form_image,
|
||||
])
|
||||
|
||||
def on_click_add_point(global_state, image: dict):
|
||||
"""Function switch from add mask mode to add points mode.
|
||||
1. Updaste mask buffer if need
|
||||
2. Change global_state['editing_state'] to 'add_points'
|
||||
3. Set current image with mask
|
||||
"""
|
||||
|
||||
global_state = preprocess_mask_info(global_state, image)
|
||||
global_state['editing_state'] = 'add_points'
|
||||
mask = global_state['mask']
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(image_raw, global_state['points'], mask,
|
||||
global_state['show_mask'], global_state)
|
||||
|
||||
return (global_state,
|
||||
gr.Image.update(value=image_draw, interactive=False))
|
||||
|
||||
enable_add_points.click(on_click_add_point,
|
||||
inputs=[global_state, form_image],
|
||||
outputs=[global_state, form_image])
|
||||
|
||||
def on_click_image(global_state, evt: gr.SelectData):
|
||||
"""This function only support click for point selection
|
||||
"""
|
||||
xy = evt.index
|
||||
if global_state['editing_state'] != 'add_points':
|
||||
print(f'In {global_state["editing_state"]} state. '
|
||||
'Do not add points.')
|
||||
|
||||
return global_state, global_state['images']['image_show']
|
||||
|
||||
points = global_state["points"]
|
||||
|
||||
point_idx = get_latest_points_pair(points)
|
||||
if point_idx is None:
|
||||
points[0] = {'start': xy, 'target': None}
|
||||
print(f'Click Image - Start - {xy}')
|
||||
elif points[point_idx].get('target', None) is None:
|
||||
points[point_idx]['target'] = xy
|
||||
print(f'Click Image - Target - {xy}')
|
||||
else:
|
||||
points[point_idx + 1] = {'start': xy, 'target': None}
|
||||
print(f'Click Image - Start - {xy}')
|
||||
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(
|
||||
image_raw,
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'],
|
||||
global_state,
|
||||
)
|
||||
|
||||
return global_state, image_draw
|
||||
|
||||
form_image.select(
|
||||
on_click_image,
|
||||
inputs=[global_state],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
def on_click_clear_points(global_state):
|
||||
"""Function to handle clear all control points
|
||||
1. clear global_state['points'] (clear_state)
|
||||
2. re-init network
|
||||
2. re-draw image
|
||||
"""
|
||||
clear_state(global_state, target='point')
|
||||
|
||||
renderer: Renderer = global_state["renderer"]
|
||||
renderer.feat_refs = None
|
||||
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(image_raw, {}, global_state['mask'],
|
||||
global_state['show_mask'], global_state)
|
||||
return global_state, image_draw
|
||||
|
||||
undo_points.click(on_click_clear_points,
|
||||
inputs=[global_state],
|
||||
outputs=[global_state, form_image])
|
||||
|
||||
def on_click_show_mask(global_state, show_mask):
|
||||
"""Function to control whether show mask on image."""
|
||||
global_state['show_mask'] = show_mask
|
||||
|
||||
image_raw = global_state['images']['image_raw']
|
||||
image_draw = update_image_draw(
|
||||
image_raw,
|
||||
global_state['points'],
|
||||
global_state['mask'],
|
||||
global_state['show_mask'],
|
||||
global_state,
|
||||
)
|
||||
return global_state, image_draw
|
||||
|
||||
show_mask.change(
|
||||
on_click_show_mask,
|
||||
inputs=[global_state, show_mask],
|
||||
outputs=[global_state, form_image],
|
||||
)
|
||||
|
||||
gr.close_all()
|
||||
app.queue(concurrency_count=5, max_size=20)
|
||||
app.launch(share=args.share)
|
||||
Loading…
Reference in New Issue