Source code for neural_de.transformations.diffusion.rev_vpsde

import logging

import torch

from neural_de.utils.twe_logger import get_logger


[docs] class RevVPSDE(torch.nn.Module): """ Constructs a Variance Preserving SDE. Args: model: diffusion model beta_min: min value of beta for normalisation beta_max: max value of beta for normalisation N: scaling factor img_shape: Image dimension, channel-first. logger: logger (logging.Logger) """ def __init__(self, model: torch.nn.Module, beta_min: float = 0.1, beta_max: float = 20, N: int = 1000, img_shape: tuple = (3, 256, 256), logger: logging.Logger = None): super().__init__() self._logger = logger if logger is not None else get_logger() self._model = model self._img_shape = img_shape self._beta_0 = beta_min self._beta_1 = beta_max self._N = N self._beta_range = -0.5 * (self._beta_1 - self._beta_0) self.noise_type = "diagonal" self.sde_type = "ito" self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
[docs] def vpsde_fn(self, t: torch.Tensor, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Apply variant-preserving sde to a batch of images. Args: timesteps: current timestep x: image batch """ beta_t = self._beta_0 + t * (self._beta_1 - self._beta_0) drift = -0.5 * beta_t[:, None] * x diffusion = torch.sqrt(beta_t) return drift, diffusion
def _extract_info_from_output_tensor(self, timesteps: torch.Tensor, broadcast_shape: tuple)\ -> torch.Tensor: """ Compute and broadcast a multiplicative factor used by vpsde to obtain drift score. Args: timesteps: current timestep broadcast_shape: target shape """ res = torch.exp(self._beta_range * timesteps**2 - self._beta_0 * timesteps) res = (-1. / torch.sqrt(1. - res)).float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape)
[docs] def rvpsde_fn(self, t: torch.Tensor, x: torch.Tensor, return_type: str = 'drift'): """ Create the drift and diffusion functions for the reverse SDE Args: t: current step x: batch of input images return_type: if "drift", will apply a drift following the diffusion. If not, only the diffusion will be performed. """ drift, diffusion = self.vpsde_fn(t, x) if return_type != 'drift': return diffusion x_img = x.view(-1, *self._img_shape) disc_steps = (t.float() * self._N).long() model_output = self._model(x_img, disc_steps) model_output, _ = torch.split(model_output, self._img_shape[0], dim=1) model_output = model_output.view(x.shape[0], -1) score = self._extract_info_from_output_tensor(t, x.shape) * model_output return drift - diffusion[:, None] ** 2 * score
[docs] def f(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Creates the drift function -f(x, 1-t) (by t' = 1 - t) Sdeint only support a 2D tensor (batch_size, c*h*w) Args: t: current step x: batch of input images """ t = t.expand(x.shape[0]) return - self.rvpsde_fn(1 - t, x, return_type='drift')
[docs] def g(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Create the diffusion function g(1-t) (by t' = 1 - t) sdeint only support a 2D tensor (batch_size, c*h*w) Args: t: current step x: batch of input images """ t = t.expand(x.shape[0]) diffusion = self.rvpsde_fn(1 - t, x, return_type='diffusion') return diffusion[:, None].expand(x.shape)