|
|
|
|
@ -353,11 +353,11 @@ class Renderer:
|
|
|
|
|
distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
|
|
|
|
|
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
|
|
|
|
direction = direction / (torch.linalg.norm(direction) + 1e-7)
|
|
|
|
|
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
|
|
|
|
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
|
|
|
|
gridh = (relis+direction[1]) / (h-1) * 2 - 1
|
|
|
|
|
gridw = (reljs+direction[0]) / (w-1) * 2 - 1
|
|
|
|
|
grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
|
|
|
|
target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
|
|
|
|
|
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
|
|
|
|
|
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs].detach(), target)
|
|
|
|
|
|
|
|
|
|
loss = loss_motion
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|