Add files via upload

pull/142/head
Tianhao Xie 2 years ago committed by GitHub
parent 2c428ff9bd
commit d6aa972708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,36 +14,6 @@ import imageio
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
)
if size <= 8:
break
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
noise = noise.mean([3, 5])
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
@ -53,10 +23,7 @@ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def make_image(tensor):
@ -259,6 +226,27 @@ class PTI:
return loss
def train(self,img,w_plus=False):
if torch.is_tensor(img) == False:
transform = transforms.Compose(
[
transforms.Resize(self.g_ema.img_resolution, ),
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to('cuda').unsqueeze(0)
else:
img = transforms.functional.resize(img, self.g_ema.img_resolution)
transform = transforms.Compose(
[
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to('cuda').unsqueeze(0)
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
w_pivot = inversed_result['latent']
if w_plus:
@ -275,7 +263,7 @@ class PTI:
optimizer.param_groups[0]["lr"] = lr
generated_image = self.g_ema.synthesis(ws,noise_mode='const')
loss = self.cacl_loss(self.percept,generated_image,inversed_result['real'])
loss = self.cacl_loss(self.percept,generated_image,real_img)
pbar.set_description(
(
f"loss: {loss.item():.4f}"

Loading…
Cancel
Save