Source code for neural_de.external.maxim_tf.create_maxim_model

"""
Model class
"""
from tensorflow import keras

from .maxim import maxim
from .maxim.configs import MAXIM_CONFIGS


[docs] def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model: """Factory function to easily create a Model variant like "S". Args: variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3' | 'M-1' | 'M-2' | 'M-3' input_resolution: Size of the input images. **kw: Other UNet config dicts. Returns: The MAXIM model. """ if variant is not None: config = MAXIM_CONFIGS[variant] for k, v in config.items(): kw.setdefault(k, v) if "variant" in kw: _ = kw.pop("variant") if "input_resolution" in kw: _ = kw.pop("input_resolution") model_name = kw.pop("name") maxim_model = maxim.MAXIM(**kw) inputs = keras.Input((*input_resolution, 3)) outputs = maxim_model(inputs) final_model = keras.Model(inputs, outputs, name=f"{model_name}_model") return final_model