Source code for neural_de.transformations.transformation
"""
Parent class for any transformation method.
"""
from __future__ import annotations
import logging
from typing import Union
import numpy as np
from neural_de.utils.twe_logger import log_and_raise, get_logger
from neural_de.utils.validation import is_batch_valid, is_device_valid
[docs]
class BaseTransformation:
"""
Parent class for any transformation methods of the library.
Provides the methods for logging and input validation.
Args:
logger: logging.logger. It is recommended to use the Confiance one, obtainable with
neural_de.utils.get_logger(...)
"""
def __init__(self, logger: logging.Logger = None):
if logger is None:
self._logger = get_logger()
else:
self._logger = logger
def _check_batch_validity(self, images: Union[list, np.ndarray], same_dim: bool = False):
"""
Check if the batch of images provided by the user conforms to the expected standards.
Raises an error if it does not.
Args:
images: list / batch of images to validate.
same_dim: Check if all images have the same dimension (Optional).
Returns:
None
"""
is_valid, reason = is_batch_valid(images, same_dim=same_dim)
if not is_valid:
log_and_raise(self._logger, ValueError,
"Parameter images is not a valid input batch:" + reason)
[docs]
def check_device_validity(self, device: str = "cpu"):
"""
Check if the selected device is valid.
Args:
device: str - cpu / cuda
Returns:
None
"""
if not is_device_valid(device):
log_and_raise(self._logger, TypeError,
f"Device {device} is not a valid Pytorch device")