Source code for neural_de.external.maxim_tf.maxim.layers

"""
Layers based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
"""

import einops
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from tensorflow.keras import backend as K
from tensorflow.keras import layers


[docs] @tf.keras.utils.register_keras_serializable("maxim") class BlockImages(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def call(self, x, patch_size): bs, h, w, num_channels = ( K.int_shape(x)[0], K.int_shape(x)[1], K.int_shape(x)[2], K.int_shape(x)[3], ) grid_height, grid_width = h // patch_size[0], w // patch_size[1] x = einops.rearrange( x, "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1], ) return x
[docs] def get_config(self): config = super().get_config().copy() return config
[docs] @tf.keras.utils.register_keras_serializable("maxim") class UnblockImages(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def call(self, x, grid_size, patch_size): x = einops.rearrange( x, "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1], ) return x
[docs] def get_config(self): config = super().get_config().copy() return config
[docs] @tf.keras.utils.register_keras_serializable("maxim") class SwapAxes(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def call(self, x, axis_one, axis_two): return tnp.swapaxes(x, axis_one, axis_two)
[docs] def get_config(self): config = super().get_config().copy() return config
[docs] @tf.keras.utils.register_keras_serializable("maxim") class Resizing(layers.Layer): def __init__(self, height, width, antialias=True, method="bilinear", **kwargs): super().__init__(**kwargs) self.height = height self.width = width self.antialias = antialias self.method = method
[docs] def call(self, x): return tf.image.resize( x, size=(self.height, self.width), antialias=self.antialias, method=self.method, )
[docs] def get_config(self): config = super().get_config().copy() config.update( { "height": self.height, "width": self.width, "antialias": self.antialias, "method": self.method, } ) return config