Source code for neural_de.transformations.diffusion.unet.timestep_embed_sequential

from torch import nn

from neural_de.transformations.diffusion.unet.timestep_block import TimestepBlock


[docs] class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """
[docs] def forward(self, x, emb): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) else: x = layer(x) return x