mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.7 KiB
Python
155 lines
4.7 KiB
Python
import gradio as gr
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw
|
|
|
|
|
|
class ImageMask(gr.components.Image):
|
|
"""
|
|
Sets: source="canvas", tool="sketch"
|
|
"""
|
|
|
|
is_template = True
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(source="upload",
|
|
tool="sketch",
|
|
interactive=False,
|
|
**kwargs)
|
|
|
|
def preprocess(self, x):
|
|
if x is None:
|
|
return x
|
|
if self.tool == "sketch" and self.source in ["upload", "webcam"
|
|
] and type(x) != dict:
|
|
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
|
width, height = decode_image.size
|
|
mask = np.ones((height, width, 4), dtype=np.uint8)
|
|
mask[..., -1] = 255
|
|
mask = self.postprocess(mask)
|
|
x = {'image': x, 'mask': mask}
|
|
return super().preprocess(x)
|
|
|
|
|
|
def get_valid_mask(mask: np.ndarray):
|
|
"""Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
|
|
"""
|
|
if mask.ndim == 3:
|
|
mask_pil = Image.fromarray(mask).convert('L')
|
|
mask = np.array(mask_pil)
|
|
if mask.max() == 255:
|
|
mask = mask / 255
|
|
return mask
|
|
|
|
|
|
def draw_points_on_image(image,
|
|
points,
|
|
curr_point=None,
|
|
highlight_all=True,
|
|
radius_scale=0.01):
|
|
overlay_rgba = Image.new("RGBA", image.size, 0)
|
|
overlay_draw = ImageDraw.Draw(overlay_rgba)
|
|
for point_key, point in points.items():
|
|
if ((curr_point is not None and curr_point == point_key)
|
|
or highlight_all):
|
|
p_color = (255, 0, 0)
|
|
t_color = (0, 0, 255)
|
|
|
|
else:
|
|
p_color = (255, 0, 0, 35)
|
|
t_color = (0, 0, 255, 35)
|
|
|
|
rad_draw = int(image.size[0] * radius_scale)
|
|
|
|
p_start = point.get("start_temp", point["start"])
|
|
p_target = point["target"]
|
|
|
|
if p_start is not None and p_target is not None:
|
|
p_draw = int(p_start[0]), int(p_start[1])
|
|
t_draw = int(p_target[0]), int(p_target[1])
|
|
|
|
overlay_draw.line(
|
|
(p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
|
|
fill=(255, 255, 0),
|
|
width=2,
|
|
)
|
|
|
|
if p_start is not None:
|
|
p_draw = int(p_start[0]), int(p_start[1])
|
|
overlay_draw.ellipse(
|
|
(
|
|
p_draw[0] - rad_draw,
|
|
p_draw[1] - rad_draw,
|
|
p_draw[0] + rad_draw,
|
|
p_draw[1] + rad_draw,
|
|
),
|
|
fill=p_color,
|
|
)
|
|
|
|
if curr_point is not None and curr_point == point_key:
|
|
# overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
|
|
overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
|
|
|
|
if p_target is not None:
|
|
t_draw = int(p_target[0]), int(p_target[1])
|
|
overlay_draw.ellipse(
|
|
(
|
|
t_draw[0] - rad_draw,
|
|
t_draw[1] - rad_draw,
|
|
t_draw[0] + rad_draw,
|
|
t_draw[1] + rad_draw,
|
|
),
|
|
fill=t_color,
|
|
)
|
|
|
|
if curr_point is not None and curr_point == point_key:
|
|
# overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
|
|
overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
|
|
|
|
return Image.alpha_composite(image.convert("RGBA"),
|
|
overlay_rgba).convert("RGB")
|
|
|
|
|
|
def draw_mask_on_image(image, mask):
|
|
im_mask = np.uint8(mask * 255)
|
|
im_mask_rgba = np.concatenate(
|
|
(
|
|
np.tile(im_mask[..., None], [1, 1, 3]),
|
|
45 * np.ones(
|
|
(im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
|
|
),
|
|
axis=-1,
|
|
)
|
|
im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
|
|
|
|
return Image.alpha_composite(image.convert("RGBA"),
|
|
im_mask_rgba).convert("RGB")
|
|
|
|
|
|
def on_change_single_global_state(keys,
|
|
value,
|
|
global_state,
|
|
map_transform=None):
|
|
if map_transform is not None:
|
|
value = map_transform(value)
|
|
|
|
curr_state = global_state
|
|
if isinstance(keys, str):
|
|
last_key = keys
|
|
|
|
else:
|
|
for k in keys[:-1]:
|
|
curr_state = curr_state[k]
|
|
|
|
last_key = keys[-1]
|
|
|
|
curr_state[last_key] = value
|
|
return global_state
|
|
|
|
|
|
def get_latest_points_pair(points_dict):
|
|
if not points_dict:
|
|
return None
|
|
point_idx = list(points_dict.keys())
|
|
latest_point_idx = max(point_idx)
|
|
return latest_point_idx
|