Coverage for adaro_rl / pipelines / utils.py: 89%
18 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from typing import Callable
4def normalize_lists(**kwargs):
5 """
6 For each kwarg that may be a list or a scalar,
7 wraps it in a list and then checks that
8 all of them have the same length.
9 """
10 lists = {k: (v if isinstance(v, list) else [v]) for k, v in kwargs.items()}
11 lengths = {len(v) for v in lists.values()}
12 if len(lengths) != 1:
13 raise ValueError(f"All arguments must have the same length, got {lengths}")
14 return lists
17def make_attack_list(
18 base_env,
19 attack_name_list: list[str],
20 make_agent_fct_list_for_attack: list[Callable],
21 target_list: list[str],
22 eps_list: list[float],
23 config,
24 make_attack: Callable,
25 norm: float,
26 device: str,
27):
28 """
29 Create a list of attack instances based on the provided parameters.
30 Each attack instance is created using the `make_attack` function.
31 The function also handles the normalization of input lists and
32 ensures that all lists have the same length.
33 """
35 obs_space = base_env.observation_space
36 perturb_space = base_env.get_attr("observation_perturbation_space")[0]
37 is_proportional_mask = base_env.get_attr("proportional_obs_perturbation_mask")[0]
39 # Create a list of attack instances
41 make_attack_fct_list = []
43 for attack_name, make_agent_fct_for_attack, target, eps in zip(
44 attack_name_list, make_agent_fct_list_for_attack, target_list, eps_list
45 ):
46 if target == "target_fct":
47 target = config.target_fct
49 make_attack_fct = (
50 lambda attack_name=attack_name,
51 make_agent_fct_for_attack=make_agent_fct_for_attack,
52 target=target,
53 eps=eps: make_attack(
54 attack_name=attack_name,
55 make_agent_fct=make_agent_fct_for_attack,
56 target=target,
57 obs_space=obs_space,
58 perturb_space=perturb_space,
59 eps=eps,
60 norm=norm,
61 is_proportional_mask=is_proportional_mask,
62 device=device,
63 )
64 )
66 make_attack_fct_list.append(make_attack_fct)
68 return make_attack_fct_list