mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
10 KiB
Python
344 lines
10 KiB
Python
import math
|
|
import os
|
|
from viz import renderer
|
|
import torch
|
|
from torch import optim
|
|
from torch.nn import functional as F
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import dataclasses
|
|
import dnnlib
|
|
from .lpips import util
|
|
import imageio
|
|
|
|
|
|
|
|
|
|
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
|
lr_ramp = min(1, (1 - t) / rampdown)
|
|
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
|
lr_ramp = lr_ramp * min(1, t / rampup)
|
|
|
|
return initial_lr * lr_ramp
|
|
|
|
|
|
|
|
|
|
|
|
def make_image(tensor):
|
|
return (
|
|
tensor.detach()
|
|
.clamp_(min=-1, max=1)
|
|
.add(1)
|
|
.div_(2)
|
|
.mul(255)
|
|
.type(torch.uint8)
|
|
.permute(0, 2, 3, 1)
|
|
.to("cpu")
|
|
.numpy()
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class InverseConfig:
|
|
lr_warmup = 0.05
|
|
lr_decay = 0.25
|
|
lr = 0.1
|
|
noise = 0.05
|
|
noise_decay = 0.75
|
|
step = 1000
|
|
noise_regularize = 1e5
|
|
mse = 0.1
|
|
|
|
|
|
|
|
def inverse_image(
|
|
g_ema,
|
|
image,
|
|
percept,
|
|
image_size=256,
|
|
w_plus = False,
|
|
config=InverseConfig(),
|
|
device='cuda:0'
|
|
):
|
|
args = config
|
|
|
|
n_mean_latent = 10000
|
|
|
|
resize = min(image_size, 256)
|
|
|
|
if torch.is_tensor(image)==False:
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(resize,),
|
|
transforms.CenterCrop(resize),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
]
|
|
)
|
|
|
|
img = transform(image)
|
|
|
|
else:
|
|
img = transforms.functional.resize(image,resize)
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.CenterCrop(resize),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
]
|
|
)
|
|
img = transform(img)
|
|
imgs = []
|
|
imgs.append(img)
|
|
imgs = torch.stack(imgs, 0).to(device)
|
|
|
|
with torch.no_grad():
|
|
|
|
#noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
|
noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device)
|
|
#label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device)
|
|
w_samples = g_ema.mapping(noise_sample,None)
|
|
w_samples = w_samples[:, :1, :]
|
|
w_avg = w_samples.mean(0)
|
|
w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5
|
|
|
|
|
|
|
|
|
|
noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name}
|
|
for noise in noises.values():
|
|
noise = torch.randn_like(noise)
|
|
noise.requires_grad = True
|
|
|
|
|
|
|
|
w_opt = w_avg.detach().clone()
|
|
if w_plus:
|
|
w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1)
|
|
w_opt.requires_grad = True
|
|
#if args.w_plus:
|
|
#latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
|
|
|
|
|
|
|
optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr)
|
|
|
|
pbar = tqdm(range(args.step))
|
|
latent_path = []
|
|
|
|
for i in pbar:
|
|
t = i / args.step
|
|
lr = get_lr(t, args.lr)
|
|
optimizer.param_groups[0]["lr"] = lr
|
|
noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
|
|
|
|
w_noise = torch.randn_like(w_opt) * noise_strength
|
|
if w_plus:
|
|
ws = w_opt + w_noise
|
|
else:
|
|
ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1])
|
|
|
|
img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True)
|
|
|
|
#latent_n = latent_noise(latent_in, noise_strength.item())
|
|
|
|
#latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
|
|
#img_gen, F = g_ema.generate(latent, noise)
|
|
|
|
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
|
|
|
|
if img_gen.shape[2] > 256:
|
|
img_gen = F.interpolate(img_gen, size=(256, 256), mode='area')
|
|
|
|
p_loss = percept(img_gen,imgs)
|
|
|
|
|
|
# Noise regularization.
|
|
reg_loss = 0.0
|
|
for v in noises.values():
|
|
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
|
|
while True:
|
|
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
|
|
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
|
|
if noise.shape[2] <= 8:
|
|
break
|
|
noise = F.avg_pool2d(noise, kernel_size=2)
|
|
mse_loss = F.mse_loss(img_gen, imgs)
|
|
|
|
loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Normalize noise.
|
|
with torch.no_grad():
|
|
for buf in noises.values():
|
|
buf -= buf.mean()
|
|
buf *= buf.square().mean().rsqrt()
|
|
|
|
if (i + 1) % 100 == 0:
|
|
latent_path.append(w_opt.detach().clone())
|
|
|
|
pbar.set_description(
|
|
(
|
|
f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};"
|
|
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
|
|
)
|
|
)
|
|
|
|
#latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
|
|
#img_gen, F = g_ema.generate(latent, noise)
|
|
if w_plus:
|
|
ws = latent_path[-1]
|
|
else:
|
|
ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1])
|
|
|
|
img_gen = g_ema.synthesis(ws, noise_mode='const')
|
|
|
|
|
|
result = {
|
|
"latent": latent_path[-1],
|
|
"sample": img_gen,
|
|
"real": imgs,
|
|
}
|
|
|
|
return result
|
|
|
|
def toogle_grad(model, flag=True):
|
|
for p in model.parameters():
|
|
p.requires_grad = flag
|
|
|
|
|
|
class PTI:
|
|
def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
|
|
self.g_ema = G
|
|
self.l2_lambda = l2_lambda
|
|
self.max_pti_step = max_pti_step
|
|
self.pti_lr = pti_lr
|
|
self.percept = percept
|
|
def cacl_loss(self,percept, generated_image,real_image):
|
|
|
|
mse_loss = F.mse_loss(generated_image, real_image)
|
|
p_loss = percept(generated_image, real_image).sum()
|
|
loss = p_loss +self.l2_lambda * mse_loss
|
|
return loss
|
|
|
|
def train(self,img,w_plus=False):
|
|
if torch.is_tensor(img) == False:
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(self.g_ema.img_resolution, ),
|
|
transforms.CenterCrop(self.g_ema.img_resolution),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
]
|
|
)
|
|
|
|
real_img = transform(img).to('cuda').unsqueeze(0)
|
|
|
|
else:
|
|
img = transforms.functional.resize(img, self.g_ema.img_resolution)
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.CenterCrop(self.g_ema.img_resolution),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
]
|
|
)
|
|
real_img = transform(img).to('cuda').unsqueeze(0)
|
|
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
|
|
w_pivot = inversed_result['latent']
|
|
if w_plus:
|
|
ws = w_pivot
|
|
else:
|
|
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
|
|
toogle_grad(self.g_ema,True)
|
|
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
|
|
print('start PTI')
|
|
pbar = tqdm(range(self.max_pti_step))
|
|
for i in pbar:
|
|
t = i / self.max_pti_step
|
|
lr = get_lr(t, self.pti_lr)
|
|
optimizer.param_groups[0]["lr"] = lr
|
|
|
|
generated_image = self.g_ema.synthesis(ws,noise_mode='const')
|
|
loss = self.cacl_loss(self.percept,generated_image,real_img)
|
|
pbar.set_description(
|
|
(
|
|
f"loss: {loss.item():.4f}"
|
|
)
|
|
)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
with torch.no_grad():
|
|
generated_image = self.g_ema.synthesis(ws, noise_mode='const')
|
|
|
|
return generated_image,ws
|
|
|
|
if __name__ == "__main__":
|
|
state = {
|
|
"images": {
|
|
# image_orig: the original image, change with seed/model is changed
|
|
# image_raw: image with mask and points, change durning optimization
|
|
# image_show: image showed on screen
|
|
},
|
|
"temporal_params": {
|
|
# stop
|
|
},
|
|
'mask':
|
|
None, # mask for visualization, 1 for editing and 0 for unchange
|
|
'last_mask': None, # last edited mask
|
|
'show_mask': True, # add button
|
|
"generator_params": dnnlib.EasyDict(),
|
|
"params": {
|
|
"seed": 0,
|
|
"motion_lambda": 20,
|
|
"r1_in_pixels": 3,
|
|
"r2_in_pixels": 12,
|
|
"magnitude_direction_in_pixels": 1.0,
|
|
"latent_space": "w+",
|
|
"trunc_psi": 0.7,
|
|
"trunc_cutoff": None,
|
|
"lr": 0.001,
|
|
},
|
|
"device": 'cuda:0',
|
|
"draw_interval": 1,
|
|
"renderer": renderer.Renderer(disable_timing=True),
|
|
"points": {},
|
|
"curr_point": None,
|
|
"curr_type_point": "start",
|
|
'editing_state': 'add_points',
|
|
'pretrained_weight': 'stylegan2_horses_256_pytorch'
|
|
}
|
|
cache_dir = '../checkpoints'
|
|
valid_checkpoints_dict = {
|
|
f.split('/')[-1].split('.')[0]: os.path.join(cache_dir, f)
|
|
for f in os.listdir(cache_dir)
|
|
if (f.endswith('pkl') and os.path.exists(os.path.join(cache_dir, f)))
|
|
}
|
|
state['renderer'].init_network(state['generator_params'], # res
|
|
valid_checkpoints_dict[state['pretrained_weight']], # pkl
|
|
state['params']['seed'], # w0_seed,
|
|
None, # w_load
|
|
state['params']['latent_space'] == 'w+', # w_plus
|
|
'const',
|
|
state['params']['trunc_psi'], # trunc_psi,
|
|
state['params']['trunc_cutoff'], # trunc_cutoff,
|
|
None, # input_transform
|
|
state['params']['lr'] # lr
|
|
)
|
|
image = Image.open('/home/tianhao/research/drag3d/horse/render/0.png')
|
|
G = state['renderer'].G
|
|
#result = inverse_image(G,image,G.img_resolution)
|
|
percept = util.PerceptualLoss(
|
|
model="net-lin", net="vgg", use_gpu=True
|
|
)
|
|
pti = PTI(G,percept)
|
|
result = pti.train(image,True)
|
|
imageio.imsave('../horse/test.png', make_image(result[0])[0])
|
|
|
|
|
|
|