diff --git a/viz/renderer.py b/viz/renderer.py index 26e4f11..6955f3d 100644 --- a/viz/renderer.py +++ b/viz/renderer.py @@ -225,7 +225,7 @@ class Renderer: res.num_ws = G.num_ws res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) - + self.lr = lr # Set input transform. if res.has_input_transform: m = np.eye(3) @@ -262,6 +262,17 @@ class Renderer: self.feat_refs = None self.points0_pt = None + def set_latent(self,w,trunc_psi,trunc_cutoff): + #label = torch.zeros([1, self.G.c_dim], device=self._device) + #w = self.G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) + self.w0 = w.detach().clone() + if self.w_plus: + self.w = w.detach() + else: + self.w = w[:, 0, :].detach() + self.w.requires_grad = True + self.w_optim = torch.optim.Adam([self.w], lr=self.lr) + def update_lr(self, lr): del self.w_optim