1 """
2 Blocks based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4
5 import functools
6
7 import tensorflow as tf
8 from tensorflow.keras import backend as K
9 from tensorflow.keras import layers
10
11 from ..layers import Resizing
12
13 Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
14
15
16 def MlpBlock(
17 mlp_dim: int,
18 dropout_rate: float = 0.0,
19 use_bias: bool = True,
20 name: str = "mlp_block",
21 ):
22 """A 1-hidden-layer MLP block, applied over the last dimension."""
23
24 def apply(x):
25 d = K.int_shape(x)[-1]
26 x = layers.Dense(mlp_dim, use_bias=use_bias, name=f"{name}_Dense_0")(x)
27 x = tf.nn.gelu(x, approximate=True)
28 x = layers.Dropout(dropout_rate)(x)
29 x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1")(x)
30 return x
31
32 return apply
33
34
35 def UpSampleRatio(
36 num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
37 ):
38 """Upsample features given a ratio > 0."""
39
40 def apply(x):
-
F841
Local variable 'n' is assigned to but never used
-
F841
Local variable 'c' is assigned to but never used
41 n, h, w, c = (
42 K.int_shape(x)[0],
43 K.int_shape(x)[1],
44 K.int_shape(x)[2],
45 K.int_shape(x)[3],
46 )
47
48 # Following `jax.image.resize()`
49 x = Resizing(
50 height=int(h * ratio),
51 width=int(w * ratio),
52 method="bilinear",
53 antialias=True,
54 name=f"{name}_resizing_{K.get_uid('Resizing')}",
55 )(x)
56
57 x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
58 return x
59
60 return apply