Source code for neural_de.transformations.diffusion.unet.upsample

from torch import nn
from neural_de.transformations.diffusion.unet.nn import conv_nd
from torch.nn import functional


[docs] class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
[docs] def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = functional.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: x = functional.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x