⬅ neural_de\external\maxim_tf\maxim\layers.py source

1 """
2 Layers based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4  
5 import einops
6 import tensorflow as tf
7 from tensorflow.experimental import numpy as tnp
8 from tensorflow.keras import backend as K
9 from tensorflow.keras import layers
10  
11  
12 @tf.keras.utils.register_keras_serializable("maxim")
13 class BlockImages(layers.Layer):
14 def __init__(self, **kwargs):
15 super().__init__(**kwargs)
16  
17 def call(self, x, patch_size):
  • F841 Local variable 'bs' is assigned to but never used
  • F841 Local variable 'num_channels' is assigned to but never used
18 bs, h, w, num_channels = (
19 K.int_shape(x)[0],
20 K.int_shape(x)[1],
21 K.int_shape(x)[2],
22 K.int_shape(x)[3],
23 )
24  
25 grid_height, grid_width = h // patch_size[0], w // patch_size[1]
26  
27 x = einops.rearrange(
28 x,
29 "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
30 gh=grid_height,
31 gw=grid_width,
32 fh=patch_size[0],
33 fw=patch_size[1],
34 )
35  
36 return x
37  
38 def get_config(self):
39 config = super().get_config().copy()
40 return config
41  
42  
43 @tf.keras.utils.register_keras_serializable("maxim")
44 class UnblockImages(layers.Layer):
45 def __init__(self, **kwargs):
46 super().__init__(**kwargs)
47  
48 def call(self, x, grid_size, patch_size):
49 x = einops.rearrange(
50 x,
51 "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
52 gh=grid_size[0],
53 gw=grid_size[1],
54 fh=patch_size[0],
55 fw=patch_size[1],
56 )
57  
58 return x
59  
60 def get_config(self):
61 config = super().get_config().copy()
62 return config
63  
64  
65 @tf.keras.utils.register_keras_serializable("maxim")
66 class SwapAxes(layers.Layer):
67 def __init__(self, **kwargs):
68 super().__init__(**kwargs)
69  
70 def call(self, x, axis_one, axis_two):
71 return tnp.swapaxes(x, axis_one, axis_two)
72  
73 def get_config(self):
74 config = super().get_config().copy()
75 return config
76  
77  
78 @tf.keras.utils.register_keras_serializable("maxim")
79 class Resizing(layers.Layer):
80 def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
81 super().__init__(**kwargs)
82 self.height = height
83 self.width = width
84 self.antialias = antialias
85 self.method = method
86  
87 def call(self, x):
88 return tf.image.resize(
89 x,
90 size=(self.height, self.width),
91 antialias=self.antialias,
92 method=self.method,
93 )
94  
95 def get_config(self):
96 config = super().get_config().copy()
97 config.update(
98 {
99 "height": self.height,
100 "width": self.width,
101 "antialias": self.antialias,
102 "method": self.method,
103 }
104 )
105 return config