Coverage for adaro_rl / viz / robustness_matrix.py: 85%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2import numpy as np 

3import pandas as pd 

4import matplotlib.pyplot as plt 

5import matplotlib.colors as mcolors 

6from matplotlib.patches import Rectangle 

7 

8 

9def robustness_matrix( 

10 env_name: str, 

11 agent_dirs: list[str], 

12 agent_names: list[str], 

13 norm: float, 

14 attack_list: list[str], 

15 eps_list: list[float], 

16 output_dir: str, 

17 clusters: list[int] = None, 

18 fontsize=15, 

19): 

20 """ 

21 Plot a heat‑map of mean episode rewards for several agents under 

22 different adversarial attacks (plus the nominal run). 

23 

24 Each *agent_dir* must contain: 

25 

26 * ``results/result.csv`` – nominal evaluation with a **“mean reward”** column. 

27 * ``results/online_adv_attacks_norm_{norm}_mean_reward.csv`` – rewards for 

28 every (*eps*, *attack*) pair; first column = *eps*. 

29 

30 A PDF ``confusion_<env>_norm_<N>.pdf`` is saved to *output_dir*, and the 

31 Matplotlib ``Figure`` is returned. 

32 """ 

33 os.makedirs(output_dir, exist_ok=True) 

34 

35 # ---------------------------------------------------------------- labels 

36 x_labels = ["None"] + [ 

37 f"{attack} eps={eps}" for attack, eps in zip(attack_list, eps_list) 

38 ] 

39 

40 y_labels = agent_names 

41 

42 # ---------------------------------------------------------------- CSV paths 

43 nominal_csvs = [os.path.join(d, "results", "result.csv") for d in agent_dirs] 

44 attack_csvs = [ 

45 os.path.join(d, "results", f"online_adv_attacks_norm_{norm}_mean_reward.csv") 

46 for d in agent_dirs 

47 ] 

48 

49 df = pd.read_csv(nominal_csvs[0]) 

50 nominal_reward = df.loc[df[df.columns[0]] == 0, "mean reward"].iloc[0] 

51 

52 # ---------------------------------------------------------------- reward matrix 

53 matrix = np.zeros((len(agent_dirs), len(x_labels))) 

54 

55 for row_idx, (nominal_csv, attack_csv) in enumerate(zip(nominal_csvs, attack_csvs)): 

56 # -------- nominal reward 

57 nominal_df = pd.read_csv(nominal_csv) 

58 matrix[row_idx, 0] = nominal_df.loc[0, "mean reward"] 

59 

60 # -------- adversarial rewards 

61 attack_df = pd.read_csv(attack_csv) 

62 attack_df.set_index(attack_df.columns[0], inplace=True) 

63 reward_lookup = attack_df.to_dict(orient="index") 

64 

65 row_values: list[float] = [] 

66 for attack_name, eps in zip(attack_list, eps_list): 

67 eps = float(eps) # ensure same dtype as index 

68 row_values.append(reward_lookup.get(eps, {}).get(attack_name, np.nan)) 

69 

70 matrix[row_idx, 1: 1 + len(row_values)] = row_values 

71 

72 # ---------------------------------------------------------------- plotting 

73 fig, ax = plt.subplots(figsize=(12, 10)) 

74 ax.set_title(f"{env_name}: Robustness Matrix\n", fontsize=fontsize + 5) 

75 ax.set_xlabel("Attacks", fontsize=fontsize) 

76 ax.set_ylabel("Agents", fontsize=fontsize) 

77 

78 colour_norm = mcolors.Normalize(vmin=np.nanmin(matrix), vmax=nominal_reward) 

79 colour_map = ax.matshow(matrix, cmap="YlOrRd", norm=colour_norm) 

80 

81 colourbar = fig.colorbar(colour_map, label="Mean reward") 

82 colourbar.set_label("Mean reward", fontsize=fontsize) 

83 colourbar.ax.tick_params(labelsize=fontsize, rotation=30) 

84 

85 ax.set_xticks( 

86 range(len(x_labels)), x_labels, rotation=60, ha="left", fontsize=fontsize 

87 ) 

88 ax.xaxis.set_ticks_position("top") 

89 ax.xaxis.set_label_position("top") 

90 

91 ax.set_yticks( 

92 range(len(y_labels)), y_labels, rotation=30, va="top", fontsize=fontsize 

93 ) 

94 

95 if clusters is not None: 

96 

97 def make_rectangle(row_a, row_b): 

98 rect_height = row_b - row_a + 1 

99 return Rectangle( 

100 (-0.5, row_a - 0.5), 

101 matrix.shape[1], 

102 rect_height, 

103 linewidth=2, 

104 edgecolor="black", 

105 facecolor="none", 

106 ) 

107 

108 base = 0 

109 for cluster in clusters: 

110 rect = make_rectangle(base, base + cluster - 1) 

111 ax.add_patch(rect) 

112 base += cluster 

113 

114 fig.tight_layout() 

115 pdf_name = f"confusion_{env_name}_norm_{norm}.pdf" 

116 fig.savefig(os.path.join(output_dir, pdf_name), dpi=300) 

117 

118 return fig