Coverage for adaro_rl / viz / robustness_matrix.py: 85%
52 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
2import numpy as np
3import pandas as pd
4import matplotlib.pyplot as plt
5import matplotlib.colors as mcolors
6from matplotlib.patches import Rectangle
9def robustness_matrix(
10 env_name: str,
11 agent_dirs: list[str],
12 agent_names: list[str],
13 norm: float,
14 attack_list: list[str],
15 eps_list: list[float],
16 output_dir: str,
17 clusters: list[int] = None,
18 fontsize=15,
19):
20 """
21 Plot a heat‑map of mean episode rewards for several agents under
22 different adversarial attacks (plus the nominal run).
24 Each *agent_dir* must contain:
26 * ``results/result.csv`` – nominal evaluation with a **“mean reward”** column.
27 * ``results/online_adv_attacks_norm_{norm}_mean_reward.csv`` – rewards for
28 every (*eps*, *attack*) pair; first column = *eps*.
30 A PDF ``confusion_<env>_norm_<N>.pdf`` is saved to *output_dir*, and the
31 Matplotlib ``Figure`` is returned.
32 """
33 os.makedirs(output_dir, exist_ok=True)
35 # ---------------------------------------------------------------- labels
36 x_labels = ["None"] + [
37 f"{attack} eps={eps}" for attack, eps in zip(attack_list, eps_list)
38 ]
40 y_labels = agent_names
42 # ---------------------------------------------------------------- CSV paths
43 nominal_csvs = [os.path.join(d, "results", "result.csv") for d in agent_dirs]
44 attack_csvs = [
45 os.path.join(d, "results", f"online_adv_attacks_norm_{norm}_mean_reward.csv")
46 for d in agent_dirs
47 ]
49 df = pd.read_csv(nominal_csvs[0])
50 nominal_reward = df.loc[df[df.columns[0]] == 0, "mean reward"].iloc[0]
52 # ---------------------------------------------------------------- reward matrix
53 matrix = np.zeros((len(agent_dirs), len(x_labels)))
55 for row_idx, (nominal_csv, attack_csv) in enumerate(zip(nominal_csvs, attack_csvs)):
56 # -------- nominal reward
57 nominal_df = pd.read_csv(nominal_csv)
58 matrix[row_idx, 0] = nominal_df.loc[0, "mean reward"]
60 # -------- adversarial rewards
61 attack_df = pd.read_csv(attack_csv)
62 attack_df.set_index(attack_df.columns[0], inplace=True)
63 reward_lookup = attack_df.to_dict(orient="index")
65 row_values: list[float] = []
66 for attack_name, eps in zip(attack_list, eps_list):
67 eps = float(eps) # ensure same dtype as index
68 row_values.append(reward_lookup.get(eps, {}).get(attack_name, np.nan))
70 matrix[row_idx, 1: 1 + len(row_values)] = row_values
72 # ---------------------------------------------------------------- plotting
73 fig, ax = plt.subplots(figsize=(12, 10))
74 ax.set_title(f"{env_name}: Robustness Matrix\n", fontsize=fontsize + 5)
75 ax.set_xlabel("Attacks", fontsize=fontsize)
76 ax.set_ylabel("Agents", fontsize=fontsize)
78 colour_norm = mcolors.Normalize(vmin=np.nanmin(matrix), vmax=nominal_reward)
79 colour_map = ax.matshow(matrix, cmap="YlOrRd", norm=colour_norm)
81 colourbar = fig.colorbar(colour_map, label="Mean reward")
82 colourbar.set_label("Mean reward", fontsize=fontsize)
83 colourbar.ax.tick_params(labelsize=fontsize, rotation=30)
85 ax.set_xticks(
86 range(len(x_labels)), x_labels, rotation=60, ha="left", fontsize=fontsize
87 )
88 ax.xaxis.set_ticks_position("top")
89 ax.xaxis.set_label_position("top")
91 ax.set_yticks(
92 range(len(y_labels)), y_labels, rotation=30, va="top", fontsize=fontsize
93 )
95 if clusters is not None:
97 def make_rectangle(row_a, row_b):
98 rect_height = row_b - row_a + 1
99 return Rectangle(
100 (-0.5, row_a - 0.5),
101 matrix.shape[1],
102 rect_height,
103 linewidth=2,
104 edgecolor="black",
105 facecolor="none",
106 )
108 base = 0
109 for cluster in clusters:
110 rect = make_rectangle(base, base + cluster - 1)
111 ax.add_patch(rect)
112 base += cluster
114 fig.tight_layout()
115 pdf_name = f"confusion_{env_name}_norm_{norm}.pdf"
116 fig.savefig(os.path.join(output_dir, pdf_name), dpi=300)
118 return fig