from torch import nn
from neural_de.transformations.diffusion.unet.nn import conv_nd
from torch.nn import functional
[docs]
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
[docs]
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = functional.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = functional.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x