mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
6.0 KiB
Python
169 lines
6.0 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
import imgui
|
|
import dnnlib
|
|
from gui_utils import imgui_utils
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
class DragWidget:
|
|
def __init__(self, viz):
|
|
self.viz = viz
|
|
self.point = [-1, -1]
|
|
self.points = []
|
|
self.targets = []
|
|
self.is_point = True
|
|
self.last_click = False
|
|
self.is_drag = False
|
|
self.iteration = 0
|
|
self.mode = 'point'
|
|
self.r_mask = 50
|
|
self.show_mask = False
|
|
self.mask = torch.ones(256, 256)
|
|
self.lambda_mask = 20
|
|
self.feature_idx = 5
|
|
self.r1 = 3
|
|
self.r2 = 12
|
|
self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots'))
|
|
self.defer_frames = 0
|
|
self.disabled_time = 0
|
|
|
|
def action(self, click, down, x, y):
|
|
if self.mode == 'point':
|
|
self.add_point(click, x, y)
|
|
elif down:
|
|
self.draw_mask(x, y)
|
|
|
|
def add_point(self, click, x, y):
|
|
if click:
|
|
self.point = [y, x]
|
|
elif self.last_click:
|
|
if self.is_drag:
|
|
self.stop_drag()
|
|
if self.is_point:
|
|
self.points.append(self.point)
|
|
self.is_point = False
|
|
else:
|
|
self.targets.append(self.point)
|
|
self.is_point = True
|
|
self.last_click = click
|
|
|
|
def init_mask(self, w, h):
|
|
self.width, self.height = w, h
|
|
self.mask = torch.ones(h, w)
|
|
|
|
def draw_mask(self, x, y):
|
|
X = torch.linspace(0, self.width, self.width)
|
|
Y = torch.linspace(0, self.height, self.height)
|
|
yy, xx = torch.meshgrid(Y, X)
|
|
circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2
|
|
if self.mode == 'flexible':
|
|
self.mask[circle] = 0
|
|
elif self.mode == 'fixed':
|
|
self.mask[circle] = 1
|
|
|
|
def stop_drag(self):
|
|
self.is_drag = False
|
|
self.iteration = 0
|
|
|
|
def set_points(self, points):
|
|
self.points = points
|
|
|
|
def reset_point(self):
|
|
self.points = []
|
|
self.targets = []
|
|
self.is_point = True
|
|
|
|
def load_points(self, suffix):
|
|
points = []
|
|
point_path = self.path + f'_{suffix}.txt'
|
|
try:
|
|
with open(point_path, "r") as f:
|
|
for line in f.readlines():
|
|
y, x = line.split()
|
|
points.append([int(y), int(x)])
|
|
except:
|
|
print(f'Wrong point file path: {point_path}')
|
|
return points
|
|
|
|
@imgui_utils.scoped_by_object_id
|
|
def __call__(self, show=True):
|
|
viz = self.viz
|
|
reset = False
|
|
if show:
|
|
with imgui_utils.grayed_out(self.disabled_time != 0):
|
|
imgui.text('Drag')
|
|
imgui.same_line(viz.label_w)
|
|
|
|
if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result):
|
|
self.mode = 'point'
|
|
|
|
imgui.same_line()
|
|
reset = False
|
|
if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result):
|
|
self.reset_point()
|
|
reset = True
|
|
|
|
imgui.text(' ')
|
|
imgui.same_line(viz.label_w)
|
|
if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result):
|
|
self.is_drag = True
|
|
if len(self.points) > len(self.targets):
|
|
self.points = self.points[:len(self.targets)]
|
|
|
|
imgui.same_line()
|
|
if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result):
|
|
self.stop_drag()
|
|
|
|
imgui.text(' ')
|
|
imgui.same_line(viz.label_w)
|
|
imgui.text(f'Steps: {self.iteration}')
|
|
|
|
imgui.text('Mask')
|
|
imgui.same_line(viz.label_w)
|
|
if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result):
|
|
self.mode = 'flexible'
|
|
self.show_mask = True
|
|
|
|
imgui.same_line()
|
|
if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result):
|
|
self.mode = 'fixed'
|
|
self.show_mask = True
|
|
|
|
imgui.text(' ')
|
|
imgui.same_line(viz.label_w)
|
|
if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result):
|
|
self.mask = torch.ones(self.height, self.width)
|
|
imgui.same_line()
|
|
_clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask)
|
|
|
|
imgui.text(' ')
|
|
imgui.same_line(viz.label_w)
|
|
with imgui_utils.item_width(viz.font_size * 6):
|
|
changed, self.r_mask = imgui.input_int('Radius', self.r_mask)
|
|
|
|
imgui.text(' ')
|
|
imgui.same_line(viz.label_w)
|
|
with imgui_utils.item_width(viz.font_size * 6):
|
|
changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask)
|
|
|
|
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
|
|
if self.defer_frames > 0:
|
|
self.defer_frames -= 1
|
|
viz.args.is_drag = self.is_drag
|
|
if self.is_drag:
|
|
self.iteration += 1
|
|
viz.args.iteration = self.iteration
|
|
viz.args.points = [point for point in self.points]
|
|
viz.args.targets = [point for point in self.targets]
|
|
viz.args.mask = self.mask
|
|
viz.args.lambda_mask = self.lambda_mask
|
|
viz.args.feature_idx = self.feature_idx
|
|
viz.args.r1 = self.r1
|
|
viz.args.r2 = self.r2
|
|
viz.args.reset = reset
|
|
|
|
|
|
#----------------------------------------------------------------------------
|