mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
import torch
|
|
from inversion import inverse_image,get_lr
|
|
|
|
from tqdm import tqdm
|
|
from torch.nn import functional as F
|
|
from lpips import util
|
|
def toogle_grad(model, flag=True):
|
|
for p in model.parameters():
|
|
p.requires_grad = flag
|
|
|
|
|
|
class PTI:
|
|
def __init__(self,G,l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
|
|
self.g_ema = G
|
|
self.l2_lambda = l2_lambda
|
|
self.max_pti_step = max_pti_step
|
|
self.pti_lr = pti_lr
|
|
def cacl_loss(self,percept, generated_image,real_image):
|
|
|
|
mse_loss = F.mse_loss(generated_image, real_image)
|
|
p_loss = percept(generated_image, real_image).sum()
|
|
loss = p_loss +self.l2_lambda * mse_loss
|
|
return loss
|
|
|
|
def train(self,img):
|
|
inversed_result = inverse_image(self.g_ema,img,self.g_ema.img_resolution)
|
|
w_pivot = inversed_result['latent']
|
|
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
|
|
toogle_grad(self.g_ema,True)
|
|
percept = util.PerceptualLoss(
|
|
model="net-lin", net="vgg", use_gpu='cuda:0'
|
|
)
|
|
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
|
|
print('start PTI')
|
|
pbar = tqdm(range(self.max_pti_step))
|
|
for i in pbar:
|
|
lr = get_lr(i, self.pti_lr)
|
|
optimizer.param_groups[0]["lr"] = lr
|
|
|
|
generated_image,feature = self.g_ema.synthesis(ws,noise_mode='const')
|
|
loss = self.cacl_loss(percept,generated_image,inversed_result['real'])
|
|
pbar.set_description(
|
|
(
|
|
f"loss: {loss.item():.4f}"
|
|
)
|
|
)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
with torch.no_grad():
|
|
generated_image = self.g_ema.synthesis(ws, noise_mode='const')
|
|
|
|
return generated_image
|