Source code for neural_de.utils.model_manager

"""
This module implements tools functions to manage external models used within the component
"""

import hashlib
import json
import logging
import os.path
from pathlib import Path
from .twe_logger import log_and_raise
import urllib.request
import yaml

# This variable shoud point to python env root_path
ROOT_PATH = Path(__file__).parent.parent.parent.resolve()


[docs] class ModelManager: """ Manages all external models required by the library """ def __init__(self, enhancer: str, required_model: str, logger: logging.Logger): self._logger = logger self._checksums = self._load_checksums() self._enhancer = enhancer self._model_filename = required_model self._enhancer_directory = ( Path(os.path.expanduser("~")) / ".neuralde" / enhancer ) self._model_filepath = self._enhancer_directory / required_model # Load model repostories with open( ROOT_PATH / "neural_de/external/_repositories/external_models_list.yaml", encoding="utf-8", ) as stream: try: self.external_models_list = yaml.safe_load(stream) except yaml.YAMLError: print( "error the external models list file is missing from package , it should be at \ /neural_de/external/_repositories/external_models_list.yaml" ) @staticmethod def _load_checksums(): """ Load each available model's checksum. """ with open( ROOT_PATH / "neural_de/external/_checksums/checksums.json", "r", encoding="utf-8", ) as checksum_file: checksums = json.load(checksum_file) return checksums
[docs] def download_model(self) -> None: """ Download weights for an enhancer if they are not already available locally. Weights will be stored at ~/.neuralde/{enhancer_name}/model.pth """ if not (self._is_model_available() and self._is_model_valid()): self._logger.info( "Model %s not found locally or corrupted, downloading it from server", self._model_filename, ) if self._enhancer in self.external_models_list.keys(): print( "Required pretrain model for ", self._enhancer, " enchancer is not present in local cache, downloading it . . .", ) urllib.request.urlretrieve( self.external_models_list[self._enhancer], self._model_filepath ) self._check_download_status() print("Pretrained model has been downloaded in cache") else: raise ValueError( "error there is no defined model repository for the enhancer ", self._enhancer, "in external models annuary", ) else: self._logger.info("Model already available locally, skipping download")
def _check_download_status(self): """ Validate if the model is locally present and valid, and raises an error if not. """ if not self._is_model_valid(): self._remove_corrupted_model() log_and_raise( self._logger, ValueError, "The downloaded file does not pass the checksum validation," " it might be invalid. It has been removed from your machine", ) else: self._logger.info("Model downloaded and validated") def _remove_corrupted_model(self): if self._model_filepath.is_file(): os.remove(self._model_filepath) else: log_and_raise( self._logger, FileNotFoundError, "Expected model was not found" ) def _is_model_available(self) -> bool: self._enhancer_directory.mkdir(parents=True, exist_ok=True) if self._model_filepath.is_file(): return True return False def _is_model_valid(self) -> bool: return self._calculate_checksum() == self._checksums[self._model_filename] def _calculate_checksum(self): hash_md5 = hashlib.md5() if self._model_filepath.is_file(): with open(self._model_filepath, "rb") as file: for chunk in iter(lambda: file.read(4096), b""): hash_md5.update(chunk) else: log_and_raise( self._logger, FileNotFoundError, "Expected model was not found" ) return hash_md5.hexdigest()