1 """
2 Blocks based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4
5 import tensorflow as tf
6 from tensorflow.keras import backend as K
7 from tensorflow.keras import layers
8
9 from ..layers import BlockImages, SwapAxes, UnblockImages
10
11
12 def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
13 """A SpatialGatingUnit as defined in the gMLP paper.
14
15 The 'spatial' dim is defined as the **second last**.
16 If applied on other dims, you should swapaxes first.
17 """
18
19 def apply(x):
20 u, v = tf.split(x, 2, axis=-1)
21 v = layers.LayerNormalization(
22 epsilon=1e-06, name=f"{name}_intermediate_layernorm"
23 )(v)
24 n = K.int_shape(x)[-2] # get spatial dim
25 v = SwapAxes()(v, -1, -2)
26 v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
27 v = SwapAxes()(v, -1, -2)
28 return u * (v + 1.0)
29
30 return apply
31
32
33 def BlockGmlpLayer(
34 block_size,
35 use_bias: bool = True,
36 factor: int = 2,
37 dropout_rate: float = 0.0,
38 name: str = "block_gmlp",
39 ):
40 """Block gMLP layer that performs local mixing of tokens."""
41
42 def apply(x):
-
F841
Local variable 'n' is assigned to but never used
43 n, h, w, num_channels = (
44 K.int_shape(x)[0],
45 K.int_shape(x)[1],
46 K.int_shape(x)[2],
47 K.int_shape(x)[3],
48 )
49 fh, fw = block_size
50 gh, gw = h // fh, w // fw
51 x = BlockImages()(x, patch_size=(fh, fw))
52 # MLP2: Local (block) mixing part, provides within-block communication.
53 y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
54 y = layers.Dense(
55 num_channels * factor,
56 use_bias=use_bias,
57 name=f"{name}_in_project",
58 )(y)
59 y = tf.nn.gelu(y, approximate=True)
60 y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
61 y = layers.Dense(
62 num_channels,
63 use_bias=use_bias,
64 name=f"{name}_out_project",
65 )(y)
66 y = layers.Dropout(dropout_rate)(y)
67 x = x + y
68 x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
69 return x
70
71 return apply