Coverage for adaro_rl / attacks / registry.py: 92%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from . import random, fgm, fgsm 

2import inspect 

3import numpy as np 

4 

5 

6def make_attack(attack_name, **kwargs): 

7 """ 

8 Create an attack object based on the specified attack name and initialization parameters. 

9 

10 Parameters 

11 ---------- 

12 attack_name : str 

13 The name of the attack to create. Must be a valid key in the `attack_names` dictionary. 

14 kwargs : dict 

15 Additional keyword arguments required to initialize the attack class. 

16 

17 Returns 

18 ------- 

19 object 

20 An instance of the attack class corresponding to the provided name. 

21 

22 Raises 

23 ------ 

24 ValueError 

25 If the attack name is unsupported or required arguments are missing. 

26 """ 

27 if attack_name not in attack_names: 

28 raise ValueError(f"Unsupported attack type: {attack_name}") 

29 

30 attack_class = attack_names[attack_name] 

31 

32 # Get the parameter names and default values of the __init__ method of the specific attack class 

33 sig = inspect.signature(attack_class.__init__) 

34 required_args = [ 

35 param.name 

36 for param in sig.parameters.values() 

37 if param.default == param.empty and param.name != "self" 

38 ] 

39 

40 # Collect the arguments required for the specific attack 

41 init_kwargs = {arg: kwargs[arg] for arg in sig.parameters.keys() if arg in kwargs} 

42 

43 # Validate that all required arguments are provided 

44 missing_args = [arg for arg in required_args if arg not in init_kwargs] 

45 if missing_args: 

46 raise ValueError( 

47 f"Missing required arguments for {attack_name}: {', '.join(missing_args)}" 

48 ) 

49 

50 return attack_class(**init_kwargs) 

51 

52 

53class EnsembleAttackWrapper: 

54 """ 

55 Wrapper to dynamically apply different adversarial attacks on each environment reset. 

56 

57 This class is designed to wrap multiple adversarial attack instances and randomly select one 

58 to apply at each call to `generate_adv_obs` or `generate_perturbation`. 

59 

60 Parameters 

61 ---------- 

62 attacks : list of object 

63 A list of attack instances that must implement `generate_adv_obs` and `generate_perturbation`. 

64 

65 Attributes 

66 ---------- 

67 attacks : list of object 

68 The list of attack instances provided during initialization. 

69 """ 

70 

71 def __init__(self, make_attack_fct_list, seed=None): 

72 """ 

73 Initialize the EnsembleAttackWrapper with a list of attack instances. 

74 

75 Parameters 

76 ---------- 

77 attacks : list of object 

78 A list of attack instances to be used for generating adversarial observations or perturbations. 

79 """ 

80 self.attacks = [make_attack_fct() for make_attack_fct in make_attack_fct_list] 

81 self.rng = np.random.default_rng(seed) 

82 

83 def generate_adv_obs(self, obs): 

84 """ 

85 Generate adversarial observations by applying a randomly selected attack. 

86 

87 Parameters 

88 ---------- 

89 obs : Any 

90 The original observation(s) from the environment. 

91 

92 Returns 

93 ------- 

94 adv_obs : Any 

95 The adversarially perturbed observation as returned by the selected attack's 

96 `generate_adv_obs` method. 

97 """ 

98 attack = self.rng.choice(self.attacks) 

99 return attack.generate_adv_obs(obs) 

100 

101 def generate_perturbation(self, obs): 

102 """ 

103 Generate perturbations for the given observation using a randomly selected attack. 

104 

105 Parameters 

106 ---------- 

107 obs : Any 

108 The original observation(s) for which perturbation is to be generated. 

109 

110 Returns 

111 ------- 

112 perturbation : Any 

113 The perturbation generated by the selected attack's `generate_perturbation` method. 

114 """ 

115 attack = self.rng.choice(self.attacks) 

116 return attack.generate_perturbation(obs) 

117 

118 

119# Dictionary to map attack names to their corresponding classes 

120attack_names = {} 

121 

122# ================================================================ 

123# Attacks 

124# ================================================================ 

125 

126 

127# ----------Random Attacks----------# 

128# ----------------------------------# 

129 

130RUA = random.RandomUniformAttack 

131attack_names["RUA"] = random.RandomUniformAttack 

132 

133RSA = random.RandomSignAttack 

134attack_names["RSA"] = random.RandomSignAttack 

135 

136RNA = random.RandomNormalAttack 

137attack_names["RNA"] = random.RandomNormalAttack 

138 

139 

140# ---------- Fast Gradient Methods ----------# 

141# ------------------------------------------# 

142 

143FGM_D = fgm.FastGradientMethodDiscreteAction 

144attack_names["FGM_D"] = fgm.FastGradientMethodDiscreteAction 

145 

146FGM_C = fgm.FastGradientMethodContinuousAction 

147attack_names["FGM_C"] = fgm.FastGradientMethodContinuousAction 

148 

149FGM_V = fgm.FastGradientMethodVCritic 

150attack_names["FGM_V"] = fgm.FastGradientMethodVCritic 

151 

152FGM_QC = fgm.FastGradientMethodQCritic 

153attack_names["FGM_QC"] = fgm.FastGradientMethodQCritic 

154 

155FGM_QAC = fgm.FastGradientMethodQActorCritic 

156attack_names["FGM_QAC"] = fgm.FastGradientMethodQActorCritic 

157 

158 

159# ----------Fast Gradient Sign Methods----------# 

160# ----------------------------------------------# 

161 

162FGSM_D = fgsm.FastGradientSignMethodDiscreteAction 

163attack_names["FGSM_D"] = fgsm.FastGradientSignMethodDiscreteAction 

164 

165FGSM_C = fgsm.FastGradientSignMethodContinuousAction 

166attack_names["FGSM_C"] = fgsm.FastGradientSignMethodContinuousAction 

167 

168FGSM_V = fgsm.FastGradientSignMethodVCritic 

169attack_names["FGSM_V"] = fgsm.FastGradientSignMethodVCritic 

170 

171FGSM_QC = fgsm.FastGradientSignMethodQCritic 

172attack_names["FGSM_QC"] = fgsm.FastGradientSignMethodQCritic 

173 

174FGSM_QAC = fgsm.FastGradientSignMethodQActorCritic 

175attack_names["FGSM_QAC"] = fgsm.FastGradientSignMethodQActorCritic 

176 

177 

178# Dictionary to map training method names to their corresponding classes 

179training_method_names = {} 

180 

181# ================================================================ 

182# Adversarial Policy Training 

183# ================================================================