"""
DeRain enhancer
"""
import logging
import os
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Union
import numpy as np
import torch
from neural_de.transformations.transformation import BaseTransformation
from neural_de.external.derain.blocks import ResNetModified
from neural_de.utils.model_manager import ModelManager
from neural_de.utils.math import get_pad_value, is_scaled
# Only configuration validated : user should not modify it.
[docs]
@dataclass
class DeRainConfig:
"""
Interal configuration of the DeRain enhancer.
"""
upsample_mode: str = "bilinear"
ngf: int = 64
n_blocks: int = 9
use_dropout: bool = False
input_nc: int = 3
output_nc: int = 3
padding_type: str = "reflect"
# Model loaded from Minio
_ENHANCER = "derain"
_MODEL_FILENAME = "derain_checkpoint.pth"
_DOWNLOADED_MODEL_PATH = Path(os.path.expanduser("~")) / ".neuralde" / _ENHANCER / _MODEL_FILENAME
[docs]
class DeRainEnhancer(BaseTransformation):
"""
Provides a rain removal image transformation using the GT-Rain Derain Model.
Args:
device: Any torch-compatible device string.
logger: It is recommended to use the Confiance logger, obtainable with
neural_de.utils.get_logger(...). If None, one logging with stdout will be provided.
"""
def __init__(self, device: str = 'cpu', logger: logging.Logger = None):
super().__init__(logger)
self.check_device_validity(device)
self._device = device
# Download model if not available locally
self._logger.info("Checking model availability...")
self._model_manager = ModelManager(enhancer=_ENHANCER,
required_model=_MODEL_FILENAME,
logger=self._logger)
self._model_manager.download_model()
# Build the GT rain model
self._resnet = ResNetModified(**asdict(DeRainConfig())).to(self._device)
# Load train weights
state_dict = torch.load(_DOWNLOADED_MODEL_PATH, map_location=self._device)["state_dict"]
# Fix the relative path of resnet in state_dict
state_dict = {key[key.index(".") + 1:]: val for key, val in state_dict.items()}
self._resnet.load_state_dict(state_dict, strict=True)
self._resnet.eval()
# the resnet used expect inputs to be a multiple of 4
self._ratio = 4
self._logger.info("Trained model GT-rain loaded")