Revise to be consistent with paper

pull/374/head
Xingang Pan 2 years ago
parent c5e88b3eaf
commit d0422b1b38

@ -353,11 +353,11 @@ class Renderer:
distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
relis, reljs = torch.where(distance < round(r1 / 512 * h))
direction = direction / (torch.linalg.norm(direction) + 1e-7)
gridh = (relis-direction[1]) / (h-1) * 2 - 1
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
gridh = (relis+direction[1]) / (h-1) * 2 - 1
gridw = (reljs+direction[0]) / (w-1) * 2 - 1
grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs].detach(), target)
loss = loss_motion
if mask is not None:

Loading…
Cancel
Save