"""
2D convolution class
"""
import functools
import torch
import torchvision
from torch import nn
[docs]
class Conv2d(torch.nn.Module):
"""
2D convolution class
Args:
in_channels : int - Number of input channels
out_channels : int - Number of output channels
kernel_size : int - Size of kernel
stride : int - Stride of convolution
activation_func : func - Activation function after convolution
norm_layer : functools.partial - Normalization layer
use_bias : bool - If set, then use bias
padding_type : str - The name of padding layer: reflect | replicate | zero
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
norm_layer=nn.BatchNorm2d,
use_bias=False,
padding_type="reflect",
):
super(Conv2d, self).__init__()
self.activation_func = activation_func
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(kernel_size // 2)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(kernel_size // 2)]
elif padding_type == "zero":
p = kernel_size // 2
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [
nn.Conv2d(
in_channels,
out_channels,
stride=stride,
kernel_size=kernel_size,
padding=p,
bias=use_bias,
),
norm_layer(out_channels),
]
self.conv = nn.Sequential(*conv_block)
[docs]
def forward(self, x):
conv = self.conv(x)
if self.activation_func is not None:
return self.activation_func(conv)
else:
return conv
[docs]
class UpConv2d(torch.nn.Module):
"""
Up-convolution (upsample + convolution) block class
Args:
in_channels : int - number of input channels
out_channels : int - number of output channels
kernel_size : int - size of kernel (k x k)
activation_func : func - activation function after convolution
norm_layer : functools.partial - normalization layer
use_bias : bool - if set, then use bias
padding_type : str - the name of padding layer: reflect | replicate | zero
interpolate_mode : str - the mode for interpolation: bilinear | nearest
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
norm_layer=nn.BatchNorm2d,
use_bias=False,
padding_type="reflect",
interpolate_mode="bilinear",
):
super(UpConv2d, self).__init__()
self.interpolate_mode = interpolate_mode
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
activation_func=activation_func,
norm_layer=norm_layer,
use_bias=use_bias,
padding_type=padding_type,
)
[docs]
def forward(self, x):
n_height, n_width = x.shape[2:4]
shape = (int(2 * n_height), int(2 * n_width))
upsample = torch.nn.functional.interpolate(
x, size=shape, mode=self.interpolate_mode, align_corners=True
)
conv = self.conv(upsample)
return conv
[docs]
class DecoderBlock(torch.nn.Module):
"""
Decoder block with skip connections
Args:
in_channels : int - number of input channels
skip_channels : int - number of skip connection channels
out_channels : int - number of output channels
activation_func : func - activation function after convolution
norm_layer : functools.partial - normalization layer
use_bias : bool - if set, then use bias
padding_type : str - the name of padding layer: reflect | replicate | zero
upsample_mode : str - the mode for interpolation: transpose | bilinear | nearest
"""
def __init__(
self,
in_channels,
skip_channels,
out_channels,
activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
norm_layer=nn.BatchNorm2d,
use_bias=False,
padding_type="reflect",
upsample_mode="transpose",
):
super(DecoderBlock, self).__init__()
self.skip_channels = skip_channels
self.upsample_mode = upsample_mode
# Upsampling
if upsample_mode == "transpose":
self.deconv = nn.Sequential(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=use_bias,
),
norm_layer(out_channels),
activation_func,
)
else:
self.deconv = UpConv2d(
in_channels,
out_channels,
use_bias=use_bias,
activation_func=activation_func,
norm_layer=norm_layer,
padding_type=padding_type,
interpolate_mode=upsample_mode,
)
concat_channels = skip_channels + out_channels
self.conv = Conv2d(
concat_channels,
out_channels,
kernel_size=3,
stride=1,
activation_func=activation_func,
padding_type=padding_type,
norm_layer=norm_layer,
use_bias=use_bias,
)
[docs]
def forward(self, x, skip=None):
deconv = self.deconv(x)
if self.skip_channels > 0:
concat = torch.cat([deconv, skip], dim=1)
else:
concat = deconv
return self.conv(concat)
[docs]
class ResNetModified(nn.Module):
"""
Resnet-based generator that consists of deformable Resnet blocks.
"""
def __init__(
self,
input_nc,
output_nc,
ngf=64,
norm_layer=nn.BatchNorm2d,
activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
use_dropout=False,
n_blocks=6,
padding_type="reflect",
upsample_mode="bilinear",
):
"""
Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
upsample_mode (str) -- mode for upsampling: transpose | bilinear
"""
assert n_blocks >= 0
super(ResNetModified, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
# Initial Convolution
self.initial_conv = nn.Sequential(
Conv2d(
in_channels=input_nc,
out_channels=ngf,
kernel_size=7,
padding_type=padding_type,
norm_layer=norm_layer,
activation_func=activation_func,
use_bias=use_bias,
),
Conv2d(
in_channels=ngf,
out_channels=ngf,
kernel_size=3,
padding_type=padding_type,
norm_layer=norm_layer,
activation_func=activation_func,
use_bias=use_bias,
),
)
# Downsample Blocks
n_downsampling = 2
mult = 2 ** 0
self.downsample_1 = Conv2d(
in_channels=ngf * mult,
out_channels=ngf * mult * 2,
kernel_size=3,
stride=2,
padding_type=padding_type,
norm_layer=norm_layer,
activation_func=activation_func,
use_bias=use_bias,
)
mult = 2 ** 1
self.downsample_2 = Conv2d(
in_channels=ngf * mult,
out_channels=ngf * mult * 2,
kernel_size=3,
stride=2,
padding_type=padding_type,
norm_layer=norm_layer,
activation_func=activation_func,
use_bias=use_bias,
)
# Residual Blocks
residual_blocks = []
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
residual_blocks += [
DeformableResnetBlock(
ngf * mult,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=use_bias,
activation_func=activation_func,
)
]
self.residual_blocks = nn.Sequential(*residual_blocks)
# Upsampling
mult = 2 ** (n_downsampling - 0)
self.upsample_2 = DecoderBlock(
ngf * mult,
int(ngf * mult / 2),
int(ngf * mult / 2),
use_bias=use_bias,
activation_func=activation_func,
norm_layer=norm_layer,
padding_type=padding_type,
upsample_mode=upsample_mode,
)
mult = 2 ** (n_downsampling - 1)
self.upsample_1 = DecoderBlock(
ngf * mult,
int(ngf * mult / 2),
int(ngf * mult / 2),
use_bias=use_bias,
activation_func=activation_func,
norm_layer=norm_layer,
padding_type=padding_type,
upsample_mode=upsample_mode,
)
# Output Convolution
self.output_conv_naive = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(ngf, output_nc, kernel_size=3, padding=0),
nn.Tanh(),
)
[docs]
def forward(self, input):
"""Standard forward"""
# Downsample
initial_conv_out = self.initial_conv(input)
downsample_1_out = self.downsample_1(initial_conv_out)
downsample_2_out = self.downsample_2(downsample_1_out)
# Residual
residual_blocks_out = self.residual_blocks(downsample_2_out)
# Upsample
upsample_2_out = self.upsample_2(residual_blocks_out, downsample_1_out)
upsample_1_out = self.upsample_1(upsample_2_out, initial_conv_out)
final_out = self.output_conv_naive(upsample_1_out)
return (final_out,)