Coverage for uqmodels / visualization / old_visualisation.py: 7%
107 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 08:15 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 08:15 +0000
1import numpy as np
2import matplotlib.pyplot as plt
3import uqmodels.postprocessing.UQ_processing as UQ_proc
6def plot_prediction_interval(
7 y: np.array,
8 y_pred_lower: np.array,
9 y_pred_upper: np.array,
10 X: np.array = None,
11 y_pred: np.array = None,
12 save_path: str = None,
13 sort_X: bool = False,
14 **kwargs,
15) -> None:
16 """Plot prediction intervals whose bounds are given by y_pred_lower and y_pred_upper.
17 True values and point estimates are also plotted if given as argument.
19 Args:
20 y: label true values.
21 y_pred_lower: lower bounds of the prediction interval.
22 y_pred_upper: upper bounds of the prediction interval.
23 X <optionnal>: abscisse vector.
24 y_pred <optionnal>: predicted values.
25 kwargs: plot parameters.
26 """
28 # Figure configuration
29 if "figsize" in kwargs.keys():
30 figsize = kwargs["figsize"]
31 else:
32 figsize = (15, 6)
33 if "loc" not in kwargs.keys():
34 loc = kwargs["loc"]
35 else:
36 loc = "upper left"
37 plt.figure(figsize=figsize)
39 plt.rcParams["font.family"] = "Times New Roman"
40 plt.rcParams["ytick.labelsize"] = 15
41 plt.rcParams["xtick.labelsize"] = 15
42 plt.rcParams["axes.labelsize"] = 15
43 plt.rcParams["legend.fontsize"] = 16
45 if X is None:
46 X = np.arange(len(y))
47 elif sort_X:
48 sorted_idx = np.argsort(X)
49 X = X[sorted_idx]
50 y = y[sorted_idx]
51 y_pred = y_pred[sorted_idx]
52 y_pred_lower = y_pred_lower[sorted_idx]
53 y_pred_upper = y_pred_upper[sorted_idx]
55 if y_pred_upper is None or y_pred_lower is None:
56 miscoverage = np.array([False for _ in range(len(y))])
57 else:
58 miscoverage = (y > y_pred_upper) | (y < y_pred_lower)
60 label = "Observation" if y_pred_upper is None else "Observation (inside PI)"
61 plt.plot(
62 X[~miscoverage],
63 y[~miscoverage],
64 "darkgreen",
65 marker="X",
66 markersize=2,
67 linewidth=0,
68 label=label,
69 zorder=20,
70 )
72 label = "Observation" if y_pred_upper is None else "Observation (outside PI)"
73 plt.plot(
74 X[miscoverage],
75 y[miscoverage],
76 σ="red",
77 marker="o",
78 markersize=2,
79 linewidth=0,
80 label=label,
81 zorder=20,
82 )
83 if y_pred_upper is not None and y_pred_lower is not None:
84 plt.plot(X, y_pred_upper, "--", color="blue", linewidth=1, alpha=0.7)
85 plt.plot(X, y_pred_lower, "--", color="blue", linewidth=1, alpha=0.7)
86 plt.fill_between(
87 x=X,
88 y1=y_pred_upper,
89 y2=y_pred_lower,
90 alpha=0.2,
91 fc="b",
92 ec="None",
93 label="Prediction Interval",
94 )
96 if y_pred is not None:
97 plt.plot(X, y_pred, color="k", label="Prediction")
99 plt.xlabel("X")
100 plt.ylabel("Y")
102 if "loc" not in kwargs.keys():
103 loc = "upper left"
104 else:
105 loc = kwargs["loc"]
107 plt.legend(loc=loc)
108 if save_path:
109 plt.savefig(f"{save_path}", format="pdf")
110 else:
111 plt.show()
114def plot_sorted_pi(
115 y: np.array,
116 y_pred_lower: np.array,
117 y_pred_upper: np.array,
118 X: np.array = None,
119 y_pred: np.array = None,
120 **kwargs,
121) -> None:
122 """Plot prediction intervals in an ordered fashion (lowest to largest width),
123 showing the upper and lower bounds for each prediction.
124 Args:
125 y: label true values.
126 y_pred_lower: lower bounds of the prediction interval.
127 y_pred_upper: upper bounds of the prediction interval.
128 X <optionnal>: abscisse vector.
129 y_pred <optionnal>: predicted values.
130 kwargs: plot parameters.
131 """
133 if y_pred is None:
134 y_pred = (y_pred_upper + y_pred_lower) / 2
136 width = np.abs(y_pred_upper - y_pred_lower)
137 sorted_order = np.argsort(width)
139 # Figure configuration
140 if "figsize" in kwargs.keys():
141 figsize = kwargs["figsize"]
142 else:
143 figsize = (15, 6)
144 if "loc" not in kwargs.keys():
145 kwargs["loc"]
146 else:
147 pass
148 plt.figure(figsize=figsize)
150 if X is None:
151 X = np.arange(len(y_pred_lower))
153 # True values
154 plt.plot(
155 X,
156 y_pred[sorted_order] - y_pred[sorted_order],
157 color="black",
158 markersize=2,
159 zorder=20,
160 label="Prediction",
161 )
163 misscoverage = (y > y_pred_upper) | (y < y_pred_lower)
164 misscoverage = misscoverage[sorted_order]
166 # True values
167 plt.plot(
168 X[~misscoverage],
169 y[sorted_order][~misscoverage] - y_pred[sorted_order][~misscoverage],
170 color="darkgreen",
171 marker="o",
172 markersize=2,
173 linewidth=0,
174 zorder=20,
175 label="Observation (inside PI)",
176 )
178 plt.plot(
179 X[misscoverage],
180 y[sorted_order][misscoverage] - y_pred[sorted_order][misscoverage],
181 color="red",
182 marker="o",
183 markersize=2,
184 linewidth=0,
185 zorder=20,
186 label="Observation (outside PI)",
187 )
189 # PI Lower bound
190 plt.plot(
191 X,
192 y_pred_lower[sorted_order] - y_pred[sorted_order],
193 "--",
194 label="Prediction Interval Bounds",
195 color="blue",
196 linewidth=1,
197 alpha=0.7,
198 )
200 # PI upper bound
201 plt.plot(
202 X,
203 y_pred_upper[sorted_order] - y_pred[sorted_order],
204 "--",
205 color="blue",
206 linewidth=1,
207 alpha=0.7,
208 )
210 plt.legend()
212 plt.show()
215def visu_latent_space(grid_dim, embedding, f_obs, context_grid, context_grid_name=None):
216 fig = plt.figure(figsize=(15, 7))
217 for i in range(grid_dim[0]):
218 for j in range(grid_dim[1]):
219 ax = fig.add_subplot(
220 grid_dim[0], grid_dim[1], i * grid_dim[1] + j + 1, projection="3d"
221 )
222 if context_grid_name is not None:
223 plt.title(context_grid_name[i][j])
224 ax.scatter(
225 embedding[f_obs, 0],
226 embedding[f_obs, 1],
227 embedding[f_obs, 2],
228 c=context_grid[i][j][f_obs],
229 cmap=plt.get_cmap("jet"),
230 s=1,
231 )
234def show_dUQ_refinement(
235 UQ,
236 y=None,
237 d=0,
238 f_obs=None,
239 max_cut_A=0.99,
240 q_Eratio=2,
241 E_cut_in_var_nominal=False,
242 A_res_in_var_atypic=False,
243):
244 if isinstance(UQ, tuple):
245 UQ = np.array(UQ)
247 if f_obs is None:
248 f_obs = np.arange(UQ.shape[1])
250 var_A, var_E = UQ
251 extremum_var_TOT, ndUQ_ratio = UQ_proc.get_extremum_var_TOT_and_ndUQ_ratio(
252 UQ,
253 min_cut=0,
254 max_cut=max_cut_A,
255 var_min=0,
256 var_max=None,
257 factor=1,
258 q_var=1,
259 q_Eratio=q_Eratio,
260 mode_multidim=True,
261 E_cut_in_var_nominal=E_cut_in_var_nominal,
262 A_res_in_var_atypic=A_res_in_var_atypic,
263 )
265 var_A_cut, var_E_res = UQ_proc.split_var_dUQ(
266 UQ,
267 q_var=1,
268 q_var_e=1,
269 ndUQ_ratio=ndUQ_ratio,
270 E_cut_in_var_nominal=E_cut_in_var_nominal,
271 A_res_in_var_atypic=A_res_in_var_atypic,
272 extremum_var_TOT=extremum_var_TOT,
273 )
275 var_A_res = var_A - var_A_cut
276 var_E_cut = var_E - var_E_res
278 val = 0
279 if y is not None:
280 val = 1
282 fig, ax = plt.subplots(3 + val, 1, sharex=True, figsize=(20, 5))
283 if val == 1:
284 ax[0].plot(y[f_obs, d: d + 1], label="true_val")
285 ax[0 + val].plot(var_A[f_obs, d: d + 1], label="row_var_A")
286 ax[0 + val].plot(var_A_cut[f_obs, d: d + 1], label="refined_var_A")
287 ax[0 + val].legend()
288 ax[1 + val].plot(var_E[f_obs, d: d + 1], label="row_var_E")
289 ax[1 + val].plot(var_E_res[f_obs, d: d + 1], label="refined_var_E")
290 ax[1 + val].legend()
291 ratio = var_E[f_obs, d: d + 1] / var_A[f_obs, d: d + 1]
292 ax[2 + val].plot(ratio / ratio.std(), label="row_ratio")
293 refined_ratio = (var_A_res[f_obs, d: d + 1] + var_E_res[f_obs, d: d + 1]) / (
294 var_A_cut[f_obs, d: d + 1] + var_E_cut[f_obs, d: d + 1]
295 )
296 ax[2 + val].plot(refined_ratio / refined_ratio.std(), label="refined_ratio")
297 ax[2 + val].legend()
298 print("yaya")
299 return (fig, ax)