mirror of https://github.com/XingangPan/DragGAN
				
				
				
			Merge caeb80f324 into 336f120ce1
				
					
				
			
						commit
						417058bd37
					
				@ -0,0 +1,53 @@
 | 
			
		||||
import torch
 | 
			
		||||
from inversion import  inverse_image,get_lr
 | 
			
		||||
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from lpips import util
 | 
			
		||||
def toogle_grad(model, flag=True):
 | 
			
		||||
    for p in model.parameters():
 | 
			
		||||
        p.requires_grad = flag
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PTI:
 | 
			
		||||
    def __init__(self,G,l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
 | 
			
		||||
        self.g_ema = G
 | 
			
		||||
        self.l2_lambda = l2_lambda
 | 
			
		||||
        self.max_pti_step = max_pti_step
 | 
			
		||||
        self.pti_lr = pti_lr
 | 
			
		||||
    def cacl_loss(self,percept, generated_image,real_image):
 | 
			
		||||
 | 
			
		||||
        mse_loss = F.mse_loss(generated_image, real_image)
 | 
			
		||||
        p_loss = percept(generated_image, real_image).sum()
 | 
			
		||||
        loss = p_loss +self.l2_lambda * mse_loss
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def train(self,img):
 | 
			
		||||
        inversed_result = inverse_image(self.g_ema,img,self.g_ema.img_resolution)
 | 
			
		||||
        w_pivot = inversed_result['latent']
 | 
			
		||||
        ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
 | 
			
		||||
        toogle_grad(self.g_ema,True)
 | 
			
		||||
        percept = util.PerceptualLoss(
 | 
			
		||||
            model="net-lin", net="vgg", use_gpu='cuda:0'
 | 
			
		||||
        )
 | 
			
		||||
        optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
 | 
			
		||||
        print('start PTI')
 | 
			
		||||
        pbar = tqdm(range(self.max_pti_step))
 | 
			
		||||
        for i in pbar:
 | 
			
		||||
            lr = get_lr(i, self.pti_lr)
 | 
			
		||||
            optimizer.param_groups[0]["lr"] = lr
 | 
			
		||||
 | 
			
		||||
            generated_image,feature = self.g_ema.synthesis(ws,noise_mode='const')
 | 
			
		||||
            loss = self.cacl_loss(percept,generated_image,inversed_result['real'])
 | 
			
		||||
            pbar.set_description(
 | 
			
		||||
                (
 | 
			
		||||
                    f"loss: {loss.item():.4f}"
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            generated_image = self.g_ema.synthesis(ws, noise_mode='const')
 | 
			
		||||
 | 
			
		||||
        return generated_image
 | 
			
		||||
@ -0,0 +1,9 @@
 | 
			
		||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
 | 
			
		||||
# and proprietary rights in and to this software, related documentation
 | 
			
		||||
# and any modifications thereto.  Any use, reproduction, disclosure or
 | 
			
		||||
# distribution of this software and related documentation without an express
 | 
			
		||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
 | 
			
		||||
 | 
			
		||||
# empty
 | 
			
		||||
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								@ -0,0 +1,343 @@
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
from viz import renderer
 | 
			
		||||
import torch
 | 
			
		||||
from torch import optim
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import dataclasses
 | 
			
		||||
import dnnlib
 | 
			
		||||
from .lpips import util
 | 
			
		||||
import imageio
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
 | 
			
		||||
    lr_ramp = min(1, (1 - t) / rampdown)
 | 
			
		||||
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
 | 
			
		||||
    lr_ramp = lr_ramp * min(1, t / rampup)
 | 
			
		||||
 | 
			
		||||
    return initial_lr * lr_ramp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_image(tensor):
 | 
			
		||||
    return (
 | 
			
		||||
        tensor.detach()
 | 
			
		||||
        .clamp_(min=-1, max=1)
 | 
			
		||||
        .add(1)
 | 
			
		||||
        .div_(2)
 | 
			
		||||
        .mul(255)
 | 
			
		||||
        .type(torch.uint8)
 | 
			
		||||
        .permute(0, 2, 3, 1)
 | 
			
		||||
        .to("cpu")
 | 
			
		||||
        .numpy()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class InverseConfig:
 | 
			
		||||
    lr_warmup = 0.05
 | 
			
		||||
    lr_decay = 0.25
 | 
			
		||||
    lr = 0.1
 | 
			
		||||
    noise = 0.05
 | 
			
		||||
    noise_decay = 0.75
 | 
			
		||||
    step = 1000
 | 
			
		||||
    noise_regularize = 1e5
 | 
			
		||||
    mse = 0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def inverse_image(
 | 
			
		||||
    g_ema,
 | 
			
		||||
    image,
 | 
			
		||||
    percept,
 | 
			
		||||
    image_size=256,
 | 
			
		||||
    w_plus = False,
 | 
			
		||||
    config=InverseConfig(),
 | 
			
		||||
    device='cuda:0'
 | 
			
		||||
):
 | 
			
		||||
    args = config
 | 
			
		||||
 | 
			
		||||
    n_mean_latent = 10000
 | 
			
		||||
 | 
			
		||||
    resize = min(image_size, 256)
 | 
			
		||||
 | 
			
		||||
    if torch.is_tensor(image)==False:
 | 
			
		||||
        transform = transforms.Compose(
 | 
			
		||||
            [
 | 
			
		||||
                transforms.Resize(resize,),
 | 
			
		||||
                transforms.CenterCrop(resize),
 | 
			
		||||
                transforms.ToTensor(),
 | 
			
		||||
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        img = transform(image)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        img = transforms.functional.resize(image,resize)
 | 
			
		||||
        transform = transforms.Compose(
 | 
			
		||||
            [
 | 
			
		||||
                transforms.CenterCrop(resize),
 | 
			
		||||
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        img = transform(img)
 | 
			
		||||
    imgs = []
 | 
			
		||||
    imgs.append(img)
 | 
			
		||||
    imgs = torch.stack(imgs, 0).to(device)
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
 | 
			
		||||
        #noise_sample = torch.randn(n_mean_latent, 512, device=device)
 | 
			
		||||
        noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device)
 | 
			
		||||
        #label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device)
 | 
			
		||||
        w_samples = g_ema.mapping(noise_sample,None)
 | 
			
		||||
        w_samples = w_samples[:, :1, :]
 | 
			
		||||
        w_avg = w_samples.mean(0)
 | 
			
		||||
        w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name}
 | 
			
		||||
    for noise in noises.values():
 | 
			
		||||
        noise = torch.randn_like(noise)
 | 
			
		||||
        noise.requires_grad = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    w_opt = w_avg.detach().clone()
 | 
			
		||||
    if w_plus:
 | 
			
		||||
        w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1)
 | 
			
		||||
    w_opt.requires_grad = True
 | 
			
		||||
    #if args.w_plus:
 | 
			
		||||
        #latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr)
 | 
			
		||||
 | 
			
		||||
    pbar = tqdm(range(args.step))
 | 
			
		||||
    latent_path = []
 | 
			
		||||
 | 
			
		||||
    for i in pbar:
 | 
			
		||||
        t = i / args.step
 | 
			
		||||
        lr = get_lr(t, args.lr)
 | 
			
		||||
        optimizer.param_groups[0]["lr"] = lr
 | 
			
		||||
        noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
 | 
			
		||||
 | 
			
		||||
        w_noise = torch.randn_like(w_opt) * noise_strength
 | 
			
		||||
        if w_plus:
 | 
			
		||||
            ws = w_opt + w_noise
 | 
			
		||||
        else:
 | 
			
		||||
            ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1])
 | 
			
		||||
 | 
			
		||||
        img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True)
 | 
			
		||||
 | 
			
		||||
        #latent_n = latent_noise(latent_in, noise_strength.item())
 | 
			
		||||
 | 
			
		||||
        #latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
 | 
			
		||||
        #img_gen, F = g_ema.generate(latent, noise)
 | 
			
		||||
 | 
			
		||||
        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
 | 
			
		||||
        
 | 
			
		||||
        if img_gen.shape[2] > 256:
 | 
			
		||||
            img_gen = F.interpolate(img_gen, size=(256, 256), mode='area')
 | 
			
		||||
 | 
			
		||||
        p_loss = percept(img_gen,imgs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # Noise regularization.
 | 
			
		||||
        reg_loss = 0.0
 | 
			
		||||
        for v in noises.values():
 | 
			
		||||
            noise = v[None, None, :, :]  # must be [1,1,H,W] for F.avg_pool2d()
 | 
			
		||||
            while True:
 | 
			
		||||
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
 | 
			
		||||
                reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
 | 
			
		||||
                if noise.shape[2] <= 8:
 | 
			
		||||
                    break
 | 
			
		||||
                noise = F.avg_pool2d(noise, kernel_size=2)
 | 
			
		||||
        mse_loss = F.mse_loss(img_gen, imgs)
 | 
			
		||||
 | 
			
		||||
        loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss
 | 
			
		||||
 | 
			
		||||
        optimizer.zero_grad()
 | 
			
		||||
        loss.backward()
 | 
			
		||||
        optimizer.step()
 | 
			
		||||
 | 
			
		||||
        # Normalize noise.
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for buf in noises.values():
 | 
			
		||||
                buf -= buf.mean()
 | 
			
		||||
                buf *= buf.square().mean().rsqrt()
 | 
			
		||||
 | 
			
		||||
        if (i + 1) % 100 == 0:
 | 
			
		||||
            latent_path.append(w_opt.detach().clone())
 | 
			
		||||
 | 
			
		||||
        pbar.set_description(
 | 
			
		||||
            (
 | 
			
		||||
                f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};"
 | 
			
		||||
                f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    #latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
 | 
			
		||||
    #img_gen, F = g_ema.generate(latent, noise)
 | 
			
		||||
    if w_plus:
 | 
			
		||||
        ws = latent_path[-1]
 | 
			
		||||
    else:
 | 
			
		||||
        ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1])
 | 
			
		||||
 | 
			
		||||
    img_gen = g_ema.synthesis(ws, noise_mode='const')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    result = {
 | 
			
		||||
        "latent": latent_path[-1],
 | 
			
		||||
        "sample": img_gen,
 | 
			
		||||
        "real": imgs,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def toogle_grad(model, flag=True):
 | 
			
		||||
    for p in model.parameters():
 | 
			
		||||
        p.requires_grad = flag
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PTI:
 | 
			
		||||
    def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
 | 
			
		||||
        self.g_ema = G
 | 
			
		||||
        self.l2_lambda = l2_lambda
 | 
			
		||||
        self.max_pti_step = max_pti_step
 | 
			
		||||
        self.pti_lr = pti_lr
 | 
			
		||||
        self.percept = percept
 | 
			
		||||
    def cacl_loss(self,percept, generated_image,real_image):
 | 
			
		||||
 | 
			
		||||
        mse_loss = F.mse_loss(generated_image, real_image)
 | 
			
		||||
        p_loss = percept(generated_image, real_image).sum()
 | 
			
		||||
        loss = p_loss +self.l2_lambda * mse_loss
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def train(self,img,w_plus=False):
 | 
			
		||||
        if torch.is_tensor(img) == False:
 | 
			
		||||
            transform = transforms.Compose(
 | 
			
		||||
                [
 | 
			
		||||
                    transforms.Resize(self.g_ema.img_resolution, ),
 | 
			
		||||
                    transforms.CenterCrop(self.g_ema.img_resolution),
 | 
			
		||||
                    transforms.ToTensor(),
 | 
			
		||||
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            real_img = transform(img).to('cuda').unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            img = transforms.functional.resize(img, self.g_ema.img_resolution)
 | 
			
		||||
            transform = transforms.Compose(
 | 
			
		||||
                [
 | 
			
		||||
                    transforms.CenterCrop(self.g_ema.img_resolution),
 | 
			
		||||
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            real_img = transform(img).to('cuda').unsqueeze(0)
 | 
			
		||||
        inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus)
 | 
			
		||||
        w_pivot = inversed_result['latent']
 | 
			
		||||
        if w_plus:
 | 
			
		||||
            ws = w_pivot
 | 
			
		||||
        else:
 | 
			
		||||
            ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
 | 
			
		||||
        toogle_grad(self.g_ema,True)
 | 
			
		||||
        optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
 | 
			
		||||
        print('start PTI')
 | 
			
		||||
        pbar = tqdm(range(self.max_pti_step))
 | 
			
		||||
        for i in pbar:
 | 
			
		||||
            t = i / self.max_pti_step
 | 
			
		||||
            lr = get_lr(t, self.pti_lr)
 | 
			
		||||
            optimizer.param_groups[0]["lr"] = lr
 | 
			
		||||
 | 
			
		||||
            generated_image = self.g_ema.synthesis(ws,noise_mode='const')
 | 
			
		||||
            loss = self.cacl_loss(self.percept,generated_image,real_img)
 | 
			
		||||
            pbar.set_description(
 | 
			
		||||
                (
 | 
			
		||||
                    f"loss: {loss.item():.4f}"
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            generated_image = self.g_ema.synthesis(ws, noise_mode='const')
 | 
			
		||||
 | 
			
		||||
        return generated_image,ws
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    state = {
 | 
			
		||||
        "images": {
 | 
			
		||||
            # image_orig: the original image, change with seed/model is changed
 | 
			
		||||
            # image_raw: image with mask and points, change durning optimization
 | 
			
		||||
            # image_show: image showed on screen
 | 
			
		||||
        },
 | 
			
		||||
        "temporal_params": {
 | 
			
		||||
            # stop
 | 
			
		||||
        },
 | 
			
		||||
        'mask':
 | 
			
		||||
            None,  # mask for visualization, 1 for editing and 0 for unchange
 | 
			
		||||
        'last_mask': None,  # last edited mask
 | 
			
		||||
        'show_mask': True,  # add button
 | 
			
		||||
        "generator_params": dnnlib.EasyDict(),
 | 
			
		||||
        "params": {
 | 
			
		||||
            "seed": 0,
 | 
			
		||||
            "motion_lambda": 20,
 | 
			
		||||
            "r1_in_pixels": 3,
 | 
			
		||||
            "r2_in_pixels": 12,
 | 
			
		||||
            "magnitude_direction_in_pixels": 1.0,
 | 
			
		||||
            "latent_space": "w+",
 | 
			
		||||
            "trunc_psi": 0.7,
 | 
			
		||||
            "trunc_cutoff": None,
 | 
			
		||||
            "lr": 0.001,
 | 
			
		||||
        },
 | 
			
		||||
        "device": 'cuda:0',
 | 
			
		||||
        "draw_interval": 1,
 | 
			
		||||
        "renderer": renderer.Renderer(disable_timing=True),
 | 
			
		||||
        "points": {},
 | 
			
		||||
        "curr_point": None,
 | 
			
		||||
        "curr_type_point": "start",
 | 
			
		||||
        'editing_state': 'add_points',
 | 
			
		||||
        'pretrained_weight': 'stylegan2_horses_256_pytorch'
 | 
			
		||||
    }
 | 
			
		||||
    cache_dir = '../checkpoints'
 | 
			
		||||
    valid_checkpoints_dict = {
 | 
			
		||||
        f.split('/')[-1].split('.')[0]: os.path.join(cache_dir, f)
 | 
			
		||||
        for f in os.listdir(cache_dir)
 | 
			
		||||
        if (f.endswith('pkl') and os.path.exists(os.path.join(cache_dir, f)))
 | 
			
		||||
    }
 | 
			
		||||
    state['renderer'].init_network(state['generator_params'],  # res
 | 
			
		||||
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
 | 
			
		||||
        state['params']['seed'],  # w0_seed,
 | 
			
		||||
        None,  # w_load
 | 
			
		||||
        state['params']['latent_space'] == 'w+',  # w_plus
 | 
			
		||||
        'const',
 | 
			
		||||
        state['params']['trunc_psi'],  # trunc_psi,
 | 
			
		||||
        state['params']['trunc_cutoff'],  # trunc_cutoff,
 | 
			
		||||
        None,  # input_transform
 | 
			
		||||
        state['params']['lr']  # lr
 | 
			
		||||
    )
 | 
			
		||||
    image = Image.open('/home/tianhao/research/drag3d/horse/render/0.png')
 | 
			
		||||
    G = state['renderer'].G
 | 
			
		||||
    #result = inverse_image(G,image,G.img_resolution)
 | 
			
		||||
    percept = util.PerceptualLoss(
 | 
			
		||||
        model="net-lin", net="vgg", use_gpu=True
 | 
			
		||||
    )
 | 
			
		||||
    pti = PTI(G,percept)
 | 
			
		||||
    result = pti.train(image,True)
 | 
			
		||||
    imageio.imsave('../horse/test.png', make_image(result[0])[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,5 @@
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								
											
												Binary file not shown.
											
										
									
								@ -0,0 +1,58 @@
 | 
			
		||||
import os
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch.autograd import Variable
 | 
			
		||||
from pdb import set_trace as st
 | 
			
		||||
from IPython import embed
 | 
			
		||||
 | 
			
		||||
class BaseModel():
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        pass;
 | 
			
		||||
        
 | 
			
		||||
    def name(self):
 | 
			
		||||
        return 'BaseModel'
 | 
			
		||||
 | 
			
		||||
    def initialize(self, use_gpu=True, gpu_ids=[0]):
 | 
			
		||||
        self.use_gpu = use_gpu
 | 
			
		||||
        self.gpu_ids = gpu_ids
 | 
			
		||||
 | 
			
		||||
    def forward(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_image_paths(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def optimize_parameters(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_current_visuals(self):
 | 
			
		||||
        return self.input
 | 
			
		||||
 | 
			
		||||
    def get_current_errors(self):
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    def save(self, label):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # helper saving function that can be used by subclasses
 | 
			
		||||
    def save_network(self, network, path, network_label, epoch_label):
 | 
			
		||||
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
 | 
			
		||||
        save_path = os.path.join(path, save_filename)
 | 
			
		||||
        torch.save(network.state_dict(), save_path)
 | 
			
		||||
 | 
			
		||||
    # helper loading function that can be used by subclasses
 | 
			
		||||
    def load_network(self, network, network_label, epoch_label):
 | 
			
		||||
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
 | 
			
		||||
        save_path = os.path.join(self.save_dir, save_filename)
 | 
			
		||||
        print('Loading network from %s'%save_path)
 | 
			
		||||
        network.load_state_dict(torch.load(save_path))
 | 
			
		||||
 | 
			
		||||
    def update_learning_rate():
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_image_paths(self):
 | 
			
		||||
        return self.image_paths
 | 
			
		||||
 | 
			
		||||
    def save_done(self, flag=False):
 | 
			
		||||
        np.save(os.path.join(self.save_dir, 'done_flag'),flag)
 | 
			
		||||
        np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
 | 
			
		||||
@ -0,0 +1,314 @@
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
import os
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from torch.autograd import Variable
 | 
			
		||||
import itertools
 | 
			
		||||
from .base_model import BaseModel
 | 
			
		||||
from scipy.ndimage import zoom
 | 
			
		||||
import fractions
 | 
			
		||||
import functools
 | 
			
		||||
import skimage.transform
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import urllib
 | 
			
		||||
 | 
			
		||||
from IPython import embed
 | 
			
		||||
 | 
			
		||||
from . import networks_basic as networks
 | 
			
		||||
from . import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DownloadProgressBar(tqdm):
 | 
			
		||||
    def update_to(self, b=1, bsize=1, tsize=None):
 | 
			
		||||
        if tsize is not None:
 | 
			
		||||
            self.total = tsize
 | 
			
		||||
        self.update(b * bsize - self.n)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_path(base_path):
 | 
			
		||||
    BASE_DIR = os.path.join('checkpoints')
 | 
			
		||||
 | 
			
		||||
    save_path = os.path.join(BASE_DIR, base_path)
 | 
			
		||||
    if not os.path.exists(save_path):
 | 
			
		||||
        url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
 | 
			
		||||
        print(f'{base_path} not found')
 | 
			
		||||
        print('Try to download from huggingface: ', url)
 | 
			
		||||
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
 | 
			
		||||
        download_url(url, save_path)
 | 
			
		||||
        print('Downloaded to ', save_path)
 | 
			
		||||
    return save_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def download_url(url, output_path):
 | 
			
		||||
    with DownloadProgressBar(unit='B', unit_scale=True,
 | 
			
		||||
                             miniters=1, desc=url.split('/')[-1]) as t:
 | 
			
		||||
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistModel(BaseModel):
 | 
			
		||||
    def name(self):
 | 
			
		||||
        return self.model_name
 | 
			
		||||
 | 
			
		||||
    def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
 | 
			
		||||
                   use_gpu=True, printNet=False, spatial=False,
 | 
			
		||||
                   is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
 | 
			
		||||
        '''
 | 
			
		||||
        INPUTS
 | 
			
		||||
            model - ['net-lin'] for linearly calibrated network
 | 
			
		||||
                    ['net'] for off-the-shelf network
 | 
			
		||||
                    ['L2'] for L2 distance in Lab colorspace
 | 
			
		||||
                    ['SSIM'] for ssim in RGB colorspace
 | 
			
		||||
            net - ['squeeze','alex','vgg']
 | 
			
		||||
            model_path - if None, will look in weights/[NET_NAME].pth
 | 
			
		||||
            colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
 | 
			
		||||
            use_gpu - bool - whether or not to use a GPU
 | 
			
		||||
            printNet - bool - whether or not to print network architecture out
 | 
			
		||||
            spatial - bool - whether to output an array containing varying distances across spatial dimensions
 | 
			
		||||
            spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
 | 
			
		||||
            spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
 | 
			
		||||
            spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
 | 
			
		||||
            is_train - bool - [True] for training mode
 | 
			
		||||
            lr - float - initial learning rate
 | 
			
		||||
            beta1 - float - initial momentum term for adam
 | 
			
		||||
            version - 0.1 for latest, 0.0 was original (with a bug)
 | 
			
		||||
            gpu_ids - int array - [0] by default, gpus to use
 | 
			
		||||
        '''
 | 
			
		||||
        BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
 | 
			
		||||
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.net = net
 | 
			
		||||
        self.is_train = is_train
 | 
			
		||||
        self.spatial = spatial
 | 
			
		||||
        self.gpu_ids = gpu_ids
 | 
			
		||||
        self.model_name = '%s [%s]' % (model, net)
 | 
			
		||||
 | 
			
		||||
        if(self.model == 'net-lin'):  # pretrained net + linear layer
 | 
			
		||||
            self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
 | 
			
		||||
                                        use_dropout=True, spatial=spatial, version=version, lpips=True)
 | 
			
		||||
            kw = {}
 | 
			
		||||
            if not use_gpu:
 | 
			
		||||
                kw['map_location'] = 'cpu'
 | 
			
		||||
            if(model_path is None):
 | 
			
		||||
                model_path = get_path('weights/v%s/%s.pth' % (version, net))
 | 
			
		||||
 | 
			
		||||
            if(not is_train):
 | 
			
		||||
                print('Loading model from: %s' % model_path)
 | 
			
		||||
                self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
 | 
			
		||||
 | 
			
		||||
        elif(self.model == 'net'):  # pretrained network
 | 
			
		||||
            self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
 | 
			
		||||
        elif(self.model in ['L2', 'l2']):
 | 
			
		||||
            self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace)  # not really a network, only for testing
 | 
			
		||||
            self.model_name = 'L2'
 | 
			
		||||
        elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
 | 
			
		||||
            self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
 | 
			
		||||
            self.model_name = 'SSIM'
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Model [%s] not recognized." % self.model)
 | 
			
		||||
 | 
			
		||||
        self.parameters = list(self.net.parameters())
 | 
			
		||||
 | 
			
		||||
        if self.is_train:  # training mode
 | 
			
		||||
            # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
 | 
			
		||||
            self.rankLoss = networks.BCERankingLoss()
 | 
			
		||||
            self.parameters += list(self.rankLoss.net.parameters())
 | 
			
		||||
            self.lr = lr
 | 
			
		||||
            self.old_lr = lr
 | 
			
		||||
            self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
 | 
			
		||||
        else:  # test mode
 | 
			
		||||
            self.net.eval()
 | 
			
		||||
 | 
			
		||||
        if(use_gpu):
 | 
			
		||||
            self.net.to(gpu_ids[0])
 | 
			
		||||
            self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
 | 
			
		||||
            if(self.is_train):
 | 
			
		||||
                self.rankLoss = self.rankLoss.to(device=gpu_ids[0])  # just put this on GPU0
 | 
			
		||||
 | 
			
		||||
        if(printNet):
 | 
			
		||||
            print('---------- Networks initialized -------------')
 | 
			
		||||
            networks.print_network(self.net)
 | 
			
		||||
            print('-----------------------------------------------')
 | 
			
		||||
 | 
			
		||||
    def forward(self, in0, in1, retPerLayer=False):
 | 
			
		||||
        ''' Function computes the distance between image patches in0 and in1
 | 
			
		||||
        INPUTS
 | 
			
		||||
            in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
 | 
			
		||||
        OUTPUT
 | 
			
		||||
            computed distances between in0 and in1
 | 
			
		||||
        '''
 | 
			
		||||
 | 
			
		||||
        return self.net.forward(in0, in1, retPerLayer=retPerLayer)
 | 
			
		||||
 | 
			
		||||
    # ***** TRAINING FUNCTIONS *****
 | 
			
		||||
    def optimize_parameters(self):
 | 
			
		||||
        self.forward_train()
 | 
			
		||||
        self.optimizer_net.zero_grad()
 | 
			
		||||
        self.backward_train()
 | 
			
		||||
        self.optimizer_net.step()
 | 
			
		||||
        self.clamp_weights()
 | 
			
		||||
 | 
			
		||||
    def clamp_weights(self):
 | 
			
		||||
        for module in self.net.modules():
 | 
			
		||||
            if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
 | 
			
		||||
                module.weight.data = torch.clamp(module.weight.data, min=0)
 | 
			
		||||
 | 
			
		||||
    def set_input(self, data):
 | 
			
		||||
        self.input_ref = data['ref']
 | 
			
		||||
        self.input_p0 = data['p0']
 | 
			
		||||
        self.input_p1 = data['p1']
 | 
			
		||||
        self.input_judge = data['judge']
 | 
			
		||||
 | 
			
		||||
        if(self.use_gpu):
 | 
			
		||||
            self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
 | 
			
		||||
            self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
 | 
			
		||||
            self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
 | 
			
		||||
            self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
 | 
			
		||||
 | 
			
		||||
        self.var_ref = Variable(self.input_ref, requires_grad=True)
 | 
			
		||||
        self.var_p0 = Variable(self.input_p0, requires_grad=True)
 | 
			
		||||
        self.var_p1 = Variable(self.input_p1, requires_grad=True)
 | 
			
		||||
 | 
			
		||||
    def forward_train(self):  # run forward pass
 | 
			
		||||
        # print(self.net.module.scaling_layer.shift)
 | 
			
		||||
        # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
 | 
			
		||||
 | 
			
		||||
        self.d0 = self.forward(self.var_ref, self.var_p0)
 | 
			
		||||
        self.d1 = self.forward(self.var_ref, self.var_p1)
 | 
			
		||||
        self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
 | 
			
		||||
 | 
			
		||||
        self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
 | 
			
		||||
 | 
			
		||||
        self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
 | 
			
		||||
 | 
			
		||||
        return self.loss_total
 | 
			
		||||
 | 
			
		||||
    def backward_train(self):
 | 
			
		||||
        torch.mean(self.loss_total).backward()
 | 
			
		||||
 | 
			
		||||
    def compute_accuracy(self, d0, d1, judge):
 | 
			
		||||
        ''' d0, d1 are Variables, judge is a Tensor '''
 | 
			
		||||
        d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
 | 
			
		||||
        judge_per = judge.cpu().numpy().flatten()
 | 
			
		||||
        return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
 | 
			
		||||
 | 
			
		||||
    def get_current_errors(self):
 | 
			
		||||
        retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
 | 
			
		||||
                               ('acc_r', self.acc_r)])
 | 
			
		||||
 | 
			
		||||
        for key in retDict.keys():
 | 
			
		||||
            retDict[key] = np.mean(retDict[key])
 | 
			
		||||
 | 
			
		||||
        return retDict
 | 
			
		||||
 | 
			
		||||
    def get_current_visuals(self):
 | 
			
		||||
        zoom_factor = 256 / self.var_ref.data.size()[2]
 | 
			
		||||
 | 
			
		||||
        ref_img = util.tensor2im(self.var_ref.data)
 | 
			
		||||
        p0_img = util.tensor2im(self.var_p0.data)
 | 
			
		||||
        p1_img = util.tensor2im(self.var_p1.data)
 | 
			
		||||
 | 
			
		||||
        ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
 | 
			
		||||
        p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
 | 
			
		||||
        p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
 | 
			
		||||
 | 
			
		||||
        return OrderedDict([('ref', ref_img_vis),
 | 
			
		||||
                            ('p0', p0_img_vis),
 | 
			
		||||
                            ('p1', p1_img_vis)])
 | 
			
		||||
 | 
			
		||||
    def save(self, path, label):
 | 
			
		||||
        if(self.use_gpu):
 | 
			
		||||
            self.save_network(self.net.module, path, '', label)
 | 
			
		||||
        else:
 | 
			
		||||
            self.save_network(self.net, path, '', label)
 | 
			
		||||
        self.save_network(self.rankLoss.net, path, 'rank', label)
 | 
			
		||||
 | 
			
		||||
    def update_learning_rate(self, nepoch_decay):
 | 
			
		||||
        lrd = self.lr / nepoch_decay
 | 
			
		||||
        lr = self.old_lr - lrd
 | 
			
		||||
 | 
			
		||||
        for param_group in self.optimizer_net.param_groups:
 | 
			
		||||
            param_group['lr'] = lr
 | 
			
		||||
 | 
			
		||||
        print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
 | 
			
		||||
        self.old_lr = lr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def score_2afc_dataset(data_loader, func, name=''):
 | 
			
		||||
    ''' Function computes Two Alternative Forced Choice (2AFC) score using
 | 
			
		||||
        distance function 'func' in dataset 'data_loader'
 | 
			
		||||
    INPUTS
 | 
			
		||||
        data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
 | 
			
		||||
        func - callable distance function - calling d=func(in0,in1) should take 2
 | 
			
		||||
            pytorch tensors with shape Nx3xXxY, and return numpy array of length N
 | 
			
		||||
    OUTPUTS
 | 
			
		||||
        [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
 | 
			
		||||
        [1] - dictionary with following elements
 | 
			
		||||
            d0s,d1s - N arrays containing distances between reference patch to perturbed patches 
 | 
			
		||||
            gts - N array in [0,1], preferred patch selected by human evaluators
 | 
			
		||||
                (closer to "0" for left patch p0, "1" for right patch p1,
 | 
			
		||||
                "0.6" means 60pct people preferred right patch, 40pct preferred left)
 | 
			
		||||
            scores - N array in [0,1], corresponding to what percentage function agreed with humans
 | 
			
		||||
    CONSTS
 | 
			
		||||
        N - number of test triplets in data_loader
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    d0s = []
 | 
			
		||||
    d1s = []
 | 
			
		||||
    gts = []
 | 
			
		||||
 | 
			
		||||
    for data in tqdm(data_loader.load_data(), desc=name):
 | 
			
		||||
        d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
 | 
			
		||||
        d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
 | 
			
		||||
        gts += data['judge'].cpu().numpy().flatten().tolist()
 | 
			
		||||
 | 
			
		||||
    d0s = np.array(d0s)
 | 
			
		||||
    d1s = np.array(d1s)
 | 
			
		||||
    gts = np.array(gts)
 | 
			
		||||
    scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
 | 
			
		||||
 | 
			
		||||
    return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def score_jnd_dataset(data_loader, func, name=''):
 | 
			
		||||
    ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
 | 
			
		||||
    INPUTS
 | 
			
		||||
        data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
 | 
			
		||||
        func - callable distance function - calling d=func(in0,in1) should take 2
 | 
			
		||||
            pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
 | 
			
		||||
    OUTPUTS
 | 
			
		||||
        [0] - JND score in [0,1], mAP score (area under precision-recall curve)
 | 
			
		||||
        [1] - dictionary with following elements
 | 
			
		||||
            ds - N array containing distances between two patches shown to human evaluator
 | 
			
		||||
            sames - N array containing fraction of people who thought the two patches were identical
 | 
			
		||||
    CONSTS
 | 
			
		||||
        N - number of test triplets in data_loader
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    ds = []
 | 
			
		||||
    gts = []
 | 
			
		||||
 | 
			
		||||
    for data in tqdm(data_loader.load_data(), desc=name):
 | 
			
		||||
        ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
 | 
			
		||||
        gts += data['same'].cpu().numpy().flatten().tolist()
 | 
			
		||||
 | 
			
		||||
    sames = np.array(gts)
 | 
			
		||||
    ds = np.array(ds)
 | 
			
		||||
 | 
			
		||||
    sorted_inds = np.argsort(ds)
 | 
			
		||||
    ds_sorted = ds[sorted_inds]
 | 
			
		||||
    sames_sorted = sames[sorted_inds]
 | 
			
		||||
 | 
			
		||||
    TPs = np.cumsum(sames_sorted)
 | 
			
		||||
    FPs = np.cumsum(1 - sames_sorted)
 | 
			
		||||
    FNs = np.sum(sames_sorted) - TPs
 | 
			
		||||
 | 
			
		||||
    precs = TPs / (TPs + FPs)
 | 
			
		||||
    recs = TPs / (TPs + FNs)
 | 
			
		||||
    score = util.voc_ap(recs, precs)
 | 
			
		||||
 | 
			
		||||
    return(score, dict(ds=ds, sames=sames))
 | 
			
		||||
@ -0,0 +1,188 @@
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.init as init
 | 
			
		||||
from torch.autograd import Variable
 | 
			
		||||
import numpy as np
 | 
			
		||||
from pdb import set_trace as st
 | 
			
		||||
from skimage import color
 | 
			
		||||
from IPython import embed
 | 
			
		||||
from . import pretrained_networks as pn
 | 
			
		||||
 | 
			
		||||
from . import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def spatial_average(in_tens, keepdim=True):
 | 
			
		||||
    return in_tens.mean([2,3],keepdim=keepdim)
 | 
			
		||||
 | 
			
		||||
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
 | 
			
		||||
    in_H = in_tens.shape[2]
 | 
			
		||||
    scale_factor = 1.*out_H/in_H
 | 
			
		||||
 | 
			
		||||
    return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
 | 
			
		||||
 | 
			
		||||
# Learned perceptual metric
 | 
			
		||||
class PNetLin(nn.Module):
 | 
			
		||||
    def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
 | 
			
		||||
        super(PNetLin, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.pnet_type = pnet_type
 | 
			
		||||
        self.pnet_tune = pnet_tune
 | 
			
		||||
        self.pnet_rand = pnet_rand
 | 
			
		||||
        self.spatial = spatial
 | 
			
		||||
        self.lpips = lpips
 | 
			
		||||
        self.version = version
 | 
			
		||||
        self.scaling_layer = ScalingLayer()
 | 
			
		||||
 | 
			
		||||
        if(self.pnet_type in ['vgg','vgg16']):
 | 
			
		||||
            net_type = pn.vgg16
 | 
			
		||||
            self.chns = [64,128,256,512,512]
 | 
			
		||||
        elif(self.pnet_type=='alex'):
 | 
			
		||||
            net_type = pn.alexnet
 | 
			
		||||
            self.chns = [64,192,384,256,256]
 | 
			
		||||
        elif(self.pnet_type=='squeeze'):
 | 
			
		||||
            net_type = pn.squeezenet
 | 
			
		||||
            self.chns = [64,128,256,384,384,512,512]
 | 
			
		||||
        self.L = len(self.chns)
 | 
			
		||||
 | 
			
		||||
        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
 | 
			
		||||
 | 
			
		||||
        if(lpips):
 | 
			
		||||
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
 | 
			
		||||
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
 | 
			
		||||
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
 | 
			
		||||
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
 | 
			
		||||
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
 | 
			
		||||
            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
 | 
			
		||||
            if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
 | 
			
		||||
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
 | 
			
		||||
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
 | 
			
		||||
                self.lins+=[self.lin5,self.lin6]
 | 
			
		||||
 | 
			
		||||
    def forward(self, in0, in1, retPerLayer=False):
 | 
			
		||||
        # v0.0 - original release had a bug, where input was not scaled
 | 
			
		||||
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
 | 
			
		||||
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
 | 
			
		||||
        feats0, feats1, diffs = {}, {}, {}
 | 
			
		||||
 | 
			
		||||
        for kk in range(self.L):
 | 
			
		||||
            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
 | 
			
		||||
            diffs[kk] = (feats0[kk]-feats1[kk])**2
 | 
			
		||||
 | 
			
		||||
        if(self.lpips):
 | 
			
		||||
            if(self.spatial):
 | 
			
		||||
                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
 | 
			
		||||
            else:
 | 
			
		||||
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
 | 
			
		||||
        else:
 | 
			
		||||
            if(self.spatial):
 | 
			
		||||
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
 | 
			
		||||
            else:
 | 
			
		||||
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
 | 
			
		||||
 | 
			
		||||
        val = res[0]
 | 
			
		||||
        for l in range(1,self.L):
 | 
			
		||||
            val += res[l]
 | 
			
		||||
        
 | 
			
		||||
        if(retPerLayer):
 | 
			
		||||
            return (val, res)
 | 
			
		||||
        else:
 | 
			
		||||
            return val
 | 
			
		||||
 | 
			
		||||
class ScalingLayer(nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(ScalingLayer, self).__init__()
 | 
			
		||||
        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
 | 
			
		||||
        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
 | 
			
		||||
 | 
			
		||||
    def forward(self, inp):
 | 
			
		||||
        return (inp - self.shift) / self.scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NetLinLayer(nn.Module):
 | 
			
		||||
    ''' A single linear layer which does a 1x1 conv '''
 | 
			
		||||
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
 | 
			
		||||
        super(NetLinLayer, self).__init__()
 | 
			
		||||
 | 
			
		||||
        layers = [nn.Dropout(),] if(use_dropout) else []
 | 
			
		||||
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
 | 
			
		||||
        self.model = nn.Sequential(*layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Dist2LogitLayer(nn.Module):
 | 
			
		||||
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
 | 
			
		||||
    def __init__(self, chn_mid=32, use_sigmoid=True):
 | 
			
		||||
        super(Dist2LogitLayer, self).__init__()
 | 
			
		||||
 | 
			
		||||
        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
 | 
			
		||||
        layers += [nn.LeakyReLU(0.2,True),]
 | 
			
		||||
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
 | 
			
		||||
        layers += [nn.LeakyReLU(0.2,True),]
 | 
			
		||||
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
 | 
			
		||||
        if(use_sigmoid):
 | 
			
		||||
            layers += [nn.Sigmoid(),]
 | 
			
		||||
        self.model = nn.Sequential(*layers)
 | 
			
		||||
 | 
			
		||||
    def forward(self,d0,d1,eps=0.1):
 | 
			
		||||
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
 | 
			
		||||
 | 
			
		||||
class BCERankingLoss(nn.Module):
 | 
			
		||||
    def __init__(self, chn_mid=32):
 | 
			
		||||
        super(BCERankingLoss, self).__init__()
 | 
			
		||||
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
 | 
			
		||||
        # self.parameters = list(self.net.parameters())
 | 
			
		||||
        self.loss = torch.nn.BCELoss()
 | 
			
		||||
 | 
			
		||||
    def forward(self, d0, d1, judge):
 | 
			
		||||
        per = (judge+1.)/2.
 | 
			
		||||
        self.logit = self.net.forward(d0,d1)
 | 
			
		||||
        return self.loss(self.logit, per)
 | 
			
		||||
 | 
			
		||||
# L2, DSSIM metrics
 | 
			
		||||
class FakeNet(nn.Module):
 | 
			
		||||
    def __init__(self, use_gpu=True, colorspace='Lab'):
 | 
			
		||||
        super(FakeNet, self).__init__()
 | 
			
		||||
        self.use_gpu = use_gpu
 | 
			
		||||
        self.colorspace=colorspace
 | 
			
		||||
 | 
			
		||||
class L2(FakeNet):
 | 
			
		||||
 | 
			
		||||
    def forward(self, in0, in1, retPerLayer=None):
 | 
			
		||||
        assert(in0.size()[0]==1) # currently only supports batchSize 1
 | 
			
		||||
 | 
			
		||||
        if(self.colorspace=='RGB'):
 | 
			
		||||
            (N,C,X,Y) = in0.size()
 | 
			
		||||
            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
 | 
			
		||||
            return value
 | 
			
		||||
        elif(self.colorspace=='Lab'):
 | 
			
		||||
            value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
 | 
			
		||||
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
 | 
			
		||||
            ret_var = Variable( torch.Tensor((value,) ) )
 | 
			
		||||
            if(self.use_gpu):
 | 
			
		||||
                ret_var = ret_var.cuda()
 | 
			
		||||
            return ret_var
 | 
			
		||||
 | 
			
		||||
class DSSIM(FakeNet):
 | 
			
		||||
 | 
			
		||||
    def forward(self, in0, in1, retPerLayer=None):
 | 
			
		||||
        assert(in0.size()[0]==1) # currently only supports batchSize 1
 | 
			
		||||
 | 
			
		||||
        if(self.colorspace=='RGB'):
 | 
			
		||||
            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
 | 
			
		||||
        elif(self.colorspace=='Lab'):
 | 
			
		||||
            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
 | 
			
		||||
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
 | 
			
		||||
        ret_var = Variable( torch.Tensor((value,) ) )
 | 
			
		||||
        if(self.use_gpu):
 | 
			
		||||
            ret_var = ret_var.cuda()
 | 
			
		||||
        return ret_var
 | 
			
		||||
 | 
			
		||||
def print_network(net):
 | 
			
		||||
    num_params = 0
 | 
			
		||||
    for param in net.parameters():
 | 
			
		||||
        num_params += param.numel()
 | 
			
		||||
    print('Network',net)
 | 
			
		||||
    print('Total number of parameters: %d' % num_params)
 | 
			
		||||
@ -0,0 +1,181 @@
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
import torch
 | 
			
		||||
from torchvision import models as tv
 | 
			
		||||
from IPython import embed
 | 
			
		||||
 | 
			
		||||
class squeezenet(torch.nn.Module):
 | 
			
		||||
    def __init__(self, requires_grad=False, pretrained=True):
 | 
			
		||||
        super(squeezenet, self).__init__()
 | 
			
		||||
        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
 | 
			
		||||
        self.slice1 = torch.nn.Sequential()
 | 
			
		||||
        self.slice2 = torch.nn.Sequential()
 | 
			
		||||
        self.slice3 = torch.nn.Sequential()
 | 
			
		||||
        self.slice4 = torch.nn.Sequential()
 | 
			
		||||
        self.slice5 = torch.nn.Sequential()
 | 
			
		||||
        self.slice6 = torch.nn.Sequential()
 | 
			
		||||
        self.slice7 = torch.nn.Sequential()
 | 
			
		||||
        self.N_slices = 7
 | 
			
		||||
        for x in range(2):
 | 
			
		||||
            self.slice1.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(2,5):
 | 
			
		||||
            self.slice2.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(5, 8):
 | 
			
		||||
            self.slice3.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(8, 10):
 | 
			
		||||
            self.slice4.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(10, 11):
 | 
			
		||||
            self.slice5.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(11, 12):
 | 
			
		||||
            self.slice6.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        for x in range(12, 13):
 | 
			
		||||
            self.slice7.add_module(str(x), pretrained_features[x])
 | 
			
		||||
        if not requires_grad:
 | 
			
		||||
            for param in self.parameters():
 | 
			
		||||
                param.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    def forward(self, X):
 | 
			
		||||
        h = self.slice1(X)
 | 
			
		||||
        h_relu1 = h
 | 
			
		||||
        h = self.slice2(h)
 | 
			
		||||
        h_relu2 = h
 | 
			
		||||
        h = self.slice3(h)
 | 
			
		||||
        h_relu3 = h
 | 
			
		||||
        h = self.slice4(h)
 | 
			
		||||
        h_relu4 = h
 | 
			
		||||
        h = self.slice5(h)
 | 
			
		||||
        h_relu5 = h
 | 
			
		||||
        h = self.slice6(h)
 | 
			
		||||
        h_relu6 = h
 | 
			
		||||
        h = self.slice7(h)
 | 
			
		||||
        h_relu7 = h
 | 
			
		||||
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
 | 
			
		||||
        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class alexnet(torch.nn.Module):
 | 
			
		||||
    def __init__(self, requires_grad=False, pretrained=True):
 | 
			
		||||
        super(alexnet, self).__init__()
 | 
			
		||||
        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
 | 
			
		||||
        self.slice1 = torch.nn.Sequential()
 | 
			
		||||
        self.slice2 = torch.nn.Sequential()
 | 
			
		||||
        self.slice3 = torch.nn.Sequential()
 | 
			
		||||
        self.slice4 = torch.nn.Sequential()
 | 
			
		||||
        self.slice5 = torch.nn.Sequential()
 | 
			
		||||
        self.N_slices = 5
 | 
			
		||||
        for x in range(2):
 | 
			
		||||
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
 | 
			
		||||
        for x in range(2, 5):
 | 
			
		||||
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
 | 
			
		||||
        for x in range(5, 8):
 | 
			
		||||
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
 | 
			
		||||
        for x in range(8, 10):
 | 
			
		||||
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
 | 
			
		||||
        for x in range(10, 12):
 | 
			
		||||
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
 | 
			
		||||
        if not requires_grad:
 | 
			
		||||
            for param in self.parameters():
 | 
			
		||||
                param.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    def forward(self, X):
 | 
			
		||||
        h = self.slice1(X)
 | 
			
		||||
        h_relu1 = h
 | 
			
		||||
        h = self.slice2(h)
 | 
			
		||||
        h_relu2 = h
 | 
			
		||||
        h = self.slice3(h)
 | 
			
		||||
        h_relu3 = h
 | 
			
		||||
        h = self.slice4(h)
 | 
			
		||||
        h_relu4 = h
 | 
			
		||||
        h = self.slice5(h)
 | 
			
		||||
        h_relu5 = h
 | 
			
		||||
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
 | 
			
		||||
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
class vgg16(torch.nn.Module):
 | 
			
		||||
    def __init__(self, requires_grad=False, pretrained=True):
 | 
			
		||||
        super(vgg16, self).__init__()
 | 
			
		||||
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
 | 
			
		||||
        self.slice1 = torch.nn.Sequential()
 | 
			
		||||
        self.slice2 = torch.nn.Sequential()
 | 
			
		||||
        self.slice3 = torch.nn.Sequential()
 | 
			
		||||
        self.slice4 = torch.nn.Sequential()
 | 
			
		||||
        self.slice5 = torch.nn.Sequential()
 | 
			
		||||
        self.N_slices = 5
 | 
			
		||||
        for x in range(4):
 | 
			
		||||
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
 | 
			
		||||
        for x in range(4, 9):
 | 
			
		||||
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
 | 
			
		||||
        for x in range(9, 16):
 | 
			
		||||
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
 | 
			
		||||
        for x in range(16, 23):
 | 
			
		||||
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
 | 
			
		||||
        for x in range(23, 30):
 | 
			
		||||
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
 | 
			
		||||
        if not requires_grad:
 | 
			
		||||
            for param in self.parameters():
 | 
			
		||||
                param.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    def forward(self, X):
 | 
			
		||||
        h = self.slice1(X)
 | 
			
		||||
        h_relu1_2 = h
 | 
			
		||||
        h = self.slice2(h)
 | 
			
		||||
        h_relu2_2 = h
 | 
			
		||||
        h = self.slice3(h)
 | 
			
		||||
        h_relu3_3 = h
 | 
			
		||||
        h = self.slice4(h)
 | 
			
		||||
        h_relu4_3 = h
 | 
			
		||||
        h = self.slice5(h)
 | 
			
		||||
        h_relu5_3 = h
 | 
			
		||||
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
 | 
			
		||||
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class resnet(torch.nn.Module):
 | 
			
		||||
    def __init__(self, requires_grad=False, pretrained=True, num=18):
 | 
			
		||||
        super(resnet, self).__init__()
 | 
			
		||||
        if(num==18):
 | 
			
		||||
            self.net = tv.resnet18(pretrained=pretrained)
 | 
			
		||||
        elif(num==34):
 | 
			
		||||
            self.net = tv.resnet34(pretrained=pretrained)
 | 
			
		||||
        elif(num==50):
 | 
			
		||||
            self.net = tv.resnet50(pretrained=pretrained)
 | 
			
		||||
        elif(num==101):
 | 
			
		||||
            self.net = tv.resnet101(pretrained=pretrained)
 | 
			
		||||
        elif(num==152):
 | 
			
		||||
            self.net = tv.resnet152(pretrained=pretrained)
 | 
			
		||||
        self.N_slices = 5
 | 
			
		||||
 | 
			
		||||
        self.conv1 = self.net.conv1
 | 
			
		||||
        self.bn1 = self.net.bn1
 | 
			
		||||
        self.relu = self.net.relu
 | 
			
		||||
        self.maxpool = self.net.maxpool
 | 
			
		||||
        self.layer1 = self.net.layer1
 | 
			
		||||
        self.layer2 = self.net.layer2
 | 
			
		||||
        self.layer3 = self.net.layer3
 | 
			
		||||
        self.layer4 = self.net.layer4
 | 
			
		||||
 | 
			
		||||
    def forward(self, X):
 | 
			
		||||
        h = self.conv1(X)
 | 
			
		||||
        h = self.bn1(h)
 | 
			
		||||
        h = self.relu(h)
 | 
			
		||||
        h_relu1 = h
 | 
			
		||||
        h = self.maxpool(h)
 | 
			
		||||
        h = self.layer1(h)
 | 
			
		||||
        h_conv2 = h
 | 
			
		||||
        h = self.layer2(h)
 | 
			
		||||
        h_conv3 = h
 | 
			
		||||
        h = self.layer3(h)
 | 
			
		||||
        h_conv4 = h
 | 
			
		||||
        h = self.layer4(h)
 | 
			
		||||
        h_conv5 = h
 | 
			
		||||
 | 
			
		||||
        outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
 | 
			
		||||
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
@ -0,0 +1,160 @@
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from skimage.metrics import structural_similarity
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from . import dist_model
 | 
			
		||||
 | 
			
		||||
class PerceptualLoss(torch.nn.Module):
 | 
			
		||||
    def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
 | 
			
		||||
    # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
 | 
			
		||||
        super(PerceptualLoss, self).__init__()
 | 
			
		||||
        print('Setting up Perceptual loss...')
 | 
			
		||||
        self.use_gpu = use_gpu
 | 
			
		||||
        self.spatial = spatial
 | 
			
		||||
        self.gpu_ids = gpu_ids
 | 
			
		||||
        self.model = dist_model.DistModel()
 | 
			
		||||
        self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
 | 
			
		||||
        print('...[%s] initialized'%self.model.name())
 | 
			
		||||
        print('...Done')
 | 
			
		||||
 | 
			
		||||
    def forward(self, pred, target, normalize=False):
 | 
			
		||||
        """
 | 
			
		||||
        Pred and target are Variables.
 | 
			
		||||
        If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
 | 
			
		||||
        If normalize is False, assumes the images are already between [-1,+1]
 | 
			
		||||
 | 
			
		||||
        Inputs pred and target are Nx3xHxW
 | 
			
		||||
        Output pytorch Variable N long
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if normalize:
 | 
			
		||||
            target = 2 * target  - 1
 | 
			
		||||
            pred = 2 * pred  - 1
 | 
			
		||||
 | 
			
		||||
        return self.model.forward(target, pred)
 | 
			
		||||
 | 
			
		||||
def normalize_tensor(in_feat,eps=1e-10):
 | 
			
		||||
    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
 | 
			
		||||
    return in_feat/(norm_factor+eps)
 | 
			
		||||
 | 
			
		||||
def l2(p0, p1, range=255.):
 | 
			
		||||
    return .5*np.mean((p0 / range - p1 / range)**2)
 | 
			
		||||
 | 
			
		||||
def psnr(p0, p1, peak=255.):
 | 
			
		||||
    return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
 | 
			
		||||
 | 
			
		||||
def dssim(p0, p1, range=255.):
 | 
			
		||||
    return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
 | 
			
		||||
 | 
			
		||||
def rgb2lab(in_img,mean_cent=False):
 | 
			
		||||
    from skimage import color
 | 
			
		||||
    img_lab = color.rgb2lab(in_img)
 | 
			
		||||
    if(mean_cent):
 | 
			
		||||
        img_lab[:,:,0] = img_lab[:,:,0]-50
 | 
			
		||||
    return img_lab
 | 
			
		||||
 | 
			
		||||
def tensor2np(tensor_obj):
 | 
			
		||||
    # change dimension of a tensor object into a numpy array
 | 
			
		||||
    return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
 | 
			
		||||
 | 
			
		||||
def np2tensor(np_obj):
 | 
			
		||||
     # change dimenion of np array into tensor array
 | 
			
		||||
    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
 | 
			
		||||
 | 
			
		||||
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
 | 
			
		||||
    # image tensor to lab tensor
 | 
			
		||||
    from skimage import color
 | 
			
		||||
 | 
			
		||||
    img = tensor2im(image_tensor)
 | 
			
		||||
    img_lab = color.rgb2lab(img)
 | 
			
		||||
    if(mc_only):
 | 
			
		||||
        img_lab[:,:,0] = img_lab[:,:,0]-50
 | 
			
		||||
    if(to_norm and not mc_only):
 | 
			
		||||
        img_lab[:,:,0] = img_lab[:,:,0]-50
 | 
			
		||||
        img_lab = img_lab/100.
 | 
			
		||||
 | 
			
		||||
    return np2tensor(img_lab)
 | 
			
		||||
 | 
			
		||||
def tensorlab2tensor(lab_tensor,return_inbnd=False):
 | 
			
		||||
    from skimage import color
 | 
			
		||||
    import warnings
 | 
			
		||||
    warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
    lab = tensor2np(lab_tensor)*100.
 | 
			
		||||
    lab[:,:,0] = lab[:,:,0]+50
 | 
			
		||||
 | 
			
		||||
    rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
 | 
			
		||||
    if(return_inbnd):
 | 
			
		||||
        # convert back to lab, see if we match
 | 
			
		||||
        lab_back = color.rgb2lab(rgb_back.astype('uint8'))
 | 
			
		||||
        mask = 1.*np.isclose(lab_back,lab,atol=2.)
 | 
			
		||||
        mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
 | 
			
		||||
        return (im2tensor(rgb_back),mask)
 | 
			
		||||
    else:
 | 
			
		||||
        return im2tensor(rgb_back)
 | 
			
		||||
 | 
			
		||||
def rgb2lab(input):
 | 
			
		||||
    from skimage import color
 | 
			
		||||
    return color.rgb2lab(input / 255.)
 | 
			
		||||
 | 
			
		||||
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
 | 
			
		||||
    image_numpy = image_tensor[0].cpu().float().numpy()
 | 
			
		||||
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
 | 
			
		||||
    return image_numpy.astype(imtype)
 | 
			
		||||
 | 
			
		||||
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
 | 
			
		||||
    return torch.Tensor((image / factor - cent)
 | 
			
		||||
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
 | 
			
		||||
 | 
			
		||||
def tensor2vec(vector_tensor):
 | 
			
		||||
    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
 | 
			
		||||
 | 
			
		||||
def voc_ap(rec, prec, use_07_metric=False):
 | 
			
		||||
    """ ap = voc_ap(rec, prec, [use_07_metric])
 | 
			
		||||
    Compute VOC AP given precision and recall.
 | 
			
		||||
    If use_07_metric is true, uses the
 | 
			
		||||
    VOC 07 11 point method (default:False).
 | 
			
		||||
    """
 | 
			
		||||
    if use_07_metric:
 | 
			
		||||
        # 11 point metric
 | 
			
		||||
        ap = 0.
 | 
			
		||||
        for t in np.arange(0., 1.1, 0.1):
 | 
			
		||||
            if np.sum(rec >= t) == 0:
 | 
			
		||||
                p = 0
 | 
			
		||||
            else:
 | 
			
		||||
                p = np.max(prec[rec >= t])
 | 
			
		||||
            ap = ap + p / 11.
 | 
			
		||||
    else:
 | 
			
		||||
        # correct AP calculation
 | 
			
		||||
        # first append sentinel values at the end
 | 
			
		||||
        mrec = np.concatenate(([0.], rec, [1.]))
 | 
			
		||||
        mpre = np.concatenate(([0.], prec, [0.]))
 | 
			
		||||
 | 
			
		||||
        # compute the precision envelope
 | 
			
		||||
        for i in range(mpre.size - 1, 0, -1):
 | 
			
		||||
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
 | 
			
		||||
 | 
			
		||||
        # to calculate area under PR curve, look for points
 | 
			
		||||
        # where X axis (recall) changes value
 | 
			
		||||
        i = np.where(mrec[1:] != mrec[:-1])[0]
 | 
			
		||||
 | 
			
		||||
        # and sum (\Delta recall) * prec
 | 
			
		||||
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
 | 
			
		||||
    return ap
 | 
			
		||||
 | 
			
		||||
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
 | 
			
		||||
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
 | 
			
		||||
    image_numpy = image_tensor[0].cpu().float().numpy()
 | 
			
		||||
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
 | 
			
		||||
    return image_numpy.astype(imtype)
 | 
			
		||||
 | 
			
		||||
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
 | 
			
		||||
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
 | 
			
		||||
    return torch.Tensor((image / factor - cent)
 | 
			
		||||
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
 | 
			
		||||
@ -0,0 +1,964 @@
 | 
			
		||||
import os
 | 
			
		||||
import os.path as osp
 | 
			
		||||
from argparse import ArgumentParser
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import imageio
 | 
			
		||||
import dnnlib
 | 
			
		||||
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
 | 
			
		||||
                          get_latest_points_pair, get_valid_mask,
 | 
			
		||||
                          on_change_single_global_state)
 | 
			
		||||
from viz.renderer import Renderer, add_watermark_np
 | 
			
		||||
from gan_inv.inversion import PTI
 | 
			
		||||
from gan_inv.lpips import util
 | 
			
		||||
parser = ArgumentParser()
 | 
			
		||||
parser.add_argument('--share',default='False')
 | 
			
		||||
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
 | 
			
		||||
args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
cache_dir = args.cache_dir
 | 
			
		||||
 | 
			
		||||
device = 'cuda'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reverse_point_pairs(points):
 | 
			
		||||
    new_points = []
 | 
			
		||||
    for p in points:
 | 
			
		||||
        new_points.append([p[1], p[0]])
 | 
			
		||||
    return new_points
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_state(global_state, target=None):
 | 
			
		||||
    """Clear target history state from global_state
 | 
			
		||||
    If target is not defined, points and mask will be both removed.
 | 
			
		||||
    1. set global_state['points'] as empty dict
 | 
			
		||||
    2. set global_state['mask'] as full-one mask.
 | 
			
		||||
    """
 | 
			
		||||
    if target is None:
 | 
			
		||||
        target = ['point', 'mask']
 | 
			
		||||
    if not isinstance(target, list):
 | 
			
		||||
        target = [target]
 | 
			
		||||
    if 'point' in target:
 | 
			
		||||
        global_state['points'] = dict()
 | 
			
		||||
        print('Clear Points State!')
 | 
			
		||||
    if 'mask' in target:
 | 
			
		||||
        image_raw = global_state["images"]["image_raw"]
 | 
			
		||||
        global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
 | 
			
		||||
                                       dtype=np.uint8)
 | 
			
		||||
        print('Clear mask State!')
 | 
			
		||||
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_images(global_state):
 | 
			
		||||
    """This function is called only ones with Gradio App is started.
 | 
			
		||||
    0. pre-process global_state, unpack value from global_state of need
 | 
			
		||||
    1. Re-init renderer
 | 
			
		||||
    2. run `renderer._render_drag_impl` with `is_drag=False` to generate
 | 
			
		||||
       new image
 | 
			
		||||
    3. Assign images to global state and re-generate mask
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if isinstance(global_state, gr.State):
 | 
			
		||||
        state = global_state.value
 | 
			
		||||
    else:
 | 
			
		||||
        state = global_state
 | 
			
		||||
 | 
			
		||||
    state['renderer'].init_network(
 | 
			
		||||
        state['generator_params'],  # res
 | 
			
		||||
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
 | 
			
		||||
        state['params']['seed'],  # w0_seed,
 | 
			
		||||
        None,  # w_load
 | 
			
		||||
        state['params']['latent_space'] == 'w+',  # w_plus
 | 
			
		||||
        'const',
 | 
			
		||||
        state['params']['trunc_psi'],  # trunc_psi,
 | 
			
		||||
        state['params']['trunc_cutoff'],  # trunc_cutoff,
 | 
			
		||||
        None,  # input_transform
 | 
			
		||||
        state['params']['lr']  # lr,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    state['renderer']._render_drag_impl(state['generator_params'],
 | 
			
		||||
                                        is_drag=False,
 | 
			
		||||
                                        to_pil=True)
 | 
			
		||||
 | 
			
		||||
    init_image = state['generator_params'].image
 | 
			
		||||
    state['images']['image_orig'] = init_image
 | 
			
		||||
    state['images']['image_raw'] = init_image
 | 
			
		||||
    state['images']['image_show'] = Image.fromarray(
 | 
			
		||||
        add_watermark_np(np.array(init_image)))
 | 
			
		||||
    state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
 | 
			
		||||
                            dtype=np.uint8)
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_image_draw(image, points, mask, show_mask, global_state=None):
 | 
			
		||||
 | 
			
		||||
    image_draw = draw_points_on_image(image, points)
 | 
			
		||||
    if show_mask and mask is not None and not (mask == 0).all() and not (
 | 
			
		||||
            mask == 1).all():
 | 
			
		||||
        image_draw = draw_mask_on_image(image_draw, mask)
 | 
			
		||||
 | 
			
		||||
    image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
 | 
			
		||||
    if global_state is not None:
 | 
			
		||||
        global_state['images']['image_show'] = image_draw
 | 
			
		||||
    return image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_mask_info(global_state, image):
 | 
			
		||||
    """Function to handle mask information.
 | 
			
		||||
    1. last_mask is None: Do not need to change mask, return mask
 | 
			
		||||
    2. last_mask is not None:
 | 
			
		||||
        2.1 global_state is remove_mask:
 | 
			
		||||
        2.2 global_state is add_mask:
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(image, dict):
 | 
			
		||||
        last_mask = get_valid_mask(image['mask'])
 | 
			
		||||
    else:
 | 
			
		||||
        last_mask = None
 | 
			
		||||
    mask = global_state['mask']
 | 
			
		||||
 | 
			
		||||
    # mask in global state is a placeholder with all 1.
 | 
			
		||||
    if (mask == 1).all():
 | 
			
		||||
        mask = last_mask
 | 
			
		||||
 | 
			
		||||
    # last_mask = global_state['last_mask']
 | 
			
		||||
    editing_mode = global_state['editing_state']
 | 
			
		||||
 | 
			
		||||
    if last_mask is None:
 | 
			
		||||
        return global_state
 | 
			
		||||
 | 
			
		||||
    if editing_mode == 'remove_mask':
 | 
			
		||||
        updated_mask = np.clip(mask - last_mask, 0, 1)
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, do remove.')
 | 
			
		||||
    elif editing_mode == 'add_mask':
 | 
			
		||||
        updated_mask = np.clip(mask + last_mask, 0, 1)
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, do add.')
 | 
			
		||||
    else:
 | 
			
		||||
        updated_mask = mask
 | 
			
		||||
        print(f'Last editing_state is {editing_mode}, '
 | 
			
		||||
              'do nothing to mask.')
 | 
			
		||||
 | 
			
		||||
    global_state['mask'] = updated_mask
 | 
			
		||||
    # global_state['last_mask'] = None  # clear buffer
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
valid_checkpoints_dict = {
 | 
			
		||||
    f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
 | 
			
		||||
    for f in os.listdir(cache_dir)
 | 
			
		||||
    if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
 | 
			
		||||
}
 | 
			
		||||
print(f'File under cache_dir ({cache_dir}):')
 | 
			
		||||
print(os.listdir(cache_dir))
 | 
			
		||||
print('Valid checkpoint file:')
 | 
			
		||||
print(valid_checkpoints_dict)
 | 
			
		||||
 | 
			
		||||
init_pkl = 'stylegan2_lions_512_pytorch'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Network & latents tab listeners
 | 
			
		||||
def on_change_pretrained_dropdown(pretrained_value, global_state):
 | 
			
		||||
    """Function to handle model change.
 | 
			
		||||
    1. Set pretrained value to global_state
 | 
			
		||||
    2. Re-init images and clear all states
 | 
			
		||||
    """
 | 
			
		||||
    global_state['pretrained_weight'] = pretrained_value
 | 
			
		||||
    init_images(global_state)
 | 
			
		||||
    clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
    return global_state, global_state["images"]['image_show']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_reset_image(global_state):
 | 
			
		||||
    """Reset image to the original one and clear all states
 | 
			
		||||
    1. Re-init images
 | 
			
		||||
    2. Clear all states
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    init_images(global_state)
 | 
			
		||||
    clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
    return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # Update parameters
 | 
			
		||||
def on_change_update_image_seed(seed, global_state):
 | 
			
		||||
    """Function to handle generation seed change.
 | 
			
		||||
    1. Set seed to global_state
 | 
			
		||||
    2. Re-init images and clear all states
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    global_state["params"]["seed"] = int(seed)
 | 
			
		||||
    init_images(global_state)
 | 
			
		||||
    clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
    return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_latent_space(latent_space, global_state):
 | 
			
		||||
    """Function to reset latent space to optimize.
 | 
			
		||||
    NOTE: this function we reset the image and all controls
 | 
			
		||||
    1. Set latent-space to global_state
 | 
			
		||||
    2. Re-init images and clear all state
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    global_state['params']['latent_space'] = latent_space
 | 
			
		||||
    init_images(global_state)
 | 
			
		||||
    clear_state(global_state)
 | 
			
		||||
 | 
			
		||||
    return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_inverse_custom_image(custom_image,global_state):
 | 
			
		||||
    print('inverse GAN')
 | 
			
		||||
 | 
			
		||||
    if isinstance(global_state, gr.State):
 | 
			
		||||
        state = global_state.value
 | 
			
		||||
    else:
 | 
			
		||||
        state = global_state
 | 
			
		||||
 | 
			
		||||
    state['renderer'].init_network(
 | 
			
		||||
        state['generator_params'],  # res
 | 
			
		||||
        valid_checkpoints_dict[state['pretrained_weight']],  # pkl
 | 
			
		||||
        state['params']['seed'],  # w0_seed,
 | 
			
		||||
        None,  # w_load
 | 
			
		||||
        state['params']['latent_space'] == 'w+',  # w_plus
 | 
			
		||||
        'const',
 | 
			
		||||
        state['params']['trunc_psi'],  # trunc_psi,
 | 
			
		||||
        state['params']['trunc_cutoff'],  # trunc_cutoff,
 | 
			
		||||
        None,  # input_transform
 | 
			
		||||
        state['params']['lr']  # lr,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    percept = util.PerceptualLoss(
 | 
			
		||||
        model="net-lin", net="vgg", use_gpu=True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    image = Image.open(custom_image.name)
 | 
			
		||||
 | 
			
		||||
    pti = PTI(global_state['renderer'].G,percept)
 | 
			
		||||
    inversed_img, w_pivot = pti.train(image,state['params']['latent_space'] == 'w+')
 | 
			
		||||
    inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
 | 
			
		||||
    inversed_img = inversed_img.cpu().numpy()
 | 
			
		||||
    inversed_img = Image.fromarray(inversed_img)
 | 
			
		||||
    global_state['images']['image_show'] = Image.fromarray(
 | 
			
		||||
            add_watermark_np(np.array(inversed_img)))
 | 
			
		||||
 | 
			
		||||
    global_state['images']['image_orig'] = inversed_img
 | 
			
		||||
    global_state['images']['image_raw'] = inversed_img
 | 
			
		||||
        
 | 
			
		||||
    global_state['mask'] = np.ones((inversed_img.size[1], inversed_img.size[0]),
 | 
			
		||||
                                dtype=np.uint8)
 | 
			
		||||
    global_state['generator_params'].image = inversed_img
 | 
			
		||||
    global_state['generator_params'].w = w_pivot.detach().cpu().numpy()
 | 
			
		||||
    global_state['renderer'].set_latent(w_pivot,global_state['params']['trunc_psi'],global_state['params']['trunc_cutoff'])
 | 
			
		||||
 | 
			
		||||
    del percept
 | 
			
		||||
    del pti
 | 
			
		||||
    print('inverse end')
 | 
			
		||||
 | 
			
		||||
    return global_state, global_state['images']['image_show'], gr.Button.update(interactive=True)
 | 
			
		||||
 | 
			
		||||
def on_save_image(global_state,form_save_image_path):
 | 
			
		||||
    imageio.imsave(form_save_image_path,global_state['images']['image_raw'])
 | 
			
		||||
 | 
			
		||||
def on_reset_custom_image(global_state):
 | 
			
		||||
    if isinstance(global_state, gr.State):
 | 
			
		||||
        state = global_state.value
 | 
			
		||||
    else:
 | 
			
		||||
        state = global_state
 | 
			
		||||
    clear_state(state)
 | 
			
		||||
    state['renderer'].w = state['renderer'].w0.detach().clone()
 | 
			
		||||
    state['renderer'].w.requires_grad = True
 | 
			
		||||
    state['renderer'].w_optim = torch.optim.Adam([state['renderer'].w], lr=state['renderer'].lr)
 | 
			
		||||
    state['renderer']._render_drag_impl(state['generator_params'],
 | 
			
		||||
                                        is_drag=False,
 | 
			
		||||
                                        to_pil=True)
 | 
			
		||||
 | 
			
		||||
    init_image = state['generator_params'].image
 | 
			
		||||
    state['images']['image_orig'] = init_image
 | 
			
		||||
    state['images']['image_raw'] = init_image
 | 
			
		||||
    state['images']['image_show'] = Image.fromarray(
 | 
			
		||||
        add_watermark_np(np.array(init_image)))
 | 
			
		||||
    state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
 | 
			
		||||
                            dtype=np.uint8)
 | 
			
		||||
    return state, state['images']['image_show']
 | 
			
		||||
def on_change_lr(lr, global_state):
 | 
			
		||||
    if lr == 0:
 | 
			
		||||
        print('lr is 0, do nothing.')
 | 
			
		||||
        return global_state
 | 
			
		||||
    else:
 | 
			
		||||
        global_state["params"]["lr"] = lr
 | 
			
		||||
        renderer = global_state['renderer']
 | 
			
		||||
        renderer.update_lr(lr)
 | 
			
		||||
        print('New optimizer: ')
 | 
			
		||||
        print(renderer.w_optim)
 | 
			
		||||
    return global_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_start(global_state, image):
 | 
			
		||||
    p_in_pixels = []
 | 
			
		||||
    t_in_pixels = []
 | 
			
		||||
    valid_points = []
 | 
			
		||||
 | 
			
		||||
    # handle of start drag in mask editing mode
 | 
			
		||||
    global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
 | 
			
		||||
    # Prepare the points for the inference
 | 
			
		||||
    if len(global_state["points"]) == 0:
 | 
			
		||||
        # yield on_click_start_wo_points(global_state, image)
 | 
			
		||||
        image_raw = global_state['images']['image_raw']
 | 
			
		||||
        update_image_draw(
 | 
			
		||||
            image_raw,
 | 
			
		||||
            global_state['points'],
 | 
			
		||||
            global_state['mask'],
 | 
			
		||||
            global_state['show_mask'],
 | 
			
		||||
            global_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        yield (
 | 
			
		||||
            global_state,
 | 
			
		||||
            0,
 | 
			
		||||
            global_state['images']['image_show'],
 | 
			
		||||
            # gr.File.update(visible=False),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            # latent space
 | 
			
		||||
            gr.Radio.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            # NOTE: disable stop button
 | 
			
		||||
            gr.Button.update(interactive=False),
 | 
			
		||||
 | 
			
		||||
            # update other comps
 | 
			
		||||
            gr.Dropdown.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Checkbox.update(interactive=True),
 | 
			
		||||
            # gr.Number.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
        # Transform the points into torch tensors
 | 
			
		||||
        for key_point, point in global_state["points"].items():
 | 
			
		||||
            try:
 | 
			
		||||
                p_start = point.get("start_temp", point["start"])
 | 
			
		||||
                p_end = point["target"]
 | 
			
		||||
 | 
			
		||||
                if p_start is None or p_end is None:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            p_in_pixels.append(p_start)
 | 
			
		||||
            t_in_pixels.append(p_end)
 | 
			
		||||
            valid_points.append(key_point)
 | 
			
		||||
 | 
			
		||||
        mask = torch.tensor(global_state['mask']).float()
 | 
			
		||||
        drag_mask = 1 - mask
 | 
			
		||||
 | 
			
		||||
        renderer: Renderer = global_state["renderer"]
 | 
			
		||||
        global_state['temporal_params']['stop'] = False
 | 
			
		||||
        global_state['editing_state'] = 'running'
 | 
			
		||||
 | 
			
		||||
        # reverse points order
 | 
			
		||||
        p_to_opt = reverse_point_pairs(p_in_pixels)
 | 
			
		||||
        t_to_opt = reverse_point_pairs(t_in_pixels)
 | 
			
		||||
        #print('Running with:')
 | 
			
		||||
        #print(f'    Source: {p_in_pixels}')
 | 
			
		||||
        #print(f'    Target: {t_in_pixels}')
 | 
			
		||||
        step_idx = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            if global_state["temporal_params"]["stop"]:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            # do drage here!
 | 
			
		||||
            renderer._render_drag_impl(
 | 
			
		||||
                global_state['generator_params'],
 | 
			
		||||
                p_to_opt,  # point
 | 
			
		||||
                t_to_opt,  # target
 | 
			
		||||
                drag_mask,  # mask,
 | 
			
		||||
                global_state['params']['motion_lambda'],  # lambda_mask
 | 
			
		||||
                reg=0,
 | 
			
		||||
                feature_idx=5,  # NOTE: do not support change for now
 | 
			
		||||
                r1=global_state['params']['r1_in_pixels'],  # r1
 | 
			
		||||
                r2=global_state['params']['r2_in_pixels'],  # r2
 | 
			
		||||
                # random_seed     = 0,
 | 
			
		||||
                # noise_mode      = 'const',
 | 
			
		||||
                trunc_psi=global_state['params']['trunc_psi'],
 | 
			
		||||
                # force_fp32      = False,
 | 
			
		||||
                # layer_name      = None,
 | 
			
		||||
                # sel_channels    = 3,
 | 
			
		||||
                # base_channel    = 0,
 | 
			
		||||
                # img_scale_db    = 0,
 | 
			
		||||
                # img_normalize   = False,
 | 
			
		||||
                # untransform     = False,
 | 
			
		||||
                is_drag=True,
 | 
			
		||||
                to_pil=True)
 | 
			
		||||
 | 
			
		||||
            if step_idx % global_state['draw_interval'] == 0:
 | 
			
		||||
                #print('Current Source:')
 | 
			
		||||
                for key_point, p_i, t_i in zip(valid_points, p_to_opt,
 | 
			
		||||
                                               t_to_opt):
 | 
			
		||||
                    global_state["points"][key_point]["start_temp"] = [
 | 
			
		||||
                        p_i[1],
 | 
			
		||||
                        p_i[0],
 | 
			
		||||
                    ]
 | 
			
		||||
                    global_state["points"][key_point]["target"] = [
 | 
			
		||||
                        t_i[1],
 | 
			
		||||
                        t_i[0],
 | 
			
		||||
                    ]
 | 
			
		||||
                    start_temp = global_state["points"][key_point][
 | 
			
		||||
                        "start_temp"]
 | 
			
		||||
                    #print(f'    {start_temp}')
 | 
			
		||||
 | 
			
		||||
                image_result = global_state['generator_params']['image']
 | 
			
		||||
                image_draw = update_image_draw(
 | 
			
		||||
                    image_result,
 | 
			
		||||
                    global_state['points'],
 | 
			
		||||
                    global_state['mask'],
 | 
			
		||||
                    global_state['show_mask'],
 | 
			
		||||
                    global_state,
 | 
			
		||||
                )
 | 
			
		||||
                global_state['images']['image_raw'] = image_result
 | 
			
		||||
 | 
			
		||||
            yield (
 | 
			
		||||
                global_state,
 | 
			
		||||
                step_idx,
 | 
			
		||||
                global_state['images']['image_show'],
 | 
			
		||||
                # gr.File.update(visible=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                # latent space
 | 
			
		||||
                gr.Radio.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                # enable stop button in loop
 | 
			
		||||
                gr.Button.update(interactive=True),
 | 
			
		||||
 | 
			
		||||
                # update other comps
 | 
			
		||||
                gr.Dropdown.update(interactive=False),
 | 
			
		||||
                gr.Number.update(interactive=False),
 | 
			
		||||
                gr.Number.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Button.update(interactive=False),
 | 
			
		||||
                gr.Checkbox.update(interactive=False),
 | 
			
		||||
                # gr.Number.update(interactive=False),
 | 
			
		||||
                gr.Number.update(interactive=False),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # increate step
 | 
			
		||||
            step_idx += 1
 | 
			
		||||
 | 
			
		||||
        image_result = global_state['generator_params']['image']
 | 
			
		||||
        global_state['images']['image_raw'] = image_result
 | 
			
		||||
        image_draw = update_image_draw(image_result,
 | 
			
		||||
                                       global_state['points'],
 | 
			
		||||
                                       global_state['mask'],
 | 
			
		||||
                                       global_state['show_mask'],
 | 
			
		||||
                                       global_state)
 | 
			
		||||
 | 
			
		||||
        # fp = NamedTemporaryFile(suffix=".png", delete=False)
 | 
			
		||||
        # image_result.save(fp, "PNG")
 | 
			
		||||
 | 
			
		||||
        global_state['editing_state'] = 'add_points'
 | 
			
		||||
 | 
			
		||||
        yield (
 | 
			
		||||
            global_state,
 | 
			
		||||
            0,  # reset step to 0 after stop.
 | 
			
		||||
            global_state['images']['image_show'],
 | 
			
		||||
            # gr.File.update(visible=True, value=fp.name),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            # latent space
 | 
			
		||||
            gr.Radio.update(interactive=True),
 | 
			
		||||
            gr.Button.update(interactive=True),
 | 
			
		||||
            # NOTE: disable stop button with loop finish
 | 
			
		||||
            gr.Button.update(interactive=False),
 | 
			
		||||
 | 
			
		||||
            # update other comps
 | 
			
		||||
            gr.Dropdown.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
            gr.Checkbox.update(interactive=True),
 | 
			
		||||
            gr.Number.update(interactive=True),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_stop(global_state):
 | 
			
		||||
    """Function to handle stop button is clicked.
 | 
			
		||||
    1. send a stop signal by set global_state["temporal_params"]["stop"] as True
 | 
			
		||||
    2. Disable Stop button
 | 
			
		||||
    """
 | 
			
		||||
    global_state["temporal_params"]["stop"] = True
 | 
			
		||||
 | 
			
		||||
    return global_state, gr.Button.update(interactive=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_remove_point(global_state):
 | 
			
		||||
    choice = global_state["curr_point"]
 | 
			
		||||
    del global_state["points"][choice]
 | 
			
		||||
 | 
			
		||||
    choices = list(global_state["points"].keys())
 | 
			
		||||
 | 
			
		||||
    if len(choices) > 0:
 | 
			
		||||
        global_state["curr_point"] = choices[0]
 | 
			
		||||
 | 
			
		||||
    return (
 | 
			
		||||
        gr.Dropdown.update(choices=choices, value=choices[0]),
 | 
			
		||||
        global_state,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Mask
 | 
			
		||||
def on_click_reset_mask(global_state):
 | 
			
		||||
    global_state['mask'] = np.ones(
 | 
			
		||||
            (
 | 
			
		||||
                global_state["images"]["image_raw"].size[1],
 | 
			
		||||
                global_state["images"]["image_raw"].size[0],
 | 
			
		||||
            ),
 | 
			
		||||
            dtype=np.uint8,
 | 
			
		||||
    )
 | 
			
		||||
    image_draw = update_image_draw(global_state['images']['image_raw'],
 | 
			
		||||
                                       global_state['points'],
 | 
			
		||||
                                       global_state['mask'],
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
    return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # Image
 | 
			
		||||
def on_click_enable_draw(global_state, image):
 | 
			
		||||
    """Function to start add mask mode.
 | 
			
		||||
    1. Preprocess mask info from last state
 | 
			
		||||
    2. Change editing state to add_mask
 | 
			
		||||
    3. Set curr image with points and mask
 | 
			
		||||
    """
 | 
			
		||||
    global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
    global_state['editing_state'] = 'add_mask'
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(image_raw, global_state['points'],
 | 
			
		||||
                                       global_state['mask'], True,
 | 
			
		||||
                                       global_state)
 | 
			
		||||
    return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=True))
 | 
			
		||||
 | 
			
		||||
def on_click_remove_draw(global_state, image):
 | 
			
		||||
    """Function to start remove mask mode.
 | 
			
		||||
    1. Preprocess mask info from last state
 | 
			
		||||
    2. Change editing state to remove_mask
 | 
			
		||||
    3. Set curr image with points and mask
 | 
			
		||||
    """
 | 
			
		||||
    global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
    global_state['edinting_state'] = 'remove_mask'
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(image_raw, global_state['points'],
 | 
			
		||||
                                       global_state['mask'], True,
 | 
			
		||||
                                       global_state)
 | 
			
		||||
    return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=True))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_add_point(global_state, image: dict):
 | 
			
		||||
    """Function switch from add mask mode to add points mode.
 | 
			
		||||
    1. Updaste mask buffer if need
 | 
			
		||||
    2. Change global_state['editing_state'] to 'add_points'
 | 
			
		||||
    3. Set current image with mask
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    global_state = preprocess_mask_info(global_state, image)
 | 
			
		||||
    global_state['editing_state'] = 'add_points'
 | 
			
		||||
    mask = global_state['mask']
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(image_raw, global_state['points'], mask,
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
 | 
			
		||||
    return (global_state,
 | 
			
		||||
                gr.Image.update(value=image_draw, interactive=False))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_image(global_state, evt: gr.SelectData):
 | 
			
		||||
    """This function only support click for point selection
 | 
			
		||||
    """
 | 
			
		||||
    xy = evt.index
 | 
			
		||||
    if global_state['editing_state'] != 'add_points':
 | 
			
		||||
        print(f'In {global_state["editing_state"]} state. '
 | 
			
		||||
                  'Do not add points.')
 | 
			
		||||
 | 
			
		||||
        return global_state, global_state['images']['image_show']
 | 
			
		||||
 | 
			
		||||
    points = global_state["points"]
 | 
			
		||||
 | 
			
		||||
    point_idx = get_latest_points_pair(points)
 | 
			
		||||
    if point_idx is None:
 | 
			
		||||
        points[0] = {'start': xy, 'target': None}
 | 
			
		||||
        print(f'Click Image - Start - {xy}')
 | 
			
		||||
    elif points[point_idx].get('target', None) is None:
 | 
			
		||||
        points[point_idx]['target'] = xy
 | 
			
		||||
        print(f'Click Image - Target - {xy}')
 | 
			
		||||
    else:
 | 
			
		||||
        points[point_idx + 1] = {'start': xy, 'target': None}
 | 
			
		||||
        print(f'Click Image - Start - {xy}')
 | 
			
		||||
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(
 | 
			
		||||
            image_raw,
 | 
			
		||||
            global_state['points'],
 | 
			
		||||
            global_state['mask'],
 | 
			
		||||
            global_state['show_mask'],
 | 
			
		||||
            global_state,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_clear_points(global_state):
 | 
			
		||||
    """Function to handle clear all control points
 | 
			
		||||
    1. clear global_state['points'] (clear_state)
 | 
			
		||||
    2. re-init network
 | 
			
		||||
    2. re-draw image
 | 
			
		||||
    """
 | 
			
		||||
    clear_state(global_state, target='point')
 | 
			
		||||
 | 
			
		||||
    renderer: Renderer = global_state["renderer"]
 | 
			
		||||
    renderer.feat_refs = None
 | 
			
		||||
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(image_raw, {}, global_state['mask'],
 | 
			
		||||
                                       global_state['show_mask'], global_state)
 | 
			
		||||
    return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_click_show_mask(global_state, show_mask):
 | 
			
		||||
    """Function to control whether show mask on image."""
 | 
			
		||||
    global_state['show_mask'] = show_mask
 | 
			
		||||
 | 
			
		||||
    image_raw = global_state['images']['image_raw']
 | 
			
		||||
    image_draw = update_image_draw(
 | 
			
		||||
            image_raw,
 | 
			
		||||
            global_state['points'],
 | 
			
		||||
            global_state['mask'],
 | 
			
		||||
            global_state['show_mask'],
 | 
			
		||||
            global_state,
 | 
			
		||||
    )
 | 
			
		||||
    return global_state, image_draw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    with gr.Blocks() as app:
 | 
			
		||||
        # renderer = Renderer()
 | 
			
		||||
        global_state = gr.State({
 | 
			
		||||
            "images": {
 | 
			
		||||
                # image_orig: the original image, change with seed/model is changed
 | 
			
		||||
                # image_raw: image with mask and points, change durning optimization
 | 
			
		||||
                # image_show: image showed on screen
 | 
			
		||||
            },
 | 
			
		||||
            "temporal_params": {
 | 
			
		||||
                # stop
 | 
			
		||||
            },
 | 
			
		||||
            'mask':
 | 
			
		||||
                None,  # mask for visualization, 1 for editing and 0 for unchange
 | 
			
		||||
            'last_mask': None,  # last edited mask
 | 
			
		||||
            'show_mask': True,  # add button
 | 
			
		||||
            "generator_params": dnnlib.EasyDict(),
 | 
			
		||||
            "params": {
 | 
			
		||||
                "seed": 0,
 | 
			
		||||
                "motion_lambda": 20,
 | 
			
		||||
                "r1_in_pixels": 3,
 | 
			
		||||
                "r2_in_pixels": 12,
 | 
			
		||||
                "magnitude_direction_in_pixels": 1.0,
 | 
			
		||||
                "latent_space": "w+",
 | 
			
		||||
                "trunc_psi": 0.7,
 | 
			
		||||
                "trunc_cutoff": None,
 | 
			
		||||
                "lr": 0.001,
 | 
			
		||||
            },
 | 
			
		||||
            "device": device,
 | 
			
		||||
            "draw_interval": 1,
 | 
			
		||||
            "renderer": Renderer(disable_timing=True),
 | 
			
		||||
            "points": {},
 | 
			
		||||
            "curr_point": None,
 | 
			
		||||
            "curr_type_point": "start",
 | 
			
		||||
            'editing_state': 'add_points',
 | 
			
		||||
            'pretrained_weight': init_pkl
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        # init image
 | 
			
		||||
        global_state = init_images(global_state)
 | 
			
		||||
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            with gr.Row():
 | 
			
		||||
                # Left --> tools
 | 
			
		||||
                with gr.Column(scale=3):
 | 
			
		||||
                    # Pickle
 | 
			
		||||
                    with gr.Row():
 | 
			
		||||
                        with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                            gr.Markdown(value='Pickle', show_label=False)
 | 
			
		||||
 | 
			
		||||
                        with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                            form_pretrained_dropdown = gr.Dropdown(
 | 
			
		||||
                                choices=list(valid_checkpoints_dict.keys()),
 | 
			
		||||
                                label="Pretrained Model",
 | 
			
		||||
                                value=init_pkl,
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                    # Latent
 | 
			
		||||
                    with gr.Row():
 | 
			
		||||
                        with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                            gr.Markdown(value='Latent', show_label=False)
 | 
			
		||||
 | 
			
		||||
                        with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                            form_seed_number = gr.Number(
 | 
			
		||||
                                value=global_state.value['params']['seed'],
 | 
			
		||||
                                interactive=True,
 | 
			
		||||
                                label="Seed",
 | 
			
		||||
                            )
 | 
			
		||||
                            form_lr_number = gr.Number(
 | 
			
		||||
                                value=global_state.value["params"]["lr"],
 | 
			
		||||
                                interactive=True,
 | 
			
		||||
                                label="Step Size")
 | 
			
		||||
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=2, min_width=10):
 | 
			
		||||
                                    form_reset_image = gr.Button("Reset Image")
 | 
			
		||||
                                with gr.Column(scale=3, min_width=10):
 | 
			
		||||
                                    form_latent_space = gr.Radio(
 | 
			
		||||
                                        ['w', 'w+'],
 | 
			
		||||
                                        value=global_state.value['params']
 | 
			
		||||
                                        ['latent_space'],
 | 
			
		||||
                                        interactive=True,
 | 
			
		||||
                                        label='Latent space to optimize',
 | 
			
		||||
                                        show_label=False,
 | 
			
		||||
                                    )
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=3, min_width=10):
 | 
			
		||||
                                    form_custom_image = gr.UploadButton(label="inverse custom image",
 | 
			
		||||
                                                                        file_types=['.png', '.jpg', '.jpeg'])
 | 
			
		||||
                                with gr.Column(scale=3, min_width=10):
 | 
			
		||||
                                    form_reset_custom_image = gr.Button('reset custom image', interactive=False)
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=3, min_width=10):
 | 
			
		||||
                                    form_save_image_path = gr.Textbox(label="save image to",value='./test.png')
 | 
			
		||||
                                    form_save_image = gr.Button('save',interactive=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                    # Drag
 | 
			
		||||
                    with gr.Row():
 | 
			
		||||
                        with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                            gr.Markdown(value='Drag', show_label=False)
 | 
			
		||||
                        with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    enable_add_points = gr.Button('Add Points')
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    undo_points = gr.Button('Reset Points')
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    form_start_btn = gr.Button("Start")
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    form_stop_btn = gr.Button("Stop")
 | 
			
		||||
 | 
			
		||||
                            form_steps_number = gr.Number(value=0,
 | 
			
		||||
                                                          label="Steps",
 | 
			
		||||
                                                          interactive=False)
 | 
			
		||||
 | 
			
		||||
                    # Mask
 | 
			
		||||
                    with gr.Row():
 | 
			
		||||
                        with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                            gr.Markdown(value='Mask', show_label=False)
 | 
			
		||||
                        with gr.Column(scale=4, min_width=10):
 | 
			
		||||
                            enable_add_mask = gr.Button('Edit Flexible Area')
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    form_reset_mask_btn = gr.Button("Reset mask")
 | 
			
		||||
                                with gr.Column(scale=1, min_width=10):
 | 
			
		||||
                                    show_mask = gr.Checkbox(
 | 
			
		||||
                                        label='Show Mask',
 | 
			
		||||
                                        value=global_state.value['show_mask'],
 | 
			
		||||
                                        show_label=False)
 | 
			
		||||
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                form_lambda_number = gr.Number(
 | 
			
		||||
                                    value=global_state.value["params"]
 | 
			
		||||
                                    ["motion_lambda"],
 | 
			
		||||
                                    interactive=True,
 | 
			
		||||
                                    label="Lambda",
 | 
			
		||||
                                )
 | 
			
		||||
 | 
			
		||||
                    form_draw_interval_number = gr.Number(
 | 
			
		||||
                        value=global_state.value["draw_interval"],
 | 
			
		||||
                        label="Draw Interval (steps)",
 | 
			
		||||
                        interactive=True,
 | 
			
		||||
                        visible=False)
 | 
			
		||||
 | 
			
		||||
                # Right --> Image
 | 
			
		||||
                with gr.Column(scale=8):
 | 
			
		||||
                    form_image = ImageMask(
 | 
			
		||||
                        value=global_state.value['images']['image_show'],
 | 
			
		||||
                        brush_radius=20).style(
 | 
			
		||||
                        width=768,
 | 
			
		||||
                        height=768)  # NOTE: hard image size code here.
 | 
			
		||||
        gr.Markdown("""
 | 
			
		||||
            ## Quick Start
 | 
			
		||||
 | 
			
		||||
            1. Select desired `Pretrained Model` and adjust `Seed` to generate an
 | 
			
		||||
               initial image.
 | 
			
		||||
            2. Click on image to add control points.
 | 
			
		||||
            3. Click `Start` and enjoy it!
 | 
			
		||||
 | 
			
		||||
            ## Advance Usage
 | 
			
		||||
 | 
			
		||||
            1. Change `Step Size` to adjust learning rate in drag optimization.
 | 
			
		||||
            2. Select `w` or `w+` to change latent space to optimize:
 | 
			
		||||
            * Optimize on `w` space may cause greater influence to the image.
 | 
			
		||||
            * Optimize on `w+` space may work slower than `w`, but usually achieve
 | 
			
		||||
              better results.
 | 
			
		||||
            * Note that changing the latent space will reset the image, points and
 | 
			
		||||
              mask (this has the same effect as `Reset Image` button).
 | 
			
		||||
            3. Click `Edit Flexible Area` to create a mask and constrain the
 | 
			
		||||
               unmasked region to remain unchanged.
 | 
			
		||||
            """)
 | 
			
		||||
        gr.HTML("""
 | 
			
		||||
            <style>
 | 
			
		||||
                .container {
 | 
			
		||||
                    position: absolute;
 | 
			
		||||
                    height: 50px;
 | 
			
		||||
                    text-align: center;
 | 
			
		||||
                    line-height: 50px;
 | 
			
		||||
                    width: 100%;
 | 
			
		||||
                }
 | 
			
		||||
            </style>
 | 
			
		||||
            <div class="container">
 | 
			
		||||
            Gradio demo supported by
 | 
			
		||||
            <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
 | 
			
		||||
            <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
 | 
			
		||||
            </div>
 | 
			
		||||
            """)
 | 
			
		||||
        show_mask.change(
 | 
			
		||||
            on_click_show_mask,
 | 
			
		||||
            inputs=[global_state, show_mask],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
        undo_points.click(on_click_clear_points,
 | 
			
		||||
                          inputs=[global_state],
 | 
			
		||||
                          outputs=[global_state, form_image])
 | 
			
		||||
        form_image.select(
 | 
			
		||||
            on_click_image,
 | 
			
		||||
            inputs=[global_state],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
        enable_add_mask.click(on_click_enable_draw,
 | 
			
		||||
                              inputs=[global_state, form_image],
 | 
			
		||||
                              outputs=[
 | 
			
		||||
                                  global_state,
 | 
			
		||||
                                  form_image,
 | 
			
		||||
                              ])
 | 
			
		||||
        enable_add_points.click(on_click_add_point,
 | 
			
		||||
                                inputs=[global_state, form_image],
 | 
			
		||||
                                outputs=[global_state, form_image])
 | 
			
		||||
        form_reset_mask_btn.click(
 | 
			
		||||
            on_click_reset_mask,
 | 
			
		||||
            inputs=[global_state],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        form_stop_btn.click(on_click_stop,
 | 
			
		||||
                            inputs=[global_state],
 | 
			
		||||
                            outputs=[global_state, form_stop_btn])
 | 
			
		||||
 | 
			
		||||
        form_draw_interval_number.change(
 | 
			
		||||
            partial(
 | 
			
		||||
                on_change_single_global_state,
 | 
			
		||||
                "draw_interval",
 | 
			
		||||
                map_transform=lambda x: int(x),
 | 
			
		||||
            ),
 | 
			
		||||
            inputs=[form_draw_interval_number, global_state],
 | 
			
		||||
            outputs=[global_state],
 | 
			
		||||
        )
 | 
			
		||||
        form_start_btn.click(
 | 
			
		||||
            on_click_start,
 | 
			
		||||
            inputs=[global_state, form_image],
 | 
			
		||||
            outputs=[
 | 
			
		||||
                global_state,
 | 
			
		||||
                form_steps_number,
 | 
			
		||||
                form_image,
 | 
			
		||||
                # form_download_result_file,
 | 
			
		||||
                # >>> buttons
 | 
			
		||||
                form_reset_image,
 | 
			
		||||
                enable_add_points,
 | 
			
		||||
                enable_add_mask,
 | 
			
		||||
                undo_points,
 | 
			
		||||
                form_reset_mask_btn,
 | 
			
		||||
                form_latent_space,
 | 
			
		||||
                form_start_btn,
 | 
			
		||||
                form_stop_btn,
 | 
			
		||||
                # <<< buttonm
 | 
			
		||||
                # >>> inputs comps
 | 
			
		||||
                form_pretrained_dropdown,
 | 
			
		||||
                form_seed_number,
 | 
			
		||||
                form_lr_number,
 | 
			
		||||
                show_mask,
 | 
			
		||||
                form_lambda_number,
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        form_lr_number.change(
 | 
			
		||||
            on_change_lr,
 | 
			
		||||
            inputs=[form_lr_number, global_state],
 | 
			
		||||
            outputs=[global_state],
 | 
			
		||||
        )
 | 
			
		||||
        form_custom_image.upload(on_click_inverse_custom_image, inputs=[form_custom_image, global_state],
 | 
			
		||||
                                 outputs=[global_state, form_image,form_reset_custom_image])
 | 
			
		||||
        form_save_image.click(on_save_image,inputs=[global_state,form_save_image_path],outputs=[])
 | 
			
		||||
 | 
			
		||||
        form_reset_custom_image.click(on_reset_custom_image,inputs=[global_state],outputs=[global_state,form_image])
 | 
			
		||||
        # ==== Params
 | 
			
		||||
        form_lambda_number.change(
 | 
			
		||||
            partial(on_change_single_global_state, ["params", "motion_lambda"]),
 | 
			
		||||
            inputs=[form_lambda_number, global_state],
 | 
			
		||||
            outputs=[global_state],
 | 
			
		||||
        )
 | 
			
		||||
        form_latent_space.change(on_click_latent_space,
 | 
			
		||||
                                 inputs=[form_latent_space, global_state],
 | 
			
		||||
                                 outputs=[global_state, form_image])
 | 
			
		||||
        form_seed_number.change(
 | 
			
		||||
            on_change_update_image_seed,
 | 
			
		||||
            inputs=[form_seed_number, global_state],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
        form_reset_image.click(
 | 
			
		||||
            on_click_reset_image,
 | 
			
		||||
            inputs=[global_state],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
        form_pretrained_dropdown.change(
 | 
			
		||||
            on_change_pretrained_dropdown,
 | 
			
		||||
            inputs=[form_pretrained_dropdown, global_state],
 | 
			
		||||
            outputs=[global_state, form_image],
 | 
			
		||||
        )
 | 
			
		||||
    #gr.close_all()
 | 
			
		||||
    app.queue(concurrency_count=3, max_size=20)
 | 
			
		||||
    app.launch(share=args.share)
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue