|
|
|
|
@ -65,9 +65,10 @@ def style_mixing_video(network_pkl: str,
|
|
|
|
|
print('col_seeds: ', dst_seeds)
|
|
|
|
|
num_frames = int(np.rint(duration_sec * mp4_fps))
|
|
|
|
|
print('Loading networks from "%s"...' % network_pkl)
|
|
|
|
|
device = torch.device('cuda')
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
|
|
|
|
dtype = torch.float32 if device.type == 'mps' else torch.float64
|
|
|
|
|
with dnnlib.util.open_url(network_pkl) as f:
|
|
|
|
|
Gs = legacy.load_network_pkl(f)['G_ema'].to(device)
|
|
|
|
|
Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
|
|
|
|
|
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
|
|
|
|
|
@ -80,14 +81,14 @@ def style_mixing_video(network_pkl: str,
|
|
|
|
|
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
|
|
|
|
|
src_z /= np.sqrt(np.mean(np.square(src_z)))
|
|
|
|
|
# Map into the detangled latent space W and do truncation trick
|
|
|
|
|
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None)
|
|
|
|
|
src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
|
|
|
|
|
w_avg = Gs.mapping.w_avg
|
|
|
|
|
src_w = w_avg + (src_w - w_avg) * truncation_psi
|
|
|
|
|
|
|
|
|
|
# Top row latents (fixed reference)
|
|
|
|
|
print('Generating Destination W vectors...')
|
|
|
|
|
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
|
|
|
|
|
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None)
|
|
|
|
|
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
|
|
|
|
|
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
|
|
|
|
|
# Get the width and height of each image:
|
|
|
|
|
H = Gs.img_resolution # 1024
|
|
|
|
|
@ -120,7 +121,7 @@ def style_mixing_video(network_pkl: str,
|
|
|
|
|
for col, dst_image in enumerate(list(dst_images)):
|
|
|
|
|
# Select the pertinent latent w column:
|
|
|
|
|
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
|
|
|
|
|
w_col = torch.from_numpy(w_col).to(device)
|
|
|
|
|
w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
|
|
|
|
|
# Replace the values defined by col_styles:
|
|
|
|
|
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
|
|
|
|
|
# Generate these synthesized images:
|
|
|
|
|
|