|
|
|
|
@ -14,36 +14,6 @@ import imageio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def noise_regularize(noises):
|
|
|
|
|
loss = 0
|
|
|
|
|
|
|
|
|
|
for noise in noises:
|
|
|
|
|
size = noise.shape[2]
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
loss = (
|
|
|
|
|
loss
|
|
|
|
|
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
|
|
|
|
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if size <= 8:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
|
|
|
|
noise = noise.mean([3, 5])
|
|
|
|
|
size //= 2
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def noise_normalize_(noises):
|
|
|
|
|
for noise in noises:
|
|
|
|
|
mean = noise.mean()
|
|
|
|
|
std = noise.std()
|
|
|
|
|
|
|
|
|
|
noise.data.add_(-mean).div_(std)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
|
|
|
|
lr_ramp = min(1, (1 - t) / rampdown)
|
|
|
|
|
@ -53,10 +23,7 @@ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
|
|
|
|
return initial_lr * lr_ramp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def latent_noise(latent, strength):
|
|
|
|
|
noise = torch.randn_like(latent) * strength
|
|
|
|
|
|
|
|
|
|
return latent + noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_image(tensor):
|
|
|
|
|
@ -259,6 +226,27 @@ class PTI:
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
def train(self,img,w_plus=False):
|
|
|
|
|
if torch.is_tensor(img) == False:
|
|
|
|
|
transform = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.Resize(self.g_ema.img_resolution, ),
|
|
|
|
|
transforms.CenterCrop(self.g_ema.img_resolution),
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
real_img = transform(img).to('cuda').unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
img = transforms.functional.resize(img, self.g_ema.img_resolution)
|
|
|
|
|
transform = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.CenterCrop(self.g_ema.img_resolution),
|
|
|
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
real_img = transform(img).to('cuda').unsqueeze(0)
|
|
|
|
|
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
|
|
|
|
|
w_pivot = inversed_result['latent']
|
|
|
|
|
if w_plus:
|
|
|
|
|
@ -275,7 +263,7 @@ class PTI:
|
|
|
|
|
optimizer.param_groups[0]["lr"] = lr
|
|
|
|
|
|
|
|
|
|
generated_image = self.g_ema.synthesis(ws,noise_mode='const')
|
|
|
|
|
loss = self.cacl_loss(self.percept,generated_image,inversed_result['real'])
|
|
|
|
|
loss = self.cacl_loss(self.percept,generated_image,real_img)
|
|
|
|
|
pbar.set_description(
|
|
|
|
|
(
|
|
|
|
|
f"loss: {loss.item():.4f}"
|
|
|
|
|
|