"""
Implementation of the ResolutionEnhancer method.
Attributes:
UPSCALE_MODEL: version of the transformer model used for image upscaling
"""
import logging
from typing import Union
import cv2
import numpy as np
from torch import no_grad
from transformers import Swin2SRForImageSuperResolution
from transformers import Swin2SRImageProcessor
from neural_de.transformations.centered_zoom import CenteredZoom
from neural_de.utils.twe_logger import log_and_raise
from neural_de.transformations.transformation import BaseTransformation
UPSCALE_MODEL: str = "caidas/swin2SR-classical-sr-x2-64"
[docs]
class ResolutionEnhancer(BaseTransformation):
"""
BaseTransformation method for image resolution change.
It uses neural-based method for resolution enhancement, and Opencv for diminishing the
resolution.
Example :
See the notebook `examples/ResolutionEnhancer_example.ipynb` for more usage details.
1- Import the class
.. code-block:: python
from neural_de.transformations import ResolutionEnhancer
2- Create an instance of ResolutionEnhancer.
``device ="Cuda"`` is recommended if you have a gpu and torch with cuda enabled.
.. code-block:: python
res_shift = ResolutionEnhancer(device= "cpu")
3- Apply the resolution change to a batch of images to a given shape
.. code-block:: python
out_images = res_shift.transform(images, ratio=2)
Args:
device: Any torch-compatible device string.
logger: It is recommended to use the Confiance logger, obtainable with
neural_de.utils.get_logger(...). If None, one logging with stdout will be provided.
"""
def __init__(self, device: str = 'cpu', logger: logging.Logger = None):
super().__init__(logger)
self.check_device_validity(device)
self._device = device
self._processor = None
self._model = None
self._logger.info("ResolutionEnhancer Initialized ")
def _init_nn(self) -> None:
"""
Initialise the Swin2SR neural network used for image upsampling.
"""
self._processor = Swin2SRImageProcessor(do_pad=True)
self._model = Swin2SRForImageSuperResolution.from_pretrained(UPSCALE_MODEL) \
.to(self._device)
self._logger.info("Swin2 model loaded to %s", self._device)
def _intermediate_sampling(self, image: np.ndarray, shape: tuple) -> np.ndarray:
"""Uses **OpenCv** resize to get the image resolution to half the size of the final
target_shape.
Args:
image: Image to resize.
shape: Target size
Returns:
Resized image with half the size of target size
"""
image = image.astype(np.uint8)
n_cols = shape[1] // 2
n_rows = shape[0] // 2
if n_cols < 1 or n_rows < 1:
log_and_raise(self._logger, ValueError,
"Target target_shape is too small : no pixel on at least one dimension")
return cv2.resize(image, (n_cols, n_rows), interpolation=cv2.INTER_LINEAR)
def _upsample(self, images: np.ndarray) -> np.ndarray:
"""Uses a **SwinTransformer** to raise the resolution of image by a factor 2.
Args:
images: Batch of identically shaped images to resize.
Returns:
Resized image
"""
if self._processor is None:
self._init_nn()
pixel_values = self._processor(images, return_tensors="pt").pixel_values
with no_grad():
outputs = self._model(pixel_values.to(self._device))
output = outputs.reconstruction.data.float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=1, destination=-1)
# output = (output * 255.0).round().astype(np.uint8)
return output