Source code for neural_de.transformations.diffusion.rev_guided_diffusion

import logging

import numpy as np
import torch
import torchsde

from neural_de.transformations.diffusion.diffpure_config import DiffPureConfig, CHANNEL_MULT
from neural_de.transformations.diffusion.rev_vpsde import RevVPSDE
from neural_de.utils.twe_logger import get_logger
from neural_de.transformations.diffusion.unet.unet_model import UNetModel


[docs] class RevGuidedDiffusion(torch.nn.Module): """ Implements the rev-guided diffusion. Args: device: "cuda" or "cpu". Gpu is highly recommended but somme steps are available with cpu. config: An instance of DiffPureConfig, it has been created in the input of the DiffusionEhnancer class. logger: logger. """ def __init__(self, config: DiffPureConfig, device: torch.DeviceObjType = None, logger: logging.Logger = None): super().__init__() self._config = config self._logger = logger if logger is not None else get_logger() self._device = device self._logger.info('Building DiffPure model') self._logger.debug(f'Model Diffpure loaded with config : {self._config}') if not torch.cuda.is_available(): self._config.use_fp16 = False self._logger.info("No cuda detected, the diffusion model will use cpu, " "which provokes very slow inferences") if config.channel_mult is None: try: config.channel_mult = CHANNEL_MULT[config.image_size] except KeyError: raise NotImplementedError(f"unsupported image size: {config.image_size}") config.attention_resolutions = config.image_size // np.array(config.attention_resolutions).astype(np.int32) out_channels = 3 if not self._config.learn_sigma else 6 self._model = UNetModel(in_channels=3, out_channels=out_channels, config=self._config, logger=self._logger) self._logger.info(f'Loading DiffPure weights to device : {self._device}') self._model.load_state_dict(torch.load(self._config.weights_path, map_location=self._device)) if self._config.use_fp16: self._model.convert_to_fp16() self._model.eval().to(self._device) self._rev_vpsde = RevVPSDE(model=self._model, img_shape=self._config.img_shape, logger=self._logger).to(self._device) self._betas = self._rev_vpsde.discrete_betas.float().to(self._device)
[docs] def image_editing_sample(self, img: torch.Tensor): """ This method apply the rev-guided diffusion to a batch of images. Args: img: Tensor (batch of images) Returns: Tensor (batch of images) """ batch_size = img.shape[0] state_size = int(np.prod(img.shape[1:])) # c*h*w img = img.to(self._device) x0 = img xs = [] for it in range(self._config.sample_step): e = torch.randn_like(x0).to(self._device) total_noise_levels = self._config.t if self._config.rand_t: total_noise_levels = self._config.t + np.random.randint(-self._config.t_delta, self._config.t_delta) a = (1 - self._betas).cumprod(dim=0).to(self._device) x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt() epsilon_dt0, epsilon_dt1 = 0, 1e-5 t0, t1 = 1 - self._config.t * 1. / 1000 + epsilon_dt0, 1 - epsilon_dt1 t_size = 2 ts = torch.linspace(t0, t1, t_size).to(self._device) x_ = x.reshape(batch_size, -1) # (batch_size, state_size) if self._config.use_bm: bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=(batch_size, state_size), device=self._device) xs_ = torchsde.sdeint_adjoint(self._rev_vpsde, x_, ts, method='euler', bm=bm) else: xs_ = torchsde.sdeint_adjoint(self._rev_vpsde, x_, ts, method='euler') x0 = xs_[-1].view(x.shape) # (batch_size, c, h, w) xs.append(x0) return torch.cat(xs, dim=0)