Coverage for adaro_rl / attacks / registry.py: 92%
52 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from . import random, fgm, fgsm
2import inspect
3import numpy as np
6def make_attack(attack_name, **kwargs):
7 """
8 Create an attack object based on the specified attack name and initialization parameters.
10 Parameters
11 ----------
12 attack_name : str
13 The name of the attack to create. Must be a valid key in the `attack_names` dictionary.
14 kwargs : dict
15 Additional keyword arguments required to initialize the attack class.
17 Returns
18 -------
19 object
20 An instance of the attack class corresponding to the provided name.
22 Raises
23 ------
24 ValueError
25 If the attack name is unsupported or required arguments are missing.
26 """
27 if attack_name not in attack_names:
28 raise ValueError(f"Unsupported attack type: {attack_name}")
30 attack_class = attack_names[attack_name]
32 # Get the parameter names and default values of the __init__ method of the specific attack class
33 sig = inspect.signature(attack_class.__init__)
34 required_args = [
35 param.name
36 for param in sig.parameters.values()
37 if param.default == param.empty and param.name != "self"
38 ]
40 # Collect the arguments required for the specific attack
41 init_kwargs = {arg: kwargs[arg] for arg in sig.parameters.keys() if arg in kwargs}
43 # Validate that all required arguments are provided
44 missing_args = [arg for arg in required_args if arg not in init_kwargs]
45 if missing_args:
46 raise ValueError(
47 f"Missing required arguments for {attack_name}: {', '.join(missing_args)}"
48 )
50 return attack_class(**init_kwargs)
53class EnsembleAttackWrapper:
54 """
55 Wrapper to dynamically apply different adversarial attacks on each environment reset.
57 This class is designed to wrap multiple adversarial attack instances and randomly select one
58 to apply at each call to `generate_adv_obs` or `generate_perturbation`.
60 Parameters
61 ----------
62 attacks : list of object
63 A list of attack instances that must implement `generate_adv_obs` and `generate_perturbation`.
65 Attributes
66 ----------
67 attacks : list of object
68 The list of attack instances provided during initialization.
69 """
71 def __init__(self, make_attack_fct_list, seed=None):
72 """
73 Initialize the EnsembleAttackWrapper with a list of attack instances.
75 Parameters
76 ----------
77 attacks : list of object
78 A list of attack instances to be used for generating adversarial observations or perturbations.
79 """
80 self.attacks = [make_attack_fct() for make_attack_fct in make_attack_fct_list]
81 self.rng = np.random.default_rng(seed)
83 def generate_adv_obs(self, obs):
84 """
85 Generate adversarial observations by applying a randomly selected attack.
87 Parameters
88 ----------
89 obs : Any
90 The original observation(s) from the environment.
92 Returns
93 -------
94 adv_obs : Any
95 The adversarially perturbed observation as returned by the selected attack's
96 `generate_adv_obs` method.
97 """
98 attack = self.rng.choice(self.attacks)
99 return attack.generate_adv_obs(obs)
101 def generate_perturbation(self, obs):
102 """
103 Generate perturbations for the given observation using a randomly selected attack.
105 Parameters
106 ----------
107 obs : Any
108 The original observation(s) for which perturbation is to be generated.
110 Returns
111 -------
112 perturbation : Any
113 The perturbation generated by the selected attack's `generate_perturbation` method.
114 """
115 attack = self.rng.choice(self.attacks)
116 return attack.generate_perturbation(obs)
119# Dictionary to map attack names to their corresponding classes
120attack_names = {}
122# ================================================================
123# Attacks
124# ================================================================
127# ----------Random Attacks----------#
128# ----------------------------------#
130RUA = random.RandomUniformAttack
131attack_names["RUA"] = random.RandomUniformAttack
133RSA = random.RandomSignAttack
134attack_names["RSA"] = random.RandomSignAttack
136RNA = random.RandomNormalAttack
137attack_names["RNA"] = random.RandomNormalAttack
140# ---------- Fast Gradient Methods ----------#
141# ------------------------------------------#
143FGM_D = fgm.FastGradientMethodDiscreteAction
144attack_names["FGM_D"] = fgm.FastGradientMethodDiscreteAction
146FGM_C = fgm.FastGradientMethodContinuousAction
147attack_names["FGM_C"] = fgm.FastGradientMethodContinuousAction
149FGM_V = fgm.FastGradientMethodVCritic
150attack_names["FGM_V"] = fgm.FastGradientMethodVCritic
152FGM_QC = fgm.FastGradientMethodQCritic
153attack_names["FGM_QC"] = fgm.FastGradientMethodQCritic
155FGM_QAC = fgm.FastGradientMethodQActorCritic
156attack_names["FGM_QAC"] = fgm.FastGradientMethodQActorCritic
159# ----------Fast Gradient Sign Methods----------#
160# ----------------------------------------------#
162FGSM_D = fgsm.FastGradientSignMethodDiscreteAction
163attack_names["FGSM_D"] = fgsm.FastGradientSignMethodDiscreteAction
165FGSM_C = fgsm.FastGradientSignMethodContinuousAction
166attack_names["FGSM_C"] = fgsm.FastGradientSignMethodContinuousAction
168FGSM_V = fgsm.FastGradientSignMethodVCritic
169attack_names["FGSM_V"] = fgsm.FastGradientSignMethodVCritic
171FGSM_QC = fgsm.FastGradientSignMethodQCritic
172attack_names["FGSM_QC"] = fgsm.FastGradientSignMethodQCritic
174FGSM_QAC = fgsm.FastGradientSignMethodQActorCritic
175attack_names["FGSM_QAC"] = fgsm.FastGradientSignMethodQActorCritic
178# Dictionary to map training method names to their corresponding classes
179training_method_names = {}
181# ================================================================
182# Adversarial Policy Training
183# ================================================================