⬅ neural_de\external\maxim_tf\maxim\blocks\attentions.py source

1 """
2 Blocks based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py
3 """
4  
5 import functools
6  
7 import tensorflow as tf
8 from tensorflow.keras import layers
9  
10 from .others import MlpBlock
11  
12 Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
13 Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
14  
15  
16 def CALayer(
17 num_channels: int,
18 reduction: int = 4,
19 use_bias: bool = True,
20 name: str = "channel_attention",
21 ):
22 """Squeeze-and-excitation block for channel attention.
23  
24 ref: https://arxiv.org/abs/1709.01507
25 """
26  
27 def apply(x):
28 # 2D global average pooling
29 y = layers.GlobalAvgPool2D(keepdims=True)(x)
30 # Squeeze (in Squeeze-Excitation)
31 y = Conv1x1(
32 filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
33 )(y)
34 y = tf.nn.relu(y)
35 # Excitation (in Squeeze-Excitation)
36 y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
37 y = tf.nn.sigmoid(y)
38 return x * y
39  
40 return apply
41  
42  
43 def RCAB(
44 num_channels: int,
45 reduction: int = 4,
46 lrelu_slope: float = 0.2,
47 use_bias: bool = True,
48 name: str = "residual_ca",
49 ):
50 """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
51  
52 def apply(x):
53 shortcut = x
54 x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
55 x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
56 x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
57 x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
58 x = CALayer(
59 num_channels=num_channels,
60 reduction=reduction,
61 use_bias=use_bias,
62 name=f"{name}_channel_attention",
63 )(x)
64 return x + shortcut
65  
66 return apply
67  
68  
69 def RDCAB(
70 num_channels: int,
71 reduction: int = 16,
72 use_bias: bool = True,
73 dropout_rate: float = 0.0,
74 name: str = "rdcab",
75 ):
76 """Residual dense channel attention block. Used in Bottlenecks."""
77  
78 def apply(x):
79 y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
80 y = MlpBlock(
81 mlp_dim=num_channels,
82 dropout_rate=dropout_rate,
83 use_bias=use_bias,
84 name=f"{name}_channel_mixing",
85 )(y)
86 y = CALayer(
87 num_channels=num_channels,
88 reduction=reduction,
89 use_bias=use_bias,
90 name=f"{name}_channel_attention",
91 )(y)
92 x = x + y
93 return x
94  
95 return apply
96  
97  
98 def SAM(
99 num_channels: int,
100 output_channels: int = 3,
101 use_bias: bool = True,
102 name: str = "sam",
103 ):
104 """Supervised attention module for multi-stage training.
105  
106 Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
107 """
108  
109 def apply(x, x_image):
110 """Apply the SAM module to the input and num_channels.
111 Args:
112 x: the output num_channels from UNet decoder with target_shape (h, w, c)
113 x_image: the input image with target_shape (h, w, 3)
114 Returns:
115 A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
116 next stage, and (image) is the output restored image at current stage.
117 """
118 # Get num_channels
119 x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
120  
121 # Output restored image X_s
122 if output_channels == 3:
123 image = (
  • E126 Continuation line over-indented for hanging indent
124 Conv3x3(
125 filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
126 )(x)
127 + x_image
128 )
129 else:
130 image = Conv3x3(
131 filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
132 )(x)
133  
134 # Get attention maps for num_channels
135 x2 = tf.nn.sigmoid(
136 Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
137 )
138  
139 # Get attended feature maps
140 x1 = x1 * x2
141  
142 # Residual connection
143 x1 = x1 + x
144 return x1, image
145  
146 return apply