Coverage for robustAI/advertrain/dependencies/dropblock.py: 79%

62 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/rwightman/pytorch-image-models 

3 

4MIT License 

5""" 

6import torch 

7import torch.nn as nn 

8import torch.nn.functional as F 

9 

10 

11def drop_block_2d( 

12 x, 

13 drop_prob: float = 0.1, 

14 block_size: int = 7, 

15 gamma_scale: float = 1.0, 

16 with_noise: bool = False, 

17 inplace: bool = False, 

18 batchwise: bool = False, 

19): 

20 """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 

21 DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 

22 runs with success, but needs further validation and possibly optimization for lower runtime impact. 

23 """ 

24 B, C, H, W = x.shape 

25 total_size = W * H 

26 clipped_block_size = min(block_size, min(W, H)) 

27 # seed_drop_rate, the gamma parameter 

28 gamma = ( 

29 gamma_scale 

30 * drop_prob 

31 * total_size 

32 / clipped_block_size ** 2 

33 / ((W - block_size + 1) * (H - block_size + 1)) 

34 ) 

35 

36 # Forces the block to be inside the feature map. 

37 w_i, h_i = torch.meshgrid( 

38 torch.arange(W).to(x.device), torch.arange(H).to(x.device) 

39 ) 

40 valid_block = ( 

41 (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2) 

42 ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 

43 valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 

44 

45 if batchwise: 

46 # one mask for whole batch, quite a bit faster 

47 uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 

48 else: 

49 uniform_noise = torch.rand_like(x) 

50 block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 

51 block_mask = -F.max_pool2d( 

52 -block_mask, 

53 kernel_size=clipped_block_size, # block_size, 

54 stride=1, 

55 padding=clipped_block_size // 2, 

56 ) 

57 

58 if with_noise: 

59 normal_noise = ( 

60 torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) 

61 if batchwise 

62 else torch.randn_like(x) 

63 ) 

64 if inplace: 

65 x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 

66 else: 

67 x = x * block_mask + normal_noise * (1 - block_mask) 

68 else: 

69 normalize_scale = ( 

70 block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7) 

71 ).to(x.dtype) 

72 if inplace: 

73 x.mul_(block_mask * normalize_scale) 

74 else: 

75 x = x * block_mask * normalize_scale 

76 return x 

77 

78 

79def drop_block_fast_2d( 

80 x: torch.Tensor, 

81 drop_prob: float = 0.1, 

82 block_size: int = 7, 

83 gamma_scale: float = 1.0, 

84 with_noise: bool = False, 

85 inplace: bool = False, 

86 batchwise: bool = False, 

87): 

88 """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 

89 DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 

90 block mask at edges. 

91 """ 

92 B, C, H, W = x.shape 

93 total_size = W * H 

94 clipped_block_size = min(block_size, min(W, H)) 

95 gamma = ( 

96 gamma_scale 

97 * drop_prob 

98 * total_size 

99 / clipped_block_size ** 2 

100 / ((W - block_size + 1) * (H - block_size + 1)) 

101 ) 

102 

103 if batchwise: 

104 # one mask for whole batch, quite a bit faster 

105 block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma 

106 else: 

107 # mask per batch element 

108 block_mask = torch.rand_like(x) < gamma 

109 block_mask = F.max_pool2d( 

110 block_mask.to(x.dtype), 

111 kernel_size=clipped_block_size, 

112 stride=1, 

113 padding=clipped_block_size // 2, 

114 ) 

115 

116 if with_noise: 

117 normal_noise = ( 

118 torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) 

119 if batchwise 

120 else torch.randn_like(x) 

121 ) 

122 if inplace: 

123 x.mul_(1.0 - block_mask).add_(normal_noise * block_mask) 

124 else: 

125 x = x * (1.0 - block_mask) + normal_noise * block_mask 

126 else: 

127 block_mask = 1 - block_mask 

128 normalize_scale = ( 

129 block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7) 

130 ).to(dtype=x.dtype) 

131 if inplace: 

132 x.mul_(block_mask * normalize_scale) 

133 else: 

134 x = x * block_mask * normalize_scale 

135 return x 

136 

137 

138class DropBlock2d(nn.Module): 

139 """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf""" 

140 

141 def __init__( 

142 self, 

143 drop_prob=0.1, 

144 block_size=7, 

145 gamma_scale=1.0, 

146 with_noise=False, 

147 inplace=False, 

148 batchwise=False, 

149 fast=True, 

150 ): 

151 super(DropBlock2d, self).__init__() 

152 self.drop_prob = drop_prob 

153 self.gamma_scale = gamma_scale 

154 self.block_size = block_size 

155 self.with_noise = with_noise 

156 self.inplace = inplace 

157 self.batchwise = batchwise 

158 self.fast = fast # FIXME finish comparisons of fast vs not 

159 

160 def forward(self, x): 

161 if not self.training or not self.drop_prob: 

162 return x 

163 if self.fast: 

164 return drop_block_fast_2d( 

165 x, 

166 self.drop_prob, 

167 self.block_size, 

168 self.gamma_scale, 

169 self.with_noise, 

170 self.inplace, 

171 self.batchwise, 

172 ) 

173 else: 

174 return drop_block_2d( 

175 x, 

176 self.drop_prob, 

177 self.block_size, 

178 self.gamma_scale, 

179 self.with_noise, 

180 self.inplace, 

181 self.batchwise, 

182 )