1 """
2 MAXIM based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4
5 import functools
6
7 import tensorflow as tf
8 from tensorflow.keras import backend as K
9 from tensorflow.keras import layers
10
11 from .blocks.attentions import SAM
12 from .blocks.bottleneck import BottleneckBlock
13 from .blocks.misc_gating import CrossGatingBlock
14 from .blocks.others import UpSampleRatio
15 from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
16 from .layers import Resizing
17
18 Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
19 Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
20 ConvT_up = functools.partial(
21 layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
22 )
23 Conv_down = functools.partial(
24 layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
25 )
26
27
28 def MAXIM(
29 features: int = 64,
30 depth: int = 3,
31 num_stages: int = 2,
32 num_groups: int = 1,
33 use_bias: bool = True,
34 num_supervision_scales: int = 1,
35 lrelu_slope: float = 0.2,
36 use_global_mlp: bool = True,
37 use_cross_gating: bool = True,
38 high_res_stages: int = 2,
39 block_size_hr=(16, 16),
40 block_size_lr=(8, 8),
41 grid_size_hr=(16, 16),
42 grid_size_lr=(8, 8),
43 num_bottleneck_blocks: int = 1,
44 block_gmlp_factor: int = 2,
45 grid_gmlp_factor: int = 2,
46 input_proj_factor: int = 2,
47 channels_reduction: int = 4,
48 num_outputs: int = 3,
49 dropout_rate: float = 0.0,
50 ):
51 """The MAXIM model function with multi-stage and multi-scale supervision.
52
53 For more model details, please check the CVPR paper:
54 MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
55
56 Attributes:
57 features: initial hidden dimension for the input resolution.
58 depth: the number of downsampling depth for the model.
59 num_stages: how many stages to use. It will also affects the output list.
60 num_groups: how many blocks each stage contains.
61 use_bias: whether to use bias in all the conv/mlp layers.
62 num_supervision_scales: the number of desired supervision scales.
63 lrelu_slope: the negative slope parameter in leaky_relu layers.
64 use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
65 layer.
66 use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
67 skip connections and multi-stage feature fusion layers.
68 high_res_stages: how many stages are specificied as high-res stages. The
69 rest (depth - high_res_stages) are called low_res_stages.
70 block_size_hr: the block_size parameter for high-res stages.
71 block_size_lr: the block_size parameter for low-res stages.
72 grid_size_hr: the grid_size parameter for high-res stages.
73 grid_size_lr: the grid_size parameter for low-res stages.
74 num_bottleneck_blocks: how many bottleneck blocks.
75 block_gmlp_factor: the input projection factor for block_gMLP layers.
76 grid_gmlp_factor: the input projection factor for grid_gMLP layers.
77 input_proj_factor: the input projection factor for the MAB block.
78 channels_reduction: the channel reduction factor for SE layer.
79 num_outputs: the output channels.
80 dropout_rate: Dropout rate.
81
82 Returns:
83 The output contains a list of arrays consisting of multi-stage multi-scale
84 outputs. For example, if num_stages = num_supervision_scales = 3 (the
85 model used in the paper), the output specs are: outputs =
86 [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
87 [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
88 [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
89 The final output can be retrieved by outputs[-1][-1].
90 """
91
92 def apply(x):
-
F841
Local variable 'n' is assigned to but never used
-
F841
Local variable 'c' is assigned to but never used
93 n, h, w, c = (
94 K.int_shape(x)[0],
95 K.int_shape(x)[1],
96 K.int_shape(x)[2],
97 K.int_shape(x)[3],
98 ) # input image target_shape
99
100 shortcuts = []
101 shortcuts.append(x)
102
103 # Get multi-scale input images
104 for i in range(1, num_supervision_scales):
105 resizing_layer = Resizing(
106 height=h // (2 ** i),
107 width=w // (2 ** i),
108 method="nearest",
109 antialias=True, # Following `jax.image.resize()`.
110 name=f"initial_resizing_{K.get_uid('Resizing')}",
111 )
112 shortcuts.append(resizing_layer(x))
113
114 # store outputs from all stages and all scales
115 # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs
116 # [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs
117 outputs_all = []
118 sam_features, encs_prev, decs_prev = [], [], []
119
120 for idx_stage in range(num_stages):
121 # Input convolution, get multi-scale input features
122 x_scales = []
123 for i in range(num_supervision_scales):
124 x_scale = Conv3x3(
125 filters=(2 ** i) * features,
126 use_bias=use_bias,
127 name=f"stage_{idx_stage}_input_conv_{i}",
128 )(shortcuts[i])
129
130 # If later stages, fuse input features with SAM features from prev stage
131 if idx_stage > 0:
132 # use larger blocksize at high-res stages
133 if use_cross_gating:
134 block_size = (
135 block_size_hr if i < high_res_stages else block_size_lr
136 )
137 grid_size = grid_size_hr if i < high_res_stages else block_size_lr
138 x_scale, _ = CrossGatingBlock(
139 features=(2 ** i) * features,
140 block_size=block_size,
141 grid_size=grid_size,
142 dropout_rate=dropout_rate,
143 input_proj_factor=input_proj_factor,
144 upsample_y=False,
145 use_bias=use_bias,
146 name=f"stage_{idx_stage}_input_fuse_sam_{i}",
147 )(x_scale, sam_features.pop())
148 else:
149 x_scale = Conv1x1(
150 filters=(2 ** i) * features,
151 use_bias=use_bias,
152 name=f"stage_{idx_stage}_input_catconv_{i}",
153 )(tf.concat([x_scale, sam_features.pop()], axis=-1))
154
155 x_scales.append(x_scale)
156
157 # start encoder blocks
158 encs = []
159 x = x_scales[0] # First full-scale input feature
160
161 for i in range(depth): # 0, 1, 2
162 # use larger blocksize at high-res stages, vice versa.
163 block_size = block_size_hr if i < high_res_stages else block_size_lr
164 grid_size = grid_size_hr if i < high_res_stages else block_size_lr
165 use_cross_gating_layer = True if idx_stage > 0 else False
166
167 # Multi-scale input if multi-scale supervision
168 x_scale = x_scales[i] if i < num_supervision_scales else None
169
170 # UNet Encoder block
171 enc_prev = encs_prev.pop() if idx_stage > 0 else None
172 dec_prev = decs_prev.pop() if idx_stage > 0 else None
173
174 x, bridge = UNetEncoderBlock(
175 num_channels=(2 ** i) * features,
176 num_groups=num_groups,
177 downsample=True,
178 lrelu_slope=lrelu_slope,
179 block_size=block_size,
180 grid_size=grid_size,
181 block_gmlp_factor=block_gmlp_factor,
182 grid_gmlp_factor=grid_gmlp_factor,
183 input_proj_factor=input_proj_factor,
184 channels_reduction=channels_reduction,
185 use_global_mlp=use_global_mlp,
186 dropout_rate=dropout_rate,
187 use_bias=use_bias,
188 use_cross_gating=use_cross_gating_layer,
189 name=f"stage_{idx_stage}_encoder_block_{i}",
190 )(x, skip=x_scale, enc=enc_prev, dec=dec_prev)
191
192 # Cache skip signals
193 encs.append(bridge)
194
195 # Global MLP bottleneck blocks
196 for i in range(num_bottleneck_blocks):
197 x = BottleneckBlock(
198 block_size=block_size_lr,
199 grid_size=block_size_lr,
200 features=(2 ** (depth - 1)) * features,
201 num_groups=num_groups,
202 block_gmlp_factor=block_gmlp_factor,
203 grid_gmlp_factor=grid_gmlp_factor,
204 input_proj_factor=input_proj_factor,
205 dropout_rate=dropout_rate,
206 use_bias=use_bias,
207 channels_reduction=channels_reduction,
208 name=f"stage_{idx_stage}_global_block_{i}",
209 )(x)
210 # cache global feature for cross-gating
211 global_feature = x
212
213 # start cross gating. Use multi-scale feature fusion
214 skip_features = []
215 for i in reversed(range(depth)): # 2, 1, 0
216 # use larger blocksize at high-res stages
217 block_size = block_size_hr if i < high_res_stages else block_size_lr
218 grid_size = grid_size_hr if i < high_res_stages else block_size_lr
219
220 # get additional multi-scale signals
221 signal = tf.concat(
222 [
223 UpSampleRatio(
224 num_channels=(2 ** i) * features,
225 ratio=2 ** (j - i),
226 use_bias=use_bias,
227 name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
228 )(enc)
229 for j, enc in enumerate(encs)
230 ],
231 axis=-1,
232 )
233
234 # Use cross-gating to cross modulate features
235 if use_cross_gating:
236 skips, global_feature = CrossGatingBlock(
237 features=(2 ** i) * features,
238 block_size=block_size,
239 grid_size=grid_size,
240 input_proj_factor=input_proj_factor,
241 dropout_rate=dropout_rate,
242 upsample_y=True,
243 use_bias=use_bias,
244 name=f"stage_{idx_stage}_cross_gating_block_{i}",
245 )(signal, global_feature)
246 else:
247 skips = Conv1x1(
248 filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
249 )(signal)
250 skips = Conv3x3(
251 filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
252 )(skips)
253
254 skip_features.append(skips)
255
256 # start decoder. Multi-scale feature fusion of cross-gated features
257 outputs, decs, sam_features = [], [], []
258 for i in reversed(range(depth)):
259 # use larger blocksize at high-res stages
260 block_size = block_size_hr if i < high_res_stages else block_size_lr
261 grid_size = grid_size_hr if i < high_res_stages else block_size_lr
262
263 # get multi-scale skip signals from cross-gating block
264 signal = tf.concat(
265 [
266 UpSampleRatio(
267 num_channels=(2 ** i) * features,
268 ratio=2 ** (depth - j - 1 - i),
269 use_bias=use_bias,
270 name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
271 )(skip)
272 for j, skip in enumerate(skip_features)
273 ],
274 axis=-1,
275 )
276
277 # Decoder block
278 x = UNetDecoderBlock(
279 num_channels=(2 ** i) * features,
280 num_groups=num_groups,
281 lrelu_slope=lrelu_slope,
282 block_size=block_size,
283 grid_size=grid_size,
284 block_gmlp_factor=block_gmlp_factor,
285 grid_gmlp_factor=grid_gmlp_factor,
286 input_proj_factor=input_proj_factor,
287 channels_reduction=channels_reduction,
288 use_global_mlp=use_global_mlp,
289 dropout_rate=dropout_rate,
290 use_bias=use_bias,
291 name=f"stage_{idx_stage}_decoder_block_{i}",
292 )(x, bridge=signal)
293
294 # Cache decoder features for later-stage's usage
295 decs.append(x)
296
297 # output conv, if not final stage, use supervised-attention-block.
298 if i < num_supervision_scales:
299 if idx_stage < num_stages - 1: # not last stage, apply SAM
300 sam, output = SAM(
301 num_channels=(2 ** i) * features,
302 output_channels=num_outputs,
303 use_bias=use_bias,
304 name=f"stage_{idx_stage}_supervised_attention_module_{i}",
305 )(x, shortcuts[i])
306 outputs.append(output)
307 sam_features.append(sam)
308 else: # Last stage, apply output convolutions
309 output = Conv3x3(
310 num_outputs,
311 use_bias=use_bias,
312 name=f"stage_{idx_stage}_output_conv_{i}",
313 )(x)
314 output = output + shortcuts[i]
315 outputs.append(output)
316 # Cache encoder and decoder features for later-stage's usage
317 encs_prev = encs[::-1]
318 decs_prev = decs
319
320 # Store outputs
321 outputs_all.append(outputs)
322 return outputs_all
323
324 return apply