"""
Night to day enhancer - Maxim based implementation
"""
import logging
from dataclasses import dataclass, asdict
from typing import Union
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from neural_de.transformations.transformation import BaseTransformation
from neural_de.external.maxim_tf.create_maxim_model import Model
from neural_de.external.maxim_tf.maxim.configs import MAXIM_CONFIGS
from neural_de.utils.math import get_pad_value, is_scaled
from neural_de.utils.twe_logger import log_and_raise
# To avoid having all the ram locked
physical_devices = tf.config.list_physical_devices('GPU')
for gpu_instance in physical_devices:
tf.config.experimental.set_memory_growth(gpu_instance, True)
# Static configuration : the method is only valid for these parameters.
_NIGHT_ENHANCER_MODEL = "https://tfhub.dev/sayakpaul/maxim_s-2_enhancement_fivek/1"
_S2_PADDING = 64
[docs]
@dataclass
class NightConfig:
"""
Static Enhancer configuration
"""
variant: str = "S-2"
dropout_rate: float = 0.0
num_outputs: int = 3
use_bias: bool = True
num_supervision_scales: int = 3
[docs]
class NightImageEnhancer(BaseTransformation):
"""
Provides Night to Day image transformation using the MAXIM model.
Args:
logger: It is recommended to use the Confiance logger, obtainable with
neural_de.utils.get_logger(...). If None, one logging with stdout will be provided.
"""
def __init__(self, device: str = 'cpu', logger: logging.Logger = None):
super().__init__(logger)
if device == 'cuda':
self._device = tf.config.list_logical_devices('GPU')[0].name
else:
self._device = device
self._s2_model = tf.keras.models.load_model(hub.resolve(_NIGHT_ENHANCER_MODEL))
self._logger.info('Model %s loaded', _NIGHT_ENHANCER_MODEL)
self._preprocessing_size = None
self._pipeline = None
self._config = NightConfig()
def _init_pipeline(self):
"""
Initialize the MAXIM model and pipeline.
"""
with tf.device(self._device):
configs = MAXIM_CONFIGS.get("S-2")
configs.update(asdict(self._config))
configs.update({"input_resolution": self._preprocessing_size})
self._pipeline = Model(**configs)
self._pipeline.set_weights(self._s2_model.get_weights())
@staticmethod
def _preprocessing(images: np.ndarray):
"""
Preprocess an image batch for the MAXIM model :
- normalize
- pad (with reflection) so that the image dimension are a multiple of *S2_PADDING* if
they are not
Args:
images: image batch to preprocess.
"""
images = np.asarray(images)
# normalize only if necessary
if not is_scaled(images[0]):
images = images / 255
padh = get_pad_value(images.shape[1], ratio=_S2_PADDING)
padw = get_pad_value(images.shape[2], ratio=_S2_PADDING)
padded_images = tf.pad(images, [(0, 0), (padh // 2, padh - padh // 2),
(padw // 2, padw - padw // 2), (0, 0)], mode="REFLECT")
return padded_images, padh, padw