⬅ neural_de\external\maxim_tf\maxim\blocks\grid_gating.py source

1 """
2 Blocks based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4  
5 import tensorflow as tf
6 from tensorflow.keras import backend as K
7 from tensorflow.keras import layers
8  
9 from ..layers import BlockImages, SwapAxes, UnblockImages
10  
11  
12 def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
13 """A SpatialGatingUnit as defined in the gMLP paper.
14  
15 The 'spatial' dim is defined as the second last.
16 If applied on other dims, you should swapaxes first.
17 """
18  
19 def apply(x):
20 u, v = tf.split(x, 2, axis=-1)
21 v = layers.LayerNormalization(
22 epsilon=1e-06, name=f"{name}_intermediate_layernorm"
23 )(v)
24 n = K.int_shape(x)[-3] # get spatial dim
25 v = SwapAxes()(v, -1, -3)
26 v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
27 v = SwapAxes()(v, -1, -3)
28 return u * (v + 1.0)
29  
30 return apply
31  
32  
33 def GridGmlpLayer(
34 grid_size,
35 use_bias: bool = True,
36 factor: int = 2,
37 dropout_rate: float = 0.0,
38 name: str = "grid_gmlp",
39 ):
40 """Grid gMLP layer that performs global mixing of tokens."""
41  
42 def apply(x):
  • F841 Local variable 'n' is assigned to but never used
43 n, h, w, num_channels = (
44 K.int_shape(x)[0],
45 K.int_shape(x)[1],
46 K.int_shape(x)[2],
47 K.int_shape(x)[3],
48 )
49 gh, gw = grid_size
50 fh, fw = h // gh, w // gw
51  
52 x = BlockImages()(x, patch_size=(fh, fw))
53 # gMLP1: Global (grid) mixing part, provides global grid communication.
54 y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
55 y = layers.Dense(
56 num_channels * factor,
57 use_bias=use_bias,
58 name=f"{name}_in_project",
59 )(y)
60 y = tf.nn.gelu(y, approximate=True)
61 y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
62 y = layers.Dense(
63 num_channels,
64 use_bias=use_bias,
65 name=f"{name}_out_project",
66 )(y)
67 y = layers.Dropout(dropout_rate)(y)
68 x = x + y
69 x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
70 return x
71  
72 return apply