mirror of https://github.com/XingangPan/DragGAN
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
367 lines
16 KiB
Python
367 lines
16 KiB
Python
# Copyright (c) SenseTime Research. All rights reserved.
|
|
|
|
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
|
#
|
|
# This work is made available under the Nvidia Source Code License-NC.
|
|
# To view a copy of this license, visit
|
|
# https://nvlabs.github.io/stylegan2/license.html
|
|
|
|
"""Custom TensorFlow ops for efficient resampling of 2D images."""
|
|
|
|
import os
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from .. import custom_ops
|
|
|
|
def _get_plugin():
|
|
return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
|
|
r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
|
|
|
|
Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
|
|
and performs the following operations for each image, batched across
|
|
`majorDim` and `minorDim`:
|
|
|
|
1. Pad the image with zeros by the specified number of pixels on each side
|
|
(`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
|
|
corresponds to cropping the image.
|
|
|
|
2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
|
|
|
|
3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
|
|
image so that the footprint of all output pixels lies within the input image.
|
|
|
|
4. Downsample the image by throwing away pixels (`downx`, `downy`).
|
|
|
|
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
|
The fused op is considerably more efficient than performing the same calculation
|
|
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
|
|
k: 2D FIR filter of the shape `[firH, firW]`.
|
|
upx: Integer upsampling factor along the X-axis (default: 1).
|
|
upy: Integer upsampling factor along the Y-axis (default: 1).
|
|
downx: Integer downsampling factor along the X-axis (default: 1).
|
|
downy: Integer downsampling factor along the Y-axis (default: 1).
|
|
padx0: Number of pixels to pad on the left side (default: 0).
|
|
padx1: Number of pixels to pad on the right side (default: 0).
|
|
pady0: Number of pixels to pad on the top side (default: 0).
|
|
pady1: Number of pixels to pad on the bottom side (default: 0).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
|
|
"""
|
|
|
|
impl_dict = {
|
|
'ref': _upfirdn_2d_ref,
|
|
'cuda': _upfirdn_2d_cuda,
|
|
}
|
|
return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
|
|
"""Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
|
|
|
|
x = tf.convert_to_tensor(x)
|
|
k = np.asarray(k, dtype=np.float32)
|
|
assert x.shape.rank == 4
|
|
inH = x.shape[1].value
|
|
inW = x.shape[2].value
|
|
minorDim = _shape(x, 3)
|
|
kernelH, kernelW = k.shape
|
|
assert inW >= 1 and inH >= 1
|
|
assert kernelW >= 1 and kernelH >= 1
|
|
assert isinstance(upx, int) and isinstance(upy, int)
|
|
assert isinstance(downx, int) and isinstance(downy, int)
|
|
assert isinstance(padx0, int) and isinstance(padx1, int)
|
|
assert isinstance(pady0, int) and isinstance(pady1, int)
|
|
|
|
# Upsample (insert zeros).
|
|
x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
|
|
x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
|
|
x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
|
|
|
|
# Pad (crop if negative).
|
|
x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
|
|
x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
|
|
|
|
# Convolve with filter.
|
|
x = tf.transpose(x, [0, 3, 1, 2])
|
|
x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
|
|
w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
|
|
x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
|
|
x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
|
|
x = tf.transpose(x, [0, 2, 3, 1])
|
|
|
|
# Downsample (throw away pixels).
|
|
return x[:, ::downy, ::downx, :]
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
|
|
"""Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
|
|
|
|
x = tf.convert_to_tensor(x)
|
|
k = np.asarray(k, dtype=np.float32)
|
|
majorDim, inH, inW, minorDim = x.shape.as_list()
|
|
kernelH, kernelW = k.shape
|
|
assert inW >= 1 and inH >= 1
|
|
assert kernelW >= 1 and kernelH >= 1
|
|
assert isinstance(upx, int) and isinstance(upy, int)
|
|
assert isinstance(downx, int) and isinstance(downy, int)
|
|
assert isinstance(padx0, int) and isinstance(padx1, int)
|
|
assert isinstance(pady0, int) and isinstance(pady1, int)
|
|
|
|
outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
|
|
outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
|
|
assert outW >= 1 and outH >= 1
|
|
|
|
kc = tf.constant(k, dtype=x.dtype)
|
|
gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
|
|
gpadx0 = kernelW - padx0 - 1
|
|
gpady0 = kernelH - pady0 - 1
|
|
gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
|
|
gpady1 = inH * upy - outH * downy + pady0 - upy + 1
|
|
|
|
@tf.custom_gradient
|
|
def func(x):
|
|
y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
|
|
y.set_shape([majorDim, outH, outW, minorDim])
|
|
@tf.custom_gradient
|
|
def grad(dy):
|
|
dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
|
|
dx.set_shape([majorDim, inH, inW, minorDim])
|
|
return dx, func
|
|
return y, grad
|
|
return func(x)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
|
|
r"""Filter a batch of 2D images with the given FIR filter.
|
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
|
and filters each image with the given filter. The filter is normalized so that
|
|
if the input pixels are constant, they will be scaled by the specified `gain`.
|
|
Pixels outside the image are assumed to be zero.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the same shape and datatype as `x`.
|
|
"""
|
|
|
|
k = _setup_kernel(k) * gain
|
|
p = k.shape[0] - 1
|
|
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
|
r"""Upsample a batch of 2D images with the given filter.
|
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
|
and upsamples each image with the given filter. The filter is normalized so that
|
|
if the input pixels are constant, they will be scaled by the specified `gain`.
|
|
Pixels outside the image are assumed to be zero, and the filter is padded with
|
|
zeros so that its shape is a multiple of the upsampling factor.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
|
The default is `[1] * factor`, which corresponds to nearest-neighbor
|
|
upsampling.
|
|
factor: Integer upsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
|
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * (gain * (factor ** 2))
|
|
p = k.shape[0] - factor
|
|
return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
|
r"""Downsample a batch of 2D images with the given filter.
|
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
|
and downsamples each image with the given filter. The filter is normalized so that
|
|
if the input pixels are constant, they will be scaled by the specified `gain`.
|
|
Pixels outside the image are assumed to be zero, and the filter is padded with
|
|
zeros so that its shape is a multiple of the downsampling factor.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
|
The default is `[1] * factor`, which corresponds to average pooling.
|
|
factor: Integer downsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
|
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * gain
|
|
p = k.shape[0] - factor
|
|
return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
|
r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
|
|
|
Padding is performed only once at the beginning, not between the operations.
|
|
The fused op is considerably more efficient than performing the same calculation
|
|
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
|
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
|
|
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
|
The default is `[1] * factor`, which corresponds to nearest-neighbor
|
|
upsampling.
|
|
factor: Integer upsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
|
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
|
|
# Check weight shape.
|
|
w = tf.convert_to_tensor(w)
|
|
assert w.shape.rank == 4
|
|
convH = w.shape[0].value
|
|
convW = w.shape[1].value
|
|
inC = _shape(w, 2)
|
|
outC = _shape(w, 3)
|
|
assert convW == convH
|
|
|
|
# Setup filter kernel.
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * (gain * (factor ** 2))
|
|
p = (k.shape[0] - factor) - (convW - 1)
|
|
|
|
# Determine data dimensions.
|
|
if data_format == 'NCHW':
|
|
stride = [1, 1, factor, factor]
|
|
output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
|
|
num_groups = _shape(x, 1) // inC
|
|
else:
|
|
stride = [1, factor, factor, 1]
|
|
output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
|
|
num_groups = _shape(x, 3) // inC
|
|
|
|
# Transpose weights.
|
|
w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
|
|
w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
|
|
w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
|
|
|
|
# Execute.
|
|
x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
|
|
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
|
|
|
|
#----------------------------------------------------------------------------
|
|
|
|
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
|
r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
|
|
|
Padding is performed only once at the beginning, not between the operations.
|
|
The fused op is considerably more efficient than performing the same calculation
|
|
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
|
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
|
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
|
|
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
|
The default is `[1] * factor`, which corresponds to average pooling.
|
|
factor: Integer downsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
|
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
|
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
w = tf.convert_to_tensor(w)
|
|
convH, convW, _inC, _outC = w.shape.as_list()
|
|
assert convW == convH
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * gain
|
|
p = (k.shape[0] - factor) + (convW - 1)
|
|
if data_format == 'NCHW':
|
|
s = [1, 1, factor, factor]
|
|
else:
|
|
s = [1, factor, factor, 1]
|
|
x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
|
return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Internal helper funcs.
|
|
|
|
def _shape(tf_expr, dim_idx):
|
|
if tf_expr.shape.rank is not None:
|
|
dim = tf_expr.shape[dim_idx].value
|
|
if dim is not None:
|
|
return dim
|
|
return tf.shape(tf_expr)[dim_idx]
|
|
|
|
def _setup_kernel(k):
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
assert k.ndim == 2
|
|
assert k.shape[0] == k.shape[1]
|
|
return k
|
|
|
|
def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
|
|
assert data_format in ['NCHW', 'NHWC']
|
|
assert x.shape.rank == 4
|
|
y = x
|
|
if data_format == 'NCHW':
|
|
y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
|
|
y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
|
|
if data_format == 'NCHW':
|
|
y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
|
|
return y
|
|
|
|
#----------------------------------------------------------------------------
|