pull/142/merge
Tianhao Xie 2 years ago committed by GitHub
commit 417058bd37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
.gitignore vendored

@ -4,6 +4,7 @@
__pycache__/
*.py[cod]
*$py.class
*.pyc
# C extensions
*.so

@ -0,0 +1,53 @@
import torch
from inversion import inverse_image,get_lr
from tqdm import tqdm
from torch.nn import functional as F
from lpips import util
def toogle_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
class PTI:
def __init__(self,G,l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
self.g_ema = G
self.l2_lambda = l2_lambda
self.max_pti_step = max_pti_step
self.pti_lr = pti_lr
def cacl_loss(self,percept, generated_image,real_image):
mse_loss = F.mse_loss(generated_image, real_image)
p_loss = percept(generated_image, real_image).sum()
loss = p_loss +self.l2_lambda * mse_loss
return loss
def train(self,img):
inversed_result = inverse_image(self.g_ema,img,self.g_ema.img_resolution)
w_pivot = inversed_result['latent']
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
toogle_grad(self.g_ema,True)
percept = util.PerceptualLoss(
model="net-lin", net="vgg", use_gpu='cuda:0'
)
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
print('start PTI')
pbar = tqdm(range(self.max_pti_step))
for i in pbar:
lr = get_lr(i, self.pti_lr)
optimizer.param_groups[0]["lr"] = lr
generated_image,feature = self.g_ema.synthesis(ws,noise_mode='const')
loss = self.cacl_loss(percept,generated_image,inversed_result['real'])
pbar.set_description(
(
f"loss: {loss.item():.4f}"
)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
generated_image = self.g_ema.synthesis(ws, noise_mode='const')
return generated_image

@ -0,0 +1,9 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty

@ -0,0 +1,343 @@
import math
import os
from viz import renderer
import torch
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import dataclasses
import dnnlib
from .lpips import util
import imageio
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
@dataclasses.dataclass
class InverseConfig:
lr_warmup = 0.05
lr_decay = 0.25
lr = 0.1
noise = 0.05
noise_decay = 0.75
step = 1000
noise_regularize = 1e5
mse = 0.1
def inverse_image(
g_ema,
image,
percept,
image_size=256,
w_plus = False,
config=InverseConfig(),
device='cuda:0'
):
args = config
n_mean_latent = 10000
resize = min(image_size, 256)
if torch.is_tensor(image)==False:
transform = transforms.Compose(
[
transforms.Resize(resize,),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(image)
else:
img = transforms.functional.resize(image,resize)
transform = transforms.Compose(
[
transforms.CenterCrop(resize),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(img)
imgs = []
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
with torch.no_grad():
#noise_sample = torch.randn(n_mean_latent, 512, device=device)
noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device)
#label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device)
w_samples = g_ema.mapping(noise_sample,None)
w_samples = w_samples[:, :1, :]
w_avg = w_samples.mean(0)
w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5
noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name}
for noise in noises.values():
noise = torch.randn_like(noise)
noise.requires_grad = True
w_opt = w_avg.detach().clone()
if w_plus:
w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1)
w_opt.requires_grad = True
#if args.w_plus:
#latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr)
pbar = tqdm(range(args.step))
latent_path = []
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
w_noise = torch.randn_like(w_opt) * noise_strength
if w_plus:
ws = w_opt + w_noise
else:
ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1])
img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True)
#latent_n = latent_noise(latent_in, noise_strength.item())
#latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
#img_gen, F = g_ema.generate(latent, noise)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
if img_gen.shape[2] > 256:
img_gen = F.interpolate(img_gen, size=(256, 256), mode='area')
p_loss = percept(img_gen,imgs)
# Noise regularization.
reg_loss = 0.0
for v in noises.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Normalize noise.
with torch.no_grad():
for buf in noises.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
if (i + 1) % 100 == 0:
latent_path.append(w_opt.detach().clone())
pbar.set_description(
(
f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};"
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
)
)
#latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
#img_gen, F = g_ema.generate(latent, noise)
if w_plus:
ws = latent_path[-1]
else:
ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1])
img_gen = g_ema.synthesis(ws, noise_mode='const')
result = {
"latent": latent_path[-1],
"sample": img_gen,
"real": imgs,
}
return result
def toogle_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
class PTI:
def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
self.g_ema = G
self.l2_lambda = l2_lambda
self.max_pti_step = max_pti_step
self.pti_lr = pti_lr
self.percept = percept
def cacl_loss(self,percept, generated_image,real_image):
mse_loss = F.mse_loss(generated_image, real_image)
p_loss = percept(generated_image, real_image).sum()
loss = p_loss +self.l2_lambda * mse_loss
return loss
def train(self,img,w_plus=False):
if torch.is_tensor(img) == False:
transform = transforms.Compose(
[
transforms.Resize(self.g_ema.img_resolution, ),
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to('cuda').unsqueeze(0)
else:
img = transforms.functional.resize(img, self.g_ema.img_resolution)
transform = transforms.Compose(
[
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to('cuda').unsqueeze(0)
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
w_pivot = inversed_result['latent']
if w_plus:
ws = w_pivot
else:
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
toogle_grad(self.g_ema,True)
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
print('start PTI')
pbar = tqdm(range(self.max_pti_step))
for i in pbar:
t = i / self.max_pti_step
lr = get_lr(t, self.pti_lr)
optimizer.param_groups[0]["lr"] = lr
generated_image = self.g_ema.synthesis(ws,noise_mode='const')
loss = self.cacl_loss(self.percept,generated_image,real_img)
pbar.set_description(
(
f"loss: {loss.item():.4f}"
)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
generated_image = self.g_ema.synthesis(ws, noise_mode='const')
return generated_image,ws
if __name__ == "__main__":
state = {
"images": {
# image_orig: the original image, change with seed/model is changed
# image_raw: image with mask and points, change durning optimization
# image_show: image showed on screen
},
"temporal_params": {
# stop
},
'mask':
None, # mask for visualization, 1 for editing and 0 for unchange
'last_mask': None, # last edited mask
'show_mask': True, # add button
"generator_params": dnnlib.EasyDict(),
"params": {
"seed": 0,
"motion_lambda": 20,
"r1_in_pixels": 3,
"r2_in_pixels": 12,
"magnitude_direction_in_pixels": 1.0,
"latent_space": "w+",
"trunc_psi": 0.7,
"trunc_cutoff": None,
"lr": 0.001,
},
"device": 'cuda:0',
"draw_interval": 1,
"renderer": renderer.Renderer(disable_timing=True),
"points": {},
"curr_point": None,
"curr_type_point": "start",
'editing_state': 'add_points',
'pretrained_weight': 'stylegan2_horses_256_pytorch'
}
cache_dir = '../checkpoints'
valid_checkpoints_dict = {
f.split('/')[-1].split('.')[0]: os.path.join(cache_dir, f)
for f in os.listdir(cache_dir)
if (f.endswith('pkl') and os.path.exists(os.path.join(cache_dir, f)))
}
state['renderer'].init_network(state['generator_params'], # res
valid_checkpoints_dict[state['pretrained_weight']], # pkl
state['params']['seed'], # w0_seed,
None, # w_load
state['params']['latent_space'] == 'w+', # w_plus
'const',
state['params']['trunc_psi'], # trunc_psi,
state['params']['trunc_cutoff'], # trunc_cutoff,
None, # input_transform
state['params']['lr'] # lr
)
image = Image.open('/home/tianhao/research/drag3d/horse/render/0.png')
G = state['renderer'].G
#result = inverse_image(G,image,G.img_resolution)
percept = util.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=True
)
pti = PTI(G,percept)
result = pti.train(image,True)
imageio.imsave('../horse/test.png', make_image(result[0])[0])

@ -0,0 +1,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

@ -0,0 +1,58 @@
import os
import numpy as np
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed
class BaseModel():
def __init__(self):
pass;
def name(self):
return 'BaseModel'
def initialize(self, use_gpu=True, gpu_ids=[0]):
self.use_gpu = use_gpu
self.gpu_ids = gpu_ids
def forward(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))
def update_learning_rate():
pass
def get_image_paths(self):
return self.image_paths
def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

@ -0,0 +1,314 @@
from __future__ import absolute_import
import sys
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
from .base_model import BaseModel
from scipy.ndimage import zoom
import fractions
import functools
import skimage.transform
from tqdm import tqdm
import urllib
from IPython import embed
from . import networks_basic as networks
from . import util
class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def get_path(base_path):
BASE_DIR = os.path.join('checkpoints')
save_path = os.path.join(BASE_DIR, base_path)
if not os.path.exists(save_path):
url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
print(f'{base_path} not found')
print('Try to download from huggingface: ', url)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
download_url(url, save_path)
print('Downloaded to ', save_path)
return save_path
def download_url(url, output_path):
with DownloadProgressBar(unit='B', unit_scale=True,
miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
class DistModel(BaseModel):
def name(self):
return self.model_name
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
use_gpu=True, printNet=False, spatial=False,
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
'''
INPUTS
model - ['net-lin'] for linearly calibrated network
['net'] for off-the-shelf network
['L2'] for L2 distance in Lab colorspace
['SSIM'] for ssim in RGB colorspace
net - ['squeeze','alex','vgg']
model_path - if None, will look in weights/[NET_NAME].pth
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
use_gpu - bool - whether or not to use a GPU
printNet - bool - whether or not to print network architecture out
spatial - bool - whether to output an array containing varying distances across spatial dimensions
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
is_train - bool - [True] for training mode
lr - float - initial learning rate
beta1 - float - initial momentum term for adam
version - 0.1 for latest, 0.0 was original (with a bug)
gpu_ids - int array - [0] by default, gpus to use
'''
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
self.model = model
self.net = net
self.is_train = is_train
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model_name = '%s [%s]' % (model, net)
if(self.model == 'net-lin'): # pretrained net + linear layer
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
use_dropout=True, spatial=spatial, version=version, lpips=True)
kw = {}
if not use_gpu:
kw['map_location'] = 'cpu'
if(model_path is None):
model_path = get_path('weights/v%s/%s.pth' % (version, net))
if(not is_train):
print('Loading model from: %s' % model_path)
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
elif(self.model == 'net'): # pretrained network
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
elif(self.model in ['L2', 'l2']):
self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
self.model_name = 'L2'
elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
self.model_name = 'SSIM'
else:
raise ValueError("Model [%s] not recognized." % self.model)
self.parameters = list(self.net.parameters())
if self.is_train: # training mode
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
self.rankLoss = networks.BCERankingLoss()
self.parameters += list(self.rankLoss.net.parameters())
self.lr = lr
self.old_lr = lr
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
else: # test mode
self.net.eval()
if(use_gpu):
self.net.to(gpu_ids[0])
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
if(self.is_train):
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
if(printNet):
print('---------- Networks initialized -------------')
networks.print_network(self.net)
print('-----------------------------------------------')
def forward(self, in0, in1, retPerLayer=False):
''' Function computes the distance between image patches in0 and in1
INPUTS
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
OUTPUT
computed distances between in0 and in1
'''
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
# ***** TRAINING FUNCTIONS *****
def optimize_parameters(self):
self.forward_train()
self.optimizer_net.zero_grad()
self.backward_train()
self.optimizer_net.step()
self.clamp_weights()
def clamp_weights(self):
for module in self.net.modules():
if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
module.weight.data = torch.clamp(module.weight.data, min=0)
def set_input(self, data):
self.input_ref = data['ref']
self.input_p0 = data['p0']
self.input_p1 = data['p1']
self.input_judge = data['judge']
if(self.use_gpu):
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
self.var_ref = Variable(self.input_ref, requires_grad=True)
self.var_p0 = Variable(self.input_p0, requires_grad=True)
self.var_p1 = Variable(self.input_p1, requires_grad=True)
def forward_train(self): # run forward pass
# print(self.net.module.scaling_layer.shift)
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
self.d0 = self.forward(self.var_ref, self.var_p0)
self.d1 = self.forward(self.var_ref, self.var_p1)
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
return self.loss_total
def backward_train(self):
torch.mean(self.loss_total).backward()
def compute_accuracy(self, d0, d1, judge):
''' d0, d1 are Variables, judge is a Tensor '''
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
judge_per = judge.cpu().numpy().flatten()
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
def get_current_errors(self):
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
('acc_r', self.acc_r)])
for key in retDict.keys():
retDict[key] = np.mean(retDict[key])
return retDict
def get_current_visuals(self):
zoom_factor = 256 / self.var_ref.data.size()[2]
ref_img = util.tensor2im(self.var_ref.data)
p0_img = util.tensor2im(self.var_p0.data)
p1_img = util.tensor2im(self.var_p1.data)
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
return OrderedDict([('ref', ref_img_vis),
('p0', p0_img_vis),
('p1', p1_img_vis)])
def save(self, path, label):
if(self.use_gpu):
self.save_network(self.net.module, path, '', label)
else:
self.save_network(self.net, path, '', label)
self.save_network(self.rankLoss.net, path, 'rank', label)
def update_learning_rate(self, nepoch_decay):
lrd = self.lr / nepoch_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_net.param_groups:
param_group['lr'] = lr
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
self.old_lr = lr
def score_2afc_dataset(data_loader, func, name=''):
''' Function computes Two Alternative Forced Choice (2AFC) score using
distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
OUTPUTS
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
[1] - dictionary with following elements
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
gts - N array in [0,1], preferred patch selected by human evaluators
(closer to "0" for left patch p0, "1" for right patch p1,
"0.6" means 60pct people preferred right patch, 40pct preferred left)
scores - N array in [0,1], corresponding to what percentage function agreed with humans
CONSTS
N - number of test triplets in data_loader
'''
d0s = []
d1s = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
gts += data['judge'].cpu().numpy().flatten().tolist()
d0s = np.array(d0s)
d1s = np.array(d1s)
gts = np.array(gts)
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
def score_jnd_dataset(data_loader, func, name=''):
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
OUTPUTS
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
[1] - dictionary with following elements
ds - N array containing distances between two patches shown to human evaluator
sames - N array containing fraction of people who thought the two patches were identical
CONSTS
N - number of test triplets in data_loader
'''
ds = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
gts += data['same'].cpu().numpy().flatten().tolist()
sames = np.array(gts)
ds = np.array(ds)
sorted_inds = np.argsort(ds)
ds_sorted = ds[sorted_inds]
sames_sorted = sames[sorted_inds]
TPs = np.cumsum(sames_sorted)
FPs = np.cumsum(1 - sames_sorted)
FNs = np.sum(sames_sorted) - TPs
precs = TPs / (TPs + FPs)
recs = TPs / (TPs + FNs)
score = util.voc_ap(recs, precs)
return(score, dict(ds=ds, sames=sames))

@ -0,0 +1,188 @@
from __future__ import absolute_import
import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn
from . import util
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2,3],keepdim=keepdim)
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
in_H = in_tens.shape[2]
scale_factor = 1.*out_H/in_H
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
# Learned perceptual metric
class PNetLin(nn.Module):
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
super(PNetLin, self).__init__()
self.pnet_type = pnet_type
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips
self.version = version
self.scaling_layer = ScalingLayer()
if(self.pnet_type in ['vgg','vgg16']):
net_type = pn.vgg16
self.chns = [64,128,256,512,512]
elif(self.pnet_type=='alex'):
net_type = pn.alexnet
self.chns = [64,192,384,256,256]
elif(self.pnet_type=='squeeze'):
net_type = pn.squeezenet
self.chns = [64,128,256,384,384,512,512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if(lpips):
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins+=[self.lin5,self.lin6]
def forward(self, in0, in1, retPerLayer=False):
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
val = res[0]
for l in range(1,self.L):
val += res[l]
if(retPerLayer):
return (val, res)
else:
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(),] if(use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
self.model = nn.Sequential(*layers)
class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
if(use_sigmoid):
layers += [nn.Sigmoid(),]
self.model = nn.Sequential(*layers)
def forward(self,d0,d1,eps=0.1):
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge+1.)/2.
self.logit = self.net.forward(d0,d1)
return self.loss(self.logit, per)
# L2, DSSIM metrics
class FakeNet(nn.Module):
def __init__(self, use_gpu=True, colorspace='Lab'):
super(FakeNet, self).__init__()
self.use_gpu = use_gpu
self.colorspace=colorspace
class L2(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
(N,C,X,Y) = in0.size()
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
return value
elif(self.colorspace=='Lab'):
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
class DSSIM(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
elif(self.colorspace=='Lab'):
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('Network',net)
print('Total number of parameters: %d' % num_params)

@ -0,0 +1,181 @@
from collections import namedtuple
import torch
from torchvision import models as tv
from IPython import embed
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2,5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if(num==18):
self.net = tv.resnet18(pretrained=pretrained)
elif(num==34):
self.net = tv.resnet34(pretrained=pretrained)
elif(num==50):
self.net = tv.resnet50(pretrained=pretrained)
elif(num==101):
self.net = tv.resnet101(pretrained=pretrained)
elif(num==152):
self.net = tv.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out

@ -0,0 +1,160 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from skimage.metrics import structural_similarity
import torch
from . import dist_model
class PerceptualLoss(torch.nn.Module):
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized'%self.model.name())
print('...Done')
def forward(self, pred, target, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]
Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""
if normalize:
target = 2 * target - 1
pred = 2 * pred - 1
return self.model.forward(target, pred)
def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)
def psnr(p0, p1, peak=255.):
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
def dssim(p0, p1, range=255.):
return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
def rgb2lab(in_img,mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if(mean_cent):
img_lab[:,:,0] = img_lab[:,:,0]-50
return img_lab
def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
# image tensor to lab tensor
from skimage import color
img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if(mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
if(to_norm and not mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
img_lab = img_lab/100.
return np2tensor(img_lab)
def tensorlab2tensor(lab_tensor,return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")
lab = tensor2np(lab_tensor)*100.
lab[:,:,0] = lab[:,:,0]+50
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
if(return_inbnd):
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1.*np.isclose(lab_back,lab,atol=2.)
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
return (im2tensor(rgb_back),mask)
else:
return im2tensor(rgb_back)
def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.)
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

@ -0,0 +1,964 @@
import os
import os.path as osp
from argparse import ArgumentParser
from functools import partial
import gradio as gr
import numpy as np
import torch
from PIL import Image
import imageio
import dnnlib
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
get_latest_points_pair, get_valid_mask,
on_change_single_global_state)
from viz.renderer import Renderer, add_watermark_np
from gan_inv.inversion import PTI
from gan_inv.lpips import util
parser = ArgumentParser()
parser.add_argument('--share',default='False')
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
args = parser.parse_args()
cache_dir = args.cache_dir
device = 'cuda'
def reverse_point_pairs(points):
new_points = []
for p in points:
new_points.append([p[1], p[0]])
return new_points
def clear_state(global_state, target=None):
"""Clear target history state from global_state
If target is not defined, points and mask will be both removed.
1. set global_state['points'] as empty dict
2. set global_state['mask'] as full-one mask.
"""
if target is None:
target = ['point', 'mask']
if not isinstance(target, list):
target = [target]
if 'point' in target:
global_state['points'] = dict()
print('Clear Points State!')
if 'mask' in target:
image_raw = global_state["images"]["image_raw"]
global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
dtype=np.uint8)
print('Clear mask State!')
return global_state
def init_images(global_state):
"""This function is called only ones with Gradio App is started.
0. pre-process global_state, unpack value from global_state of need
1. Re-init renderer
2. run `renderer._render_drag_impl` with `is_drag=False` to generate
new image
3. Assign images to global state and re-generate mask
"""
if isinstance(global_state, gr.State):
state = global_state.value
else:
state = global_state
state['renderer'].init_network(
state['generator_params'], # res
valid_checkpoints_dict[state['pretrained_weight']], # pkl
state['params']['seed'], # w0_seed,
None, # w_load
state['params']['latent_space'] == 'w+', # w_plus
'const',
state['params']['trunc_psi'], # trunc_psi,
state['params']['trunc_cutoff'], # trunc_cutoff,
None, # input_transform
state['params']['lr'] # lr,
)
state['renderer']._render_drag_impl(state['generator_params'],
is_drag=False,
to_pil=True)
init_image = state['generator_params'].image
state['images']['image_orig'] = init_image
state['images']['image_raw'] = init_image
state['images']['image_show'] = Image.fromarray(
add_watermark_np(np.array(init_image)))
state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
dtype=np.uint8)
return global_state
def update_image_draw(image, points, mask, show_mask, global_state=None):
image_draw = draw_points_on_image(image, points)
if show_mask and mask is not None and not (mask == 0).all() and not (
mask == 1).all():
image_draw = draw_mask_on_image(image_draw, mask)
image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
if global_state is not None:
global_state['images']['image_show'] = image_draw
return image_draw
def preprocess_mask_info(global_state, image):
"""Function to handle mask information.
1. last_mask is None: Do not need to change mask, return mask
2. last_mask is not None:
2.1 global_state is remove_mask:
2.2 global_state is add_mask:
"""
if isinstance(image, dict):
last_mask = get_valid_mask(image['mask'])
else:
last_mask = None
mask = global_state['mask']
# mask in global state is a placeholder with all 1.
if (mask == 1).all():
mask = last_mask
# last_mask = global_state['last_mask']
editing_mode = global_state['editing_state']
if last_mask is None:
return global_state
if editing_mode == 'remove_mask':
updated_mask = np.clip(mask - last_mask, 0, 1)
print(f'Last editing_state is {editing_mode}, do remove.')
elif editing_mode == 'add_mask':
updated_mask = np.clip(mask + last_mask, 0, 1)
print(f'Last editing_state is {editing_mode}, do add.')
else:
updated_mask = mask
print(f'Last editing_state is {editing_mode}, '
'do nothing to mask.')
global_state['mask'] = updated_mask
# global_state['last_mask'] = None # clear buffer
return global_state
valid_checkpoints_dict = {
f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
for f in os.listdir(cache_dir)
if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
}
print(f'File under cache_dir ({cache_dir}):')
print(os.listdir(cache_dir))
print('Valid checkpoint file:')
print(valid_checkpoints_dict)
init_pkl = 'stylegan2_lions_512_pytorch'
# Network & latents tab listeners
def on_change_pretrained_dropdown(pretrained_value, global_state):
"""Function to handle model change.
1. Set pretrained value to global_state
2. Re-init images and clear all states
"""
global_state['pretrained_weight'] = pretrained_value
init_images(global_state)
clear_state(global_state)
return global_state, global_state["images"]['image_show']
def on_click_reset_image(global_state):
"""Reset image to the original one and clear all states
1. Re-init images
2. Clear all states
"""
init_images(global_state)
clear_state(global_state)
return global_state, global_state['images']['image_show']
# Update parameters
def on_change_update_image_seed(seed, global_state):
"""Function to handle generation seed change.
1. Set seed to global_state
2. Re-init images and clear all states
"""
global_state["params"]["seed"] = int(seed)
init_images(global_state)
clear_state(global_state)
return global_state, global_state['images']['image_show']
def on_click_latent_space(latent_space, global_state):
"""Function to reset latent space to optimize.
NOTE: this function we reset the image and all controls
1. Set latent-space to global_state
2. Re-init images and clear all state
"""
global_state['params']['latent_space'] = latent_space
init_images(global_state)
clear_state(global_state)
return global_state, global_state['images']['image_show']
def on_click_inverse_custom_image(custom_image,global_state):
print('inverse GAN')
if isinstance(global_state, gr.State):
state = global_state.value
else:
state = global_state
state['renderer'].init_network(
state['generator_params'], # res
valid_checkpoints_dict[state['pretrained_weight']], # pkl
state['params']['seed'], # w0_seed,
None, # w_load
state['params']['latent_space'] == 'w+', # w_plus
'const',
state['params']['trunc_psi'], # trunc_psi,
state['params']['trunc_cutoff'], # trunc_cutoff,
None, # input_transform
state['params']['lr'] # lr,
)
percept = util.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=True
)
image = Image.open(custom_image.name)
pti = PTI(global_state['renderer'].G,percept)
inversed_img, w_pivot = pti.train(image,state['params']['latent_space'] == 'w+')
inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
inversed_img = inversed_img.cpu().numpy()
inversed_img = Image.fromarray(inversed_img)
global_state['images']['image_show'] = Image.fromarray(
add_watermark_np(np.array(inversed_img)))
global_state['images']['image_orig'] = inversed_img
global_state['images']['image_raw'] = inversed_img
global_state['mask'] = np.ones((inversed_img.size[1], inversed_img.size[0]),
dtype=np.uint8)
global_state['generator_params'].image = inversed_img
global_state['generator_params'].w = w_pivot.detach().cpu().numpy()
global_state['renderer'].set_latent(w_pivot,global_state['params']['trunc_psi'],global_state['params']['trunc_cutoff'])
del percept
del pti
print('inverse end')
return global_state, global_state['images']['image_show'], gr.Button.update(interactive=True)
def on_save_image(global_state,form_save_image_path):
imageio.imsave(form_save_image_path,global_state['images']['image_raw'])
def on_reset_custom_image(global_state):
if isinstance(global_state, gr.State):
state = global_state.value
else:
state = global_state
clear_state(state)
state['renderer'].w = state['renderer'].w0.detach().clone()
state['renderer'].w.requires_grad = True
state['renderer'].w_optim = torch.optim.Adam([state['renderer'].w], lr=state['renderer'].lr)
state['renderer']._render_drag_impl(state['generator_params'],
is_drag=False,
to_pil=True)
init_image = state['generator_params'].image
state['images']['image_orig'] = init_image
state['images']['image_raw'] = init_image
state['images']['image_show'] = Image.fromarray(
add_watermark_np(np.array(init_image)))
state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
dtype=np.uint8)
return state, state['images']['image_show']
def on_change_lr(lr, global_state):
if lr == 0:
print('lr is 0, do nothing.')
return global_state
else:
global_state["params"]["lr"] = lr
renderer = global_state['renderer']
renderer.update_lr(lr)
print('New optimizer: ')
print(renderer.w_optim)
return global_state
def on_click_start(global_state, image):
p_in_pixels = []
t_in_pixels = []
valid_points = []
# handle of start drag in mask editing mode
global_state = preprocess_mask_info(global_state, image)
# Prepare the points for the inference
if len(global_state["points"]) == 0:
# yield on_click_start_wo_points(global_state, image)
image_raw = global_state['images']['image_raw']
update_image_draw(
image_raw,
global_state['points'],
global_state['mask'],
global_state['show_mask'],
global_state,
)
yield (
global_state,
0,
global_state['images']['image_show'],
# gr.File.update(visible=False),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
# latent space
gr.Radio.update(interactive=True),
gr.Button.update(interactive=True),
# NOTE: disable stop button
gr.Button.update(interactive=False),
# update other comps
gr.Dropdown.update(interactive=True),
gr.Number.update(interactive=True),
gr.Number.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Checkbox.update(interactive=True),
# gr.Number.update(interactive=True),
gr.Number.update(interactive=True),
)
else:
# Transform the points into torch tensors
for key_point, point in global_state["points"].items():
try:
p_start = point.get("start_temp", point["start"])
p_end = point["target"]
if p_start is None or p_end is None:
continue
except KeyError:
continue
p_in_pixels.append(p_start)
t_in_pixels.append(p_end)
valid_points.append(key_point)
mask = torch.tensor(global_state['mask']).float()
drag_mask = 1 - mask
renderer: Renderer = global_state["renderer"]
global_state['temporal_params']['stop'] = False
global_state['editing_state'] = 'running'
# reverse points order
p_to_opt = reverse_point_pairs(p_in_pixels)
t_to_opt = reverse_point_pairs(t_in_pixels)
#print('Running with:')
#print(f' Source: {p_in_pixels}')
#print(f' Target: {t_in_pixels}')
step_idx = 0
while True:
if global_state["temporal_params"]["stop"]:
break
# do drage here!
renderer._render_drag_impl(
global_state['generator_params'],
p_to_opt, # point
t_to_opt, # target
drag_mask, # mask,
global_state['params']['motion_lambda'], # lambda_mask
reg=0,
feature_idx=5, # NOTE: do not support change for now
r1=global_state['params']['r1_in_pixels'], # r1
r2=global_state['params']['r2_in_pixels'], # r2
# random_seed = 0,
# noise_mode = 'const',
trunc_psi=global_state['params']['trunc_psi'],
# force_fp32 = False,
# layer_name = None,
# sel_channels = 3,
# base_channel = 0,
# img_scale_db = 0,
# img_normalize = False,
# untransform = False,
is_drag=True,
to_pil=True)
if step_idx % global_state['draw_interval'] == 0:
#print('Current Source:')
for key_point, p_i, t_i in zip(valid_points, p_to_opt,
t_to_opt):
global_state["points"][key_point]["start_temp"] = [
p_i[1],
p_i[0],
]
global_state["points"][key_point]["target"] = [
t_i[1],
t_i[0],
]
start_temp = global_state["points"][key_point][
"start_temp"]
#print(f' {start_temp}')
image_result = global_state['generator_params']['image']
image_draw = update_image_draw(
image_result,
global_state['points'],
global_state['mask'],
global_state['show_mask'],
global_state,
)
global_state['images']['image_raw'] = image_result
yield (
global_state,
step_idx,
global_state['images']['image_show'],
# gr.File.update(visible=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
# latent space
gr.Radio.update(interactive=False),
gr.Button.update(interactive=False),
# enable stop button in loop
gr.Button.update(interactive=True),
# update other comps
gr.Dropdown.update(interactive=False),
gr.Number.update(interactive=False),
gr.Number.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Checkbox.update(interactive=False),
# gr.Number.update(interactive=False),
gr.Number.update(interactive=False),
)
# increate step
step_idx += 1
image_result = global_state['generator_params']['image']
global_state['images']['image_raw'] = image_result
image_draw = update_image_draw(image_result,
global_state['points'],
global_state['mask'],
global_state['show_mask'],
global_state)
# fp = NamedTemporaryFile(suffix=".png", delete=False)
# image_result.save(fp, "PNG")
global_state['editing_state'] = 'add_points'
yield (
global_state,
0, # reset step to 0 after stop.
global_state['images']['image_show'],
# gr.File.update(visible=True, value=fp.name),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
# latent space
gr.Radio.update(interactive=True),
gr.Button.update(interactive=True),
# NOTE: disable stop button with loop finish
gr.Button.update(interactive=False),
# update other comps
gr.Dropdown.update(interactive=True),
gr.Number.update(interactive=True),
gr.Number.update(interactive=True),
gr.Checkbox.update(interactive=True),
gr.Number.update(interactive=True),
)
def on_click_stop(global_state):
"""Function to handle stop button is clicked.
1. send a stop signal by set global_state["temporal_params"]["stop"] as True
2. Disable Stop button
"""
global_state["temporal_params"]["stop"] = True
return global_state, gr.Button.update(interactive=False)
def on_click_remove_point(global_state):
choice = global_state["curr_point"]
del global_state["points"][choice]
choices = list(global_state["points"].keys())
if len(choices) > 0:
global_state["curr_point"] = choices[0]
return (
gr.Dropdown.update(choices=choices, value=choices[0]),
global_state,
)
# Mask
def on_click_reset_mask(global_state):
global_state['mask'] = np.ones(
(
global_state["images"]["image_raw"].size[1],
global_state["images"]["image_raw"].size[0],
),
dtype=np.uint8,
)
image_draw = update_image_draw(global_state['images']['image_raw'],
global_state['points'],
global_state['mask'],
global_state['show_mask'], global_state)
return global_state, image_draw
# Image
def on_click_enable_draw(global_state, image):
"""Function to start add mask mode.
1. Preprocess mask info from last state
2. Change editing state to add_mask
3. Set curr image with points and mask
"""
global_state = preprocess_mask_info(global_state, image)
global_state['editing_state'] = 'add_mask'
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(image_raw, global_state['points'],
global_state['mask'], True,
global_state)
return (global_state,
gr.Image.update(value=image_draw, interactive=True))
def on_click_remove_draw(global_state, image):
"""Function to start remove mask mode.
1. Preprocess mask info from last state
2. Change editing state to remove_mask
3. Set curr image with points and mask
"""
global_state = preprocess_mask_info(global_state, image)
global_state['edinting_state'] = 'remove_mask'
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(image_raw, global_state['points'],
global_state['mask'], True,
global_state)
return (global_state,
gr.Image.update(value=image_draw, interactive=True))
def on_click_add_point(global_state, image: dict):
"""Function switch from add mask mode to add points mode.
1. Updaste mask buffer if need
2. Change global_state['editing_state'] to 'add_points'
3. Set current image with mask
"""
global_state = preprocess_mask_info(global_state, image)
global_state['editing_state'] = 'add_points'
mask = global_state['mask']
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(image_raw, global_state['points'], mask,
global_state['show_mask'], global_state)
return (global_state,
gr.Image.update(value=image_draw, interactive=False))
def on_click_image(global_state, evt: gr.SelectData):
"""This function only support click for point selection
"""
xy = evt.index
if global_state['editing_state'] != 'add_points':
print(f'In {global_state["editing_state"]} state. '
'Do not add points.')
return global_state, global_state['images']['image_show']
points = global_state["points"]
point_idx = get_latest_points_pair(points)
if point_idx is None:
points[0] = {'start': xy, 'target': None}
print(f'Click Image - Start - {xy}')
elif points[point_idx].get('target', None) is None:
points[point_idx]['target'] = xy
print(f'Click Image - Target - {xy}')
else:
points[point_idx + 1] = {'start': xy, 'target': None}
print(f'Click Image - Start - {xy}')
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(
image_raw,
global_state['points'],
global_state['mask'],
global_state['show_mask'],
global_state,
)
return global_state, image_draw
def on_click_clear_points(global_state):
"""Function to handle clear all control points
1. clear global_state['points'] (clear_state)
2. re-init network
2. re-draw image
"""
clear_state(global_state, target='point')
renderer: Renderer = global_state["renderer"]
renderer.feat_refs = None
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(image_raw, {}, global_state['mask'],
global_state['show_mask'], global_state)
return global_state, image_draw
def on_click_show_mask(global_state, show_mask):
"""Function to control whether show mask on image."""
global_state['show_mask'] = show_mask
image_raw = global_state['images']['image_raw']
image_draw = update_image_draw(
image_raw,
global_state['points'],
global_state['mask'],
global_state['show_mask'],
global_state,
)
return global_state, image_draw
if __name__ == "__main__":
with gr.Blocks() as app:
# renderer = Renderer()
global_state = gr.State({
"images": {
# image_orig: the original image, change with seed/model is changed
# image_raw: image with mask and points, change durning optimization
# image_show: image showed on screen
},
"temporal_params": {
# stop
},
'mask':
None, # mask for visualization, 1 for editing and 0 for unchange
'last_mask': None, # last edited mask
'show_mask': True, # add button
"generator_params": dnnlib.EasyDict(),
"params": {
"seed": 0,
"motion_lambda": 20,
"r1_in_pixels": 3,
"r2_in_pixels": 12,
"magnitude_direction_in_pixels": 1.0,
"latent_space": "w+",
"trunc_psi": 0.7,
"trunc_cutoff": None,
"lr": 0.001,
},
"device": device,
"draw_interval": 1,
"renderer": Renderer(disable_timing=True),
"points": {},
"curr_point": None,
"curr_type_point": "start",
'editing_state': 'add_points',
'pretrained_weight': init_pkl
})
# init image
global_state = init_images(global_state)
with gr.Row():
with gr.Row():
# Left --> tools
with gr.Column(scale=3):
# Pickle
with gr.Row():
with gr.Column(scale=1, min_width=10):
gr.Markdown(value='Pickle', show_label=False)
with gr.Column(scale=4, min_width=10):
form_pretrained_dropdown = gr.Dropdown(
choices=list(valid_checkpoints_dict.keys()),
label="Pretrained Model",
value=init_pkl,
)
# Latent
with gr.Row():
with gr.Column(scale=1, min_width=10):
gr.Markdown(value='Latent', show_label=False)
with gr.Column(scale=4, min_width=10):
form_seed_number = gr.Number(
value=global_state.value['params']['seed'],
interactive=True,
label="Seed",
)
form_lr_number = gr.Number(
value=global_state.value["params"]["lr"],
interactive=True,
label="Step Size")
with gr.Row():
with gr.Column(scale=2, min_width=10):
form_reset_image = gr.Button("Reset Image")
with gr.Column(scale=3, min_width=10):
form_latent_space = gr.Radio(
['w', 'w+'],
value=global_state.value['params']
['latent_space'],
interactive=True,
label='Latent space to optimize',
show_label=False,
)
with gr.Row():
with gr.Column(scale=3, min_width=10):
form_custom_image = gr.UploadButton(label="inverse custom image",
file_types=['.png', '.jpg', '.jpeg'])
with gr.Column(scale=3, min_width=10):
form_reset_custom_image = gr.Button('reset custom image', interactive=False)
with gr.Row():
with gr.Column(scale=3, min_width=10):
form_save_image_path = gr.Textbox(label="save image to",value='./test.png')
form_save_image = gr.Button('save',interactive=True)
# Drag
with gr.Row():
with gr.Column(scale=1, min_width=10):
gr.Markdown(value='Drag', show_label=False)
with gr.Column(scale=4, min_width=10):
with gr.Row():
with gr.Column(scale=1, min_width=10):
enable_add_points = gr.Button('Add Points')
with gr.Column(scale=1, min_width=10):
undo_points = gr.Button('Reset Points')
with gr.Row():
with gr.Column(scale=1, min_width=10):
form_start_btn = gr.Button("Start")
with gr.Column(scale=1, min_width=10):
form_stop_btn = gr.Button("Stop")
form_steps_number = gr.Number(value=0,
label="Steps",
interactive=False)
# Mask
with gr.Row():
with gr.Column(scale=1, min_width=10):
gr.Markdown(value='Mask', show_label=False)
with gr.Column(scale=4, min_width=10):
enable_add_mask = gr.Button('Edit Flexible Area')
with gr.Row():
with gr.Column(scale=1, min_width=10):
form_reset_mask_btn = gr.Button("Reset mask")
with gr.Column(scale=1, min_width=10):
show_mask = gr.Checkbox(
label='Show Mask',
value=global_state.value['show_mask'],
show_label=False)
with gr.Row():
form_lambda_number = gr.Number(
value=global_state.value["params"]
["motion_lambda"],
interactive=True,
label="Lambda",
)
form_draw_interval_number = gr.Number(
value=global_state.value["draw_interval"],
label="Draw Interval (steps)",
interactive=True,
visible=False)
# Right --> Image
with gr.Column(scale=8):
form_image = ImageMask(
value=global_state.value['images']['image_show'],
brush_radius=20).style(
width=768,
height=768) # NOTE: hard image size code here.
gr.Markdown("""
## Quick Start
1. Select desired `Pretrained Model` and adjust `Seed` to generate an
initial image.
2. Click on image to add control points.
3. Click `Start` and enjoy it!
## Advance Usage
1. Change `Step Size` to adjust learning rate in drag optimization.
2. Select `w` or `w+` to change latent space to optimize:
* Optimize on `w` space may cause greater influence to the image.
* Optimize on `w+` space may work slower than `w`, but usually achieve
better results.
* Note that changing the latent space will reset the image, points and
mask (this has the same effect as `Reset Image` button).
3. Click `Edit Flexible Area` to create a mask and constrain the
unmasked region to remain unchanged.
""")
gr.HTML("""
<style>
.container {
position: absolute;
height: 50px;
text-align: center;
line-height: 50px;
width: 100%;
}
</style>
<div class="container">
Gradio demo supported by
<img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
<a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
</div>
""")
show_mask.change(
on_click_show_mask,
inputs=[global_state, show_mask],
outputs=[global_state, form_image],
)
undo_points.click(on_click_clear_points,
inputs=[global_state],
outputs=[global_state, form_image])
form_image.select(
on_click_image,
inputs=[global_state],
outputs=[global_state, form_image],
)
enable_add_mask.click(on_click_enable_draw,
inputs=[global_state, form_image],
outputs=[
global_state,
form_image,
])
enable_add_points.click(on_click_add_point,
inputs=[global_state, form_image],
outputs=[global_state, form_image])
form_reset_mask_btn.click(
on_click_reset_mask,
inputs=[global_state],
outputs=[global_state, form_image],
)
form_stop_btn.click(on_click_stop,
inputs=[global_state],
outputs=[global_state, form_stop_btn])
form_draw_interval_number.change(
partial(
on_change_single_global_state,
"draw_interval",
map_transform=lambda x: int(x),
),
inputs=[form_draw_interval_number, global_state],
outputs=[global_state],
)
form_start_btn.click(
on_click_start,
inputs=[global_state, form_image],
outputs=[
global_state,
form_steps_number,
form_image,
# form_download_result_file,
# >>> buttons
form_reset_image,
enable_add_points,
enable_add_mask,
undo_points,
form_reset_mask_btn,
form_latent_space,
form_start_btn,
form_stop_btn,
# <<< buttonm
# >>> inputs comps
form_pretrained_dropdown,
form_seed_number,
form_lr_number,
show_mask,
form_lambda_number,
],
)
form_lr_number.change(
on_change_lr,
inputs=[form_lr_number, global_state],
outputs=[global_state],
)
form_custom_image.upload(on_click_inverse_custom_image, inputs=[form_custom_image, global_state],
outputs=[global_state, form_image,form_reset_custom_image])
form_save_image.click(on_save_image,inputs=[global_state,form_save_image_path],outputs=[])
form_reset_custom_image.click(on_reset_custom_image,inputs=[global_state],outputs=[global_state,form_image])
# ==== Params
form_lambda_number.change(
partial(on_change_single_global_state, ["params", "motion_lambda"]),
inputs=[form_lambda_number, global_state],
outputs=[global_state],
)
form_latent_space.change(on_click_latent_space,
inputs=[form_latent_space, global_state],
outputs=[global_state, form_image])
form_seed_number.change(
on_change_update_image_seed,
inputs=[form_seed_number, global_state],
outputs=[global_state, form_image],
)
form_reset_image.click(
on_click_reset_image,
inputs=[global_state],
outputs=[global_state, form_image],
)
form_pretrained_dropdown.change(
on_change_pretrained_dropdown,
inputs=[form_pretrained_dropdown, global_state],
outputs=[global_state, form_image],
)
#gr.close_all()
app.queue(concurrency_count=3, max_size=20)
app.launch(share=args.share)

@ -225,7 +225,7 @@ class Renderer:
res.num_ws = G.num_ws
res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
self.lr = lr
# Set input transform.
if res.has_input_transform:
m = np.eye(3)
@ -262,6 +262,17 @@ class Renderer:
self.feat_refs = None
self.points0_pt = None
def set_latent(self,w,trunc_psi,trunc_cutoff):
#label = torch.zeros([1, self.G.c_dim], device=self._device)
#w = self.G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
self.w0 = w.detach().clone()
if self.w_plus:
self.w = w.detach()
else:
self.w = w[:, 0, :].detach()
self.w.requires_grad = True
self.w_optim = torch.optim.Adam([self.w], lr=self.lr)
def update_lr(self, lr):
del self.w_optim

Loading…
Cancel
Save