⬅ neural_de\external\maxim_tf\maxim\blocks\misc_gating.py source

1 """
2 Blocks based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4  
5 import functools
6  
7 import tensorflow as tf
8 from tensorflow.keras import backend as K
9 from tensorflow.keras import layers
10  
11 from .block_gating import BlockGmlpLayer
12 from .grid_gating import GridGmlpLayer
13 from ..layers import BlockImages, SwapAxes, UnblockImages
14  
15 Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
16 Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
17 ConvT_up = functools.partial(
18 layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
19 )
20 Conv_down = functools.partial(
21 layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
22 )
23  
24  
25 def ResidualSplitHeadMultiAxisGmlpLayer(
26 block_size,
27 grid_size,
28 block_gmlp_factor: int = 2,
29 grid_gmlp_factor: int = 2,
30 input_proj_factor: int = 2,
31 use_bias: bool = True,
32 dropout_rate: float = 0.0,
33 name: str = "residual_split_head_maxim",
34 ):
35 """The multi-axis gated MLP block."""
36  
37 def apply(x):
38 shortcut = x
  • F841 Local variable 'n' is assigned to but never used
  • F841 Local variable 'h' is assigned to but never used
  • F841 Local variable 'w' is assigned to but never used
39 n, h, w, num_channels = (
40 K.int_shape(x)[0],
41 K.int_shape(x)[1],
42 K.int_shape(x)[2],
43 K.int_shape(x)[3],
44 )
45 x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
46  
47 x = layers.Dense(
48 int(num_channels) * input_proj_factor,
49 use_bias=use_bias,
50 name=f"{name}_in_project",
51 )(x)
52 x = tf.nn.gelu(x, approximate=True)
53  
54 u, v = tf.split(x, 2, axis=-1)
55  
56 # GridGMLPLayer
57 u = GridGmlpLayer(
58 grid_size=grid_size,
59 factor=grid_gmlp_factor,
60 use_bias=use_bias,
61 dropout_rate=dropout_rate,
62 name=f"{name}_GridGmlpLayer",
63 )(u)
64  
65 # BlockGMLPLayer
66 v = BlockGmlpLayer(
67 block_size=block_size,
68 factor=block_gmlp_factor,
69 use_bias=use_bias,
70 dropout_rate=dropout_rate,
71 name=f"{name}_BlockGmlpLayer",
72 )(v)
73  
74 x = tf.concat([u, v], axis=-1)
75  
76 x = layers.Dense(
77 num_channels,
78 use_bias=use_bias,
79 name=f"{name}_out_project",
80 )(x)
81 x = layers.Dropout(dropout_rate)(x)
82 x = x + shortcut
83 return x
84  
85 return apply
86  
87  
88 def GetSpatialGatingWeights(
89 features: int,
90 block_size,
91 grid_size,
92 input_proj_factor: int = 2,
93 dropout_rate: float = 0.0,
94 use_bias: bool = True,
95 name: str = "spatial_gating",
96 ):
97 """Get gating weights for cross-gating MLP block."""
98  
99 def apply(x):
  • F841 Local variable 'n' is assigned to but never used
100 n, h, w, num_channels = (
101 K.int_shape(x)[0],
102 K.int_shape(x)[1],
103 K.int_shape(x)[2],
104 K.int_shape(x)[3],
105 )
106  
107 # input projection
108 x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
109 x = layers.Dense(
110 num_channels * input_proj_factor,
111 use_bias=use_bias,
112 name=f"{name}_in_project",
113 )(x)
114 x = tf.nn.gelu(x, approximate=True)
115 u, v = tf.split(x, 2, axis=-1)
116  
117 # Get grid MLP weights
118 gh, gw = grid_size
119 fh, fw = h // gh, w // gw
120 u = BlockImages()(u, patch_size=(fh, fw))
121 dim_u = K.int_shape(u)[-3]
122 u = SwapAxes()(u, -1, -3)
123 u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
124 u = SwapAxes()(u, -1, -3)
125 u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
126  
127 # Get Block MLP weights
128 fh, fw = block_size
129 gh, gw = h // fh, w // fw
130 v = BlockImages()(v, patch_size=(fh, fw))
131 dim_v = K.int_shape(v)[-2]
132 v = SwapAxes()(v, -1, -2)
133 v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
134 v = SwapAxes()(v, -1, -2)
135 v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
136  
137 x = tf.concat([u, v], axis=-1)
138 x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
139 x = layers.Dropout(dropout_rate)(x)
140 return x
141  
142 return apply
143  
144  
145 def CrossGatingBlock(
146 features: int,
147 block_size,
148 grid_size,
149 dropout_rate: float = 0.0,
150 input_proj_factor: int = 2,
151 upsample_y: bool = True,
152 use_bias: bool = True,
153 name: str = "cross_gating",
154 ):
155 """Cross-gating MLP block."""
156  
157 def apply(x, y):
158 # Upscale Y signal, y is the gating signal.
159 if upsample_y:
160 y = ConvT_up(
161 filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
162 )(y)
163  
164 x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
  • F841 Local variable 'n' is assigned to but never used
  • F841 Local variable 'h' is assigned to but never used
  • F841 Local variable 'w' is assigned to but never used
165 n, h, w, num_channels = (
166 K.int_shape(x)[0],
167 K.int_shape(x)[1],
168 K.int_shape(x)[2],
169 K.int_shape(x)[3],
170 )
171  
172 y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
173  
174 shortcut_x = x
175 shortcut_y = y
176  
177 # Get gating weights from X
178 x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
179 x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
180 x = tf.nn.gelu(x, approximate=True)
181 gx = GetSpatialGatingWeights(
182 features=num_channels,
183 block_size=block_size,
184 grid_size=grid_size,
185 dropout_rate=dropout_rate,
186 use_bias=use_bias,
187 name=f"{name}_SplitHeadMultiAxisGating_x",
188 )(x)
189  
190 # Get gating weights from Y
191 y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
192 y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
193 y = tf.nn.gelu(y, approximate=True)
194 gy = GetSpatialGatingWeights(
195 features=num_channels,
196 block_size=block_size,
197 grid_size=grid_size,
198 dropout_rate=dropout_rate,
199 use_bias=use_bias,
200 name=f"{name}_SplitHeadMultiAxisGating_y",
201 )(y)
202  
203 # Apply cross gating: X = X * GY, Y = Y * GX
204 y = y * gx
205 y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
206 y = layers.Dropout(dropout_rate)(y)
207 y = y + shortcut_y
208  
209 x = x * gy # gating x using y
210 x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
211 x = layers.Dropout(dropout_rate)(x)
212 x = x + y + shortcut_x # get all aggregated signals
213 return x, y
214  
215 return apply