mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
# Copyright (c) SenseTime Research. All rights reserved.
|
|
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
import wandb
|
|
from pti.pti_configs import global_config
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def log_image_from_w(w, G, name):
|
|
img = get_image_from_w(w, G)
|
|
pillow_image = Image.fromarray(img)
|
|
wandb.log(
|
|
{f"{name}": [
|
|
wandb.Image(pillow_image, caption=f"current inversion {name}")]},
|
|
step=global_config.training_step)
|
|
|
|
|
|
def log_images_from_w(ws, G, names):
|
|
for name, w in zip(names, ws):
|
|
w = w.to(global_config.device)
|
|
log_image_from_w(w, G, name)
|
|
|
|
|
|
def plot_image_from_w(w, G):
|
|
img = get_image_from_w(w, G)
|
|
pillow_image = Image.fromarray(img)
|
|
plt.imshow(pillow_image)
|
|
plt.show()
|
|
|
|
|
|
def plot_image(img):
|
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
|
|
pillow_image = Image.fromarray(img[0])
|
|
plt.imshow(pillow_image)
|
|
plt.show()
|
|
|
|
|
|
def save_image(name, method_type, results_dir, image, run_id):
|
|
image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg')
|
|
|
|
|
|
def save_w(w, G, name, method_type, results_dir):
|
|
im = get_image_from_w(w, G)
|
|
im = Image.fromarray(im, mode='RGB')
|
|
save_image(name, method_type, results_dir, im)
|
|
|
|
|
|
def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G,
|
|
old_G,
|
|
file_name,
|
|
extra_image=None):
|
|
images_to_save = []
|
|
if extra_image is not None:
|
|
images_to_save.append(extra_image)
|
|
for latent in image_latents:
|
|
images_to_save.append(get_image_from_w(latent, old_G))
|
|
images_to_save.append(get_image_from_w(new_inv_image_latent, new_G))
|
|
result_image = create_alongside_images(images_to_save)
|
|
result_image.save(f'{base_dir}/{file_name}.jpg')
|
|
|
|
|
|
def save_single_image(base_dir, image_latent, G, file_name):
|
|
image_to_save = get_image_from_w(image_latent, G)
|
|
image_to_save = Image.fromarray(image_to_save, mode='RGB')
|
|
image_to_save.save(f'{base_dir}/{file_name}.jpg')
|
|
|
|
|
|
def create_alongside_images(images):
|
|
res = np.concatenate([np.array(image) for image in images], axis=1)
|
|
return Image.fromarray(res, mode='RGB')
|
|
|
|
|
|
def get_image_from_w(w, G):
|
|
if len(w.size()) <= 2:
|
|
w = w.unsqueeze(0)
|
|
with torch.no_grad():
|
|
img = G.synthesis(w, noise_mode='const')
|
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
|
|
return img[0]
|