mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
2.0 KiB
Python
84 lines
2.0 KiB
Python
# Copyright (c) SenseTime Research. All rights reserved.
|
|
|
|
import torch
|
|
import cv2
|
|
from torchvision import transforms
|
|
import numpy as np
|
|
import math
|
|
|
|
|
|
def visual(output, out_path):
|
|
output = (output + 1)/2
|
|
output = torch.clamp(output, 0, 1)
|
|
if output.shape[1] == 1:
|
|
output = torch.cat([output, output, output], 1)
|
|
output = output[0].detach().cpu().permute(1,2,0).numpy()
|
|
output = (output*255).astype(np.uint8)
|
|
output = output[:,:,::-1]
|
|
cv2.imwrite(out_path, output)
|
|
|
|
|
|
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
|
|
|
lr_ramp = min(1, (1 - t) / rampdown)
|
|
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
|
lr_ramp = lr_ramp * min(1, t / rampup)
|
|
return initial_lr * lr_ramp
|
|
|
|
|
|
|
|
def latent_noise(latent, strength):
|
|
noise = torch.randn_like(latent) * strength
|
|
|
|
return latent + noise
|
|
|
|
def noise_regularize_(noises):
|
|
loss = 0
|
|
|
|
for noise in noises:
|
|
size = noise.shape[2]
|
|
|
|
while True:
|
|
loss = (
|
|
loss
|
|
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
|
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
|
)
|
|
|
|
if size <= 8:
|
|
break
|
|
|
|
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
|
noise = noise.mean([3, 5])
|
|
size //= 2
|
|
|
|
return loss
|
|
|
|
|
|
def noise_normalize_(noises):
|
|
for noise in noises:
|
|
mean = noise.mean()
|
|
std = noise.std()
|
|
|
|
noise.data.add_(-mean).div_(std)
|
|
|
|
|
|
def tensor_to_numpy(x):
|
|
x = x[0].permute(1, 2, 0)
|
|
x = torch.clamp(x, -1 ,1)
|
|
x = (x+1) * 127.5
|
|
x = x.cpu().detach().numpy().astype(np.uint8)
|
|
return x
|
|
|
|
def numpy_to_tensor(x):
|
|
x = (x / 255 - 0.5) * 2
|
|
x = torch.from_numpy(x).unsqueeze(0).permute(0, 3, 1, 2)
|
|
x = x.cuda().float()
|
|
return x
|
|
|
|
def tensor_to_pil(x):
|
|
x = torch.clamp(x, -1 ,1)
|
|
x = (x+1) * 127.5
|
|
return transforms.ToPILImage()(x.squeeze_(0))
|
|
|