Coverage for adaro_rl / pipelines / utils.py: 89%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from typing import Callable 

2 

3 

4def normalize_lists(**kwargs): 

5 """ 

6 For each kwarg that may be a list or a scalar, 

7 wraps it in a list and then checks that 

8 all of them have the same length. 

9 """ 

10 lists = {k: (v if isinstance(v, list) else [v]) for k, v in kwargs.items()} 

11 lengths = {len(v) for v in lists.values()} 

12 if len(lengths) != 1: 

13 raise ValueError(f"All arguments must have the same length, got {lengths}") 

14 return lists 

15 

16 

17def make_attack_list( 

18 base_env, 

19 attack_name_list: list[str], 

20 make_agent_fct_list_for_attack: list[Callable], 

21 target_list: list[str], 

22 eps_list: list[float], 

23 config, 

24 make_attack: Callable, 

25 norm: float, 

26 device: str, 

27): 

28 """ 

29 Create a list of attack instances based on the provided parameters. 

30 Each attack instance is created using the `make_attack` function. 

31 The function also handles the normalization of input lists and 

32 ensures that all lists have the same length. 

33 """ 

34 

35 obs_space = base_env.observation_space 

36 perturb_space = base_env.get_attr("observation_perturbation_space")[0] 

37 is_proportional_mask = base_env.get_attr("proportional_obs_perturbation_mask")[0] 

38 

39 # Create a list of attack instances 

40 

41 make_attack_fct_list = [] 

42 

43 for attack_name, make_agent_fct_for_attack, target, eps in zip( 

44 attack_name_list, make_agent_fct_list_for_attack, target_list, eps_list 

45 ): 

46 if target == "target_fct": 

47 target = config.target_fct 

48 

49 make_attack_fct = ( 

50 lambda attack_name=attack_name, 

51 make_agent_fct_for_attack=make_agent_fct_for_attack, 

52 target=target, 

53 eps=eps: make_attack( 

54 attack_name=attack_name, 

55 make_agent_fct=make_agent_fct_for_attack, 

56 target=target, 

57 obs_space=obs_space, 

58 perturb_space=perturb_space, 

59 eps=eps, 

60 norm=norm, 

61 is_proportional_mask=is_proportional_mask, 

62 device=device, 

63 ) 

64 ) 

65 

66 make_attack_fct_list.append(make_attack_fct) 

67 

68 return make_attack_fct_list