| 
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -14,36 +14,6 @@ import imageio
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def noise_regularize(noises):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    loss = 0
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    for noise in noises:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        size = noise.shape[2]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        while True:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss = (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                loss
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            if size <= 8:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                break
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            noise = noise.mean([3, 5])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            size //= 2
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    return loss
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def noise_normalize_(noises):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    for noise in noises:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        mean = noise.mean()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        std = noise.std()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        noise.data.add_(-mean).div_(std)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    lr_ramp = min(1, (1 - t) / rampdown)
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -53,10 +23,7 @@ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    return initial_lr * lr_ramp
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def latent_noise(latent, strength):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    noise = torch.randn_like(latent) * strength
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    return latent + noise
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				def make_image(tensor):
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -259,6 +226,27 @@ class PTI:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return loss
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def train(self,img,w_plus=False):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if torch.is_tensor(img) == False:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            transform = transforms.Compose(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.Resize(self.g_ema.img_resolution, ),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.CenterCrop(self.g_ema.img_resolution),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.ToTensor(),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            real_img = transform(img).to('cuda').unsqueeze(0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        else:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            img = transforms.functional.resize(img, self.g_ema.img_resolution)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            transform = transforms.Compose(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                [
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.CenterCrop(self.g_ema.img_resolution),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                ]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            real_img = transform(img).to('cuda').unsqueeze(0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        w_pivot = inversed_result['latent']
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if w_plus:
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -275,7 +263,7 @@ class PTI:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            optimizer.param_groups[0]["lr"] = lr
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            generated_image = self.g_ema.synthesis(ws,noise_mode='const')
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss = self.cacl_loss(self.percept,generated_image,inversed_result['real'])
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            loss = self.cacl_loss(self.percept,generated_image,real_img)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            pbar.set_description(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    f"loss: {loss.item():.4f}"
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |