Source code for neural_de.transformations.diffusion.unet.timestep_embed_sequential
from torch import nn
from neural_de.transformations.diffusion.unet.timestep_block import TimestepBlock
[docs]
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
[docs]
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x