⬅ neural_de\external\derain\blocks.py source

1 """
2 2D convolution class
3 """
4 import functools
5  
6 import torch
7 import torchvision
8 from torch import nn
9  
10  
11 class Conv2d(torch.nn.Module):
12 """
13 2D convolution class
14 Args:
15 in_channels : int - Number of input channels
16 out_channels : int - Number of output channels
17 kernel_size : int - Size of kernel
18 stride : int - Stride of convolution
19 activation_func : func - Activation function after convolution
20 norm_layer : functools.partial - Normalization layer
21 use_bias : bool - If set, then use bias
22 padding_type : str - The name of padding layer: reflect | replicate | zero
23 """
24  
25 def __init__(
26 self,
27 in_channels,
28 out_channels,
29 kernel_size=3,
30 stride=1,
31 activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
32 norm_layer=nn.BatchNorm2d,
33 use_bias=False,
34 padding_type="reflect",
35 ):
36 super(Conv2d, self).__init__()
37  
38 self.activation_func = activation_func
39 conv_block = []
40 p = 0
41 if padding_type == "reflect":
42 conv_block += [nn.ReflectionPad2d(kernel_size // 2)]
43 elif padding_type == "replicate":
44 conv_block += [nn.ReplicationPad2d(kernel_size // 2)]
45 elif padding_type == "zero":
46 p = kernel_size // 2
47 else:
48 raise NotImplementedError("padding [%s] is not implemented" % padding_type)
49  
50 conv_block += [
51 nn.Conv2d(
52 in_channels,
53 out_channels,
54 stride=stride,
55 kernel_size=kernel_size,
56 padding=p,
57 bias=use_bias,
58 ),
59 norm_layer(out_channels),
60 ]
61  
62 self.conv = nn.Sequential(*conv_block)
63  
64 def forward(self, x):
65 conv = self.conv(x)
66  
67 if self.activation_func is not None:
68 return self.activation_func(conv)
69 else:
70 return conv
71  
72  
73 class DeformableConv2d(nn.Module):
74 """
75 2D deformable convolution class
76 Args:
77 in_channels : int - number of input channels
78 out_channels : int - number of output channels
79 kernel_size : int - size of kernel
80 stride : int - stride of convolution
81 padding : int - padding
82 use_bias : bool - if set, then use bias
83 """
84  
85 def __init__(
86 self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
87 ):
88 super(DeformableConv2d, self).__init__()
89  
  • E721 Do not compare types, for exact checks use `is` / `is not`, for instance checks use `isinstance()`
90 self.stride = stride if type(stride) == tuple else (stride, stride)
91 self.padding = padding
92  
93 self.offset_conv = nn.Conv2d(
94 in_channels,
95 2 * kernel_size * kernel_size,
96 kernel_size=kernel_size,
97 stride=stride,
98 padding=self.padding,
99 bias=True,
100 )
101  
102 nn.init.constant_(self.offset_conv.weight, 0.0)
103 nn.init.constant_(self.offset_conv.bias, 0.0)
104  
105 self.modulator_conv = nn.Conv2d(
106 in_channels,
107 1 * kernel_size * kernel_size,
108 kernel_size=kernel_size,
109 stride=stride,
110 padding=self.padding,
111 bias=True,
112 )
113  
114 nn.init.constant_(self.modulator_conv.weight, 0.0)
115 nn.init.constant_(self.modulator_conv.bias, 0.0)
116  
117 self.regular_conv = nn.Conv2d(
118 in_channels=in_channels,
119 out_channels=out_channels,
120 kernel_size=kernel_size,
121 stride=stride,
122 padding=self.padding,
123 bias=bias,
124 )
125  
126 def forward(self, x):
127 offset = self.offset_conv(x)
128 modulator = 2.0 * torch.sigmoid(self.modulator_conv(x))
129  
130 x = torchvision.ops.deform_conv2d(
131 input=x,
132 offset=offset,
133 weight=self.regular_conv.weight,
134 bias=self.regular_conv.bias,
135 padding=self.padding,
136 mask=modulator,
137 stride=self.stride,
138 )
139 return x
140  
141  
142 class UpConv2d(torch.nn.Module):
143 """
144 Up-convolution (upsample + convolution) block class
145 Args:
146 in_channels : int - number of input channels
147 out_channels : int - number of output channels
148 kernel_size : int - size of kernel (k x k)
149 activation_func : func - activation function after convolution
150 norm_layer : functools.partial - normalization layer
151 use_bias : bool - if set, then use bias
152 padding_type : str - the name of padding layer: reflect | replicate | zero
153 interpolate_mode : str - the mode for interpolation: bilinear | nearest
154 """
155  
156 def __init__(
157 self,
158 in_channels,
159 out_channels,
160 kernel_size=3,
161 activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
162 norm_layer=nn.BatchNorm2d,
163 use_bias=False,
164 padding_type="reflect",
165 interpolate_mode="bilinear",
166 ):
167 super(UpConv2d, self).__init__()
168 self.interpolate_mode = interpolate_mode
169  
170 self.conv = Conv2d(
171 in_channels,
172 out_channels,
173 kernel_size=kernel_size,
174 stride=1,
175 activation_func=activation_func,
176 norm_layer=norm_layer,
177 use_bias=use_bias,
178 padding_type=padding_type,
179 )
180  
181 def forward(self, x):
182 n_height, n_width = x.shape[2:4]
183 shape = (int(2 * n_height), int(2 * n_width))
184 upsample = torch.nn.functional.interpolate(
185 x, size=shape, mode=self.interpolate_mode, align_corners=True
186 )
187 conv = self.conv(upsample)
188 return conv
189  
190  
191 class DeformableResnetBlock(nn.Module):
192 """Define a Resnet block with deformable convolutions"""
193  
194 def __init__(
195 self, dim, padding_type, norm_layer, use_dropout, use_bias, activation_func
196 ):
197 """
198 Initialize the deformable Resnet block
199 A deformable resnet block is a conv block with skip connections
200 """
201 super(DeformableResnetBlock, self).__init__()
202 self.conv_block = self.build_conv_block(
203 dim, padding_type, norm_layer, use_dropout, use_bias, activation_func
204 )
205  
206 def build_conv_block(
207 self, dim, padding_type, norm_layer, use_dropout, use_bias, activation_func
208 ):
209 """
210 Construct a convolutional block.
211 Parameters:
212 dim (int) -- the number of channels in the conv layer.
213 padding_type (str) -- the name of padding layer: reflect | replicate | zero
214 norm_layer -- normalization layer
215 use_dropout (bool) -- if use dropout layers.
216 use_bias (bool) -- if the conv layer uses bias or not
217 activation_func (func) -- activation type
218 Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer)
219 """
220 conv_block = []
221  
222 p = 0
223 if padding_type == "reflect":
224 conv_block += [nn.ReflectionPad2d(1)]
225 elif padding_type == "replicate":
226 conv_block += [nn.ReplicationPad2d(1)]
227 elif padding_type == "zero":
228 p = 1
229 else:
230 raise NotImplementedError("padding [%s] is not implemented" % padding_type)
231  
232 conv_block += [
233 DeformableConv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
234 norm_layer(dim),
235 activation_func,
236 ]
237 if use_dropout:
238 conv_block += [nn.Dropout(0.5)]
239  
240 p = 0
241 if padding_type == "reflect":
242 conv_block += [nn.ReflectionPad2d(1)]
243 elif padding_type == "replicate":
244 conv_block += [nn.ReplicationPad2d(1)]
245 elif padding_type == "zero":
246 p = 1
247 else:
248 raise NotImplementedError("padding [%s] is not implemented" % padding_type)
249 conv_block += [
250 DeformableConv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
251 norm_layer(dim),
252 ]
253  
254 return nn.Sequential(*conv_block)
255  
256 def forward(self, x):
257 """Forward function (with skip connections)"""
258 out = x + self.conv_block(x) # add skip connections
259 return out
260  
261  
262 class DecoderBlock(torch.nn.Module):
263 """
264 Decoder block with skip connections
265 Args:
266 in_channels : int - number of input channels
267 skip_channels : int - number of skip connection channels
268 out_channels : int - number of output channels
269 activation_func : func - activation function after convolution
270 norm_layer : functools.partial - normalization layer
271 use_bias : bool - if set, then use bias
272 padding_type : str - the name of padding layer: reflect | replicate | zero
273 upsample_mode : str - the mode for interpolation: transpose | bilinear | nearest
274 """
275  
276 def __init__(
277 self,
278 in_channels,
279 skip_channels,
280 out_channels,
281 activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
282 norm_layer=nn.BatchNorm2d,
283 use_bias=False,
284 padding_type="reflect",
285 upsample_mode="transpose",
286 ):
287 super(DecoderBlock, self).__init__()
288  
289 self.skip_channels = skip_channels
290 self.upsample_mode = upsample_mode
291  
292 # Upsampling
293 if upsample_mode == "transpose":
294 self.deconv = nn.Sequential(
295 nn.ConvTranspose2d(
296 in_channels,
297 out_channels,
298 kernel_size=3,
299 stride=2,
300 padding=1,
301 output_padding=1,
302 bias=use_bias,
303 ),
304 norm_layer(out_channels),
305 activation_func,
306 )
307 else:
308 self.deconv = UpConv2d(
309 in_channels,
310 out_channels,
311 use_bias=use_bias,
312 activation_func=activation_func,
313 norm_layer=norm_layer,
314 padding_type=padding_type,
315 interpolate_mode=upsample_mode,
316 )
317  
318 concat_channels = skip_channels + out_channels
319  
320 self.conv = Conv2d(
321 concat_channels,
322 out_channels,
323 kernel_size=3,
324 stride=1,
325 activation_func=activation_func,
326 padding_type=padding_type,
327 norm_layer=norm_layer,
328 use_bias=use_bias,
329 )
330  
331 def forward(self, x, skip=None):
332 deconv = self.deconv(x)
333  
334 if self.skip_channels > 0:
335 concat = torch.cat([deconv, skip], dim=1)
336 else:
337 concat = deconv
338  
339 return self.conv(concat)
340  
341  
342 class ResNetModified(nn.Module):
343 """
344 Resnet-based generator that consists of deformable Resnet blocks.
345 """
346  
347 def __init__(
348 self,
349 input_nc,
350 output_nc,
351 ngf=64,
352 norm_layer=nn.BatchNorm2d,
353 activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
354 use_dropout=False,
355 n_blocks=6,
356 padding_type="reflect",
357 upsample_mode="bilinear",
358 ):
359 """
360 Construct a Resnet-based generator
361 Parameters:
362 input_nc (int) -- the number of channels in input images
363 output_nc (int) -- the number of channels in output images
364 ngf (int) -- the number of filters in the last conv layer
365 norm_layer -- normalization layer
366 use_dropout (bool) -- if use dropout layers
367 n_blocks (int) -- the number of ResNet blocks
368 padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
369 upsample_mode (str) -- mode for upsampling: transpose | bilinear
370 """
371 assert n_blocks >= 0
372 super(ResNetModified, self).__init__()
  • E721 Do not compare types, for exact checks use `is` / `is not`, for instance checks use `isinstance()`
373 if type(norm_layer) == functools.partial:
374 use_bias = norm_layer.func == nn.InstanceNorm2d
375 else:
376 use_bias = norm_layer == nn.InstanceNorm2d
377  
378 # Initial Convolution
379 self.initial_conv = nn.Sequential(
380 Conv2d(
381 in_channels=input_nc,
382 out_channels=ngf,
383 kernel_size=7,
384 padding_type=padding_type,
385 norm_layer=norm_layer,
386 activation_func=activation_func,
387 use_bias=use_bias,
388 ),
389 Conv2d(
390 in_channels=ngf,
391 out_channels=ngf,
392 kernel_size=3,
393 padding_type=padding_type,
394 norm_layer=norm_layer,
395 activation_func=activation_func,
396 use_bias=use_bias,
397 ),
398 )
399  
400 # Downsample Blocks
401 n_downsampling = 2
402 mult = 2 ** 0
403 self.downsample_1 = Conv2d(
404 in_channels=ngf * mult,
405 out_channels=ngf * mult * 2,
406 kernel_size=3,
407 stride=2,
408 padding_type=padding_type,
409 norm_layer=norm_layer,
410 activation_func=activation_func,
411 use_bias=use_bias,
412 )
413  
414 mult = 2 ** 1
415 self.downsample_2 = Conv2d(
416 in_channels=ngf * mult,
417 out_channels=ngf * mult * 2,
418 kernel_size=3,
419 stride=2,
420 padding_type=padding_type,
421 norm_layer=norm_layer,
422 activation_func=activation_func,
423 use_bias=use_bias,
424 )
425  
426 # Residual Blocks
427 residual_blocks = []
428 mult = 2 ** n_downsampling
429 for i in range(n_blocks): # add ResNet blocks
430 residual_blocks += [
431 DeformableResnetBlock(
432 ngf * mult,
433 padding_type=padding_type,
434 norm_layer=norm_layer,
435 use_dropout=use_dropout,
436 use_bias=use_bias,
437 activation_func=activation_func,
438 )
439 ]
440  
441 self.residual_blocks = nn.Sequential(*residual_blocks)
442  
443 # Upsampling
444 mult = 2 ** (n_downsampling - 0)
445 self.upsample_2 = DecoderBlock(
446 ngf * mult,
447 int(ngf * mult / 2),
448 int(ngf * mult / 2),
449 use_bias=use_bias,
450 activation_func=activation_func,
451 norm_layer=norm_layer,
452 padding_type=padding_type,
453 upsample_mode=upsample_mode,
454 )
455  
456 mult = 2 ** (n_downsampling - 1)
457 self.upsample_1 = DecoderBlock(
458 ngf * mult,
459 int(ngf * mult / 2),
460 int(ngf * mult / 2),
461 use_bias=use_bias,
462 activation_func=activation_func,
463 norm_layer=norm_layer,
464 padding_type=padding_type,
465 upsample_mode=upsample_mode,
466 )
467  
468 # Output Convolution
469 self.output_conv_naive = nn.Sequential(
470 nn.ReflectionPad2d(1),
471 nn.Conv2d(ngf, output_nc, kernel_size=3, padding=0),
472 nn.Tanh(),
473 )
474  
475 def forward(self, input):
476 """Standard forward"""
477  
478 # Downsample
479 initial_conv_out = self.initial_conv(input)
480 downsample_1_out = self.downsample_1(initial_conv_out)
481 downsample_2_out = self.downsample_2(downsample_1_out)
482  
483 # Residual
484 residual_blocks_out = self.residual_blocks(downsample_2_out)
485  
486 # Upsample
487 upsample_2_out = self.upsample_2(residual_blocks_out, downsample_1_out)
488 upsample_1_out = self.upsample_1(upsample_2_out, initial_conv_out)
489 final_out = self.output_conv_naive(upsample_1_out)
490  
491 return (final_out,)