import logging
import torch as th
import torch.nn as nn
from neural_de.transformations.diffusion.unet.attention_block import AttentionBlock
from neural_de.transformations.diffusion.unet.downsample import Downsample
from neural_de.transformations.diffusion.unet.nn import (
conv_nd,
linear,
zero_module,
normalization,
timestep_embedding,
)
from neural_de.transformations.diffusion.unet.res_block import ResBlock
from neural_de.transformations.diffusion.unet.timestep_embed_sequential import TimestepEmbedSequential
from neural_de.transformations.diffusion.unet.upsample import Upsample
from neural_de.transformations.diffusion.unet.utils import convert_module_to_f16, convert_module_to_f32
from neural_de.transformations.diffusion.diffpure_config import DiffPureConfig
from neural_de.utils.twe_logger import get_logger
[docs]
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
config: DiffPureConfig,
logger: logging.Logger = None
):
super().__init__()
self._logger = logger if logger is not None else get_logger()
if config.num_heads_upsample == -1:
config.num_heads_upsample = config.num_heads
self._config = config
self._dtype = th.float16 if self._config.use_fp16 else th.float32
self._num_heads_upsample = config.num_heads_upsample
# model_channels = num_channels
time_embed_dim = self._config.num_channels * 4
self.time_embed = nn.Sequential(
linear(self._config.num_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self._config.num_classes is not None:
self._label_emb = nn.Embedding(self._config.num_classes, time_embed_dim)
ch = input_ch = int(self._config.channel_mult[0] * self._config.num_channels)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(self._config.dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(self._config.channel_mult):
for _ in range(self._config.num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
self._config.dropout,
out_channels=int(mult * self._config.num_channels),
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
)
]
ch = int(mult * self._config.num_channels)
if ds in self._config.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=self._config.use_checkpoint,
num_heads=self._config.num_heads,
num_head_channels=self._config.num_head_channels,
use_new_attention_order=self._config.use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(self._config.channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
self._config.dropout,
out_channels=out_ch,
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
down=True,
)
if self._config.resblock_updown
else Downsample(
ch, self._config.conv_resample, dims=self._config.dims,
out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
self._config.dropout,
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=self._config.use_checkpoint,
num_heads=self._config.num_heads,
num_head_channels=self._config.num_head_channels,
use_new_attention_order=self._config.use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
self._config.dropout,
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(self._config.channel_mult))[::-1]:
for i in range(self._config.num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
self._config.dropout,
out_channels=int(self._config.num_channels * mult),
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
)
]
ch = int(self._config.num_channels * mult)
if ds in self._config.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=self._config.use_checkpoint,
num_heads=self._config.num_heads_upsample,
num_head_channels=self._config.num_head_channels,
use_new_attention_order=self._config.use_new_attention_order,
)
)
if level and i == self._config.num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
self._config.dropout,
out_channels=out_ch,
dims=self._config.dims,
use_checkpoint=self._config.use_checkpoint,
use_scale_shift_norm=self._config.use_scale_shift_norm,
up=True,
)
if self._config.resblock_updown
else Upsample(ch, self._config.conv_resample, dims=self._config.dims,
out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(self._config.dims, input_ch, out_channels, 3, padding=1)),
)
[docs]
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
[docs]
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
[docs]
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self._config.num_channels))
if self._config.num_classes is not None:
emb = emb + self._label_emb(y)
h = x.type(self._dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)