Coverage for uqmodels / visualization / old_visualisation.py: 7%

107 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 08:15 +0000

1import numpy as np 

2import matplotlib.pyplot as plt 

3import uqmodels.postprocessing.UQ_processing as UQ_proc 

4 

5 

6def plot_prediction_interval( 

7 y: np.array, 

8 y_pred_lower: np.array, 

9 y_pred_upper: np.array, 

10 X: np.array = None, 

11 y_pred: np.array = None, 

12 save_path: str = None, 

13 sort_X: bool = False, 

14 **kwargs, 

15) -> None: 

16 """Plot prediction intervals whose bounds are given by y_pred_lower and y_pred_upper. 

17 True values and point estimates are also plotted if given as argument. 

18 

19 Args: 

20 y: label true values. 

21 y_pred_lower: lower bounds of the prediction interval. 

22 y_pred_upper: upper bounds of the prediction interval. 

23 X <optionnal>: abscisse vector. 

24 y_pred <optionnal>: predicted values. 

25 kwargs: plot parameters. 

26 """ 

27 

28 # Figure configuration 

29 if "figsize" in kwargs.keys(): 

30 figsize = kwargs["figsize"] 

31 else: 

32 figsize = (15, 6) 

33 if "loc" not in kwargs.keys(): 

34 loc = kwargs["loc"] 

35 else: 

36 loc = "upper left" 

37 plt.figure(figsize=figsize) 

38 

39 plt.rcParams["font.family"] = "Times New Roman" 

40 plt.rcParams["ytick.labelsize"] = 15 

41 plt.rcParams["xtick.labelsize"] = 15 

42 plt.rcParams["axes.labelsize"] = 15 

43 plt.rcParams["legend.fontsize"] = 16 

44 

45 if X is None: 

46 X = np.arange(len(y)) 

47 elif sort_X: 

48 sorted_idx = np.argsort(X) 

49 X = X[sorted_idx] 

50 y = y[sorted_idx] 

51 y_pred = y_pred[sorted_idx] 

52 y_pred_lower = y_pred_lower[sorted_idx] 

53 y_pred_upper = y_pred_upper[sorted_idx] 

54 

55 if y_pred_upper is None or y_pred_lower is None: 

56 miscoverage = np.array([False for _ in range(len(y))]) 

57 else: 

58 miscoverage = (y > y_pred_upper) | (y < y_pred_lower) 

59 

60 label = "Observation" if y_pred_upper is None else "Observation (inside PI)" 

61 plt.plot( 

62 X[~miscoverage], 

63 y[~miscoverage], 

64 "darkgreen", 

65 marker="X", 

66 markersize=2, 

67 linewidth=0, 

68 label=label, 

69 zorder=20, 

70 ) 

71 

72 label = "Observation" if y_pred_upper is None else "Observation (outside PI)" 

73 plt.plot( 

74 X[miscoverage], 

75 y[miscoverage], 

76 σ="red", 

77 marker="o", 

78 markersize=2, 

79 linewidth=0, 

80 label=label, 

81 zorder=20, 

82 ) 

83 if y_pred_upper is not None and y_pred_lower is not None: 

84 plt.plot(X, y_pred_upper, "--", color="blue", linewidth=1, alpha=0.7) 

85 plt.plot(X, y_pred_lower, "--", color="blue", linewidth=1, alpha=0.7) 

86 plt.fill_between( 

87 x=X, 

88 y1=y_pred_upper, 

89 y2=y_pred_lower, 

90 alpha=0.2, 

91 fc="b", 

92 ec="None", 

93 label="Prediction Interval", 

94 ) 

95 

96 if y_pred is not None: 

97 plt.plot(X, y_pred, color="k", label="Prediction") 

98 

99 plt.xlabel("X") 

100 plt.ylabel("Y") 

101 

102 if "loc" not in kwargs.keys(): 

103 loc = "upper left" 

104 else: 

105 loc = kwargs["loc"] 

106 

107 plt.legend(loc=loc) 

108 if save_path: 

109 plt.savefig(f"{save_path}", format="pdf") 

110 else: 

111 plt.show() 

112 

113 

114def plot_sorted_pi( 

115 y: np.array, 

116 y_pred_lower: np.array, 

117 y_pred_upper: np.array, 

118 X: np.array = None, 

119 y_pred: np.array = None, 

120 **kwargs, 

121) -> None: 

122 """Plot prediction intervals in an ordered fashion (lowest to largest width), 

123 showing the upper and lower bounds for each prediction. 

124 Args: 

125 y: label true values. 

126 y_pred_lower: lower bounds of the prediction interval. 

127 y_pred_upper: upper bounds of the prediction interval. 

128 X <optionnal>: abscisse vector. 

129 y_pred <optionnal>: predicted values. 

130 kwargs: plot parameters. 

131 """ 

132 

133 if y_pred is None: 

134 y_pred = (y_pred_upper + y_pred_lower) / 2 

135 

136 width = np.abs(y_pred_upper - y_pred_lower) 

137 sorted_order = np.argsort(width) 

138 

139 # Figure configuration 

140 if "figsize" in kwargs.keys(): 

141 figsize = kwargs["figsize"] 

142 else: 

143 figsize = (15, 6) 

144 if "loc" not in kwargs.keys(): 

145 kwargs["loc"] 

146 else: 

147 pass 

148 plt.figure(figsize=figsize) 

149 

150 if X is None: 

151 X = np.arange(len(y_pred_lower)) 

152 

153 # True values 

154 plt.plot( 

155 X, 

156 y_pred[sorted_order] - y_pred[sorted_order], 

157 color="black", 

158 markersize=2, 

159 zorder=20, 

160 label="Prediction", 

161 ) 

162 

163 misscoverage = (y > y_pred_upper) | (y < y_pred_lower) 

164 misscoverage = misscoverage[sorted_order] 

165 

166 # True values 

167 plt.plot( 

168 X[~misscoverage], 

169 y[sorted_order][~misscoverage] - y_pred[sorted_order][~misscoverage], 

170 color="darkgreen", 

171 marker="o", 

172 markersize=2, 

173 linewidth=0, 

174 zorder=20, 

175 label="Observation (inside PI)", 

176 ) 

177 

178 plt.plot( 

179 X[misscoverage], 

180 y[sorted_order][misscoverage] - y_pred[sorted_order][misscoverage], 

181 color="red", 

182 marker="o", 

183 markersize=2, 

184 linewidth=0, 

185 zorder=20, 

186 label="Observation (outside PI)", 

187 ) 

188 

189 # PI Lower bound 

190 plt.plot( 

191 X, 

192 y_pred_lower[sorted_order] - y_pred[sorted_order], 

193 "--", 

194 label="Prediction Interval Bounds", 

195 color="blue", 

196 linewidth=1, 

197 alpha=0.7, 

198 ) 

199 

200 # PI upper bound 

201 plt.plot( 

202 X, 

203 y_pred_upper[sorted_order] - y_pred[sorted_order], 

204 "--", 

205 color="blue", 

206 linewidth=1, 

207 alpha=0.7, 

208 ) 

209 

210 plt.legend() 

211 

212 plt.show() 

213 

214 

215def visu_latent_space(grid_dim, embedding, f_obs, context_grid, context_grid_name=None): 

216 fig = plt.figure(figsize=(15, 7)) 

217 for i in range(grid_dim[0]): 

218 for j in range(grid_dim[1]): 

219 ax = fig.add_subplot( 

220 grid_dim[0], grid_dim[1], i * grid_dim[1] + j + 1, projection="3d" 

221 ) 

222 if context_grid_name is not None: 

223 plt.title(context_grid_name[i][j]) 

224 ax.scatter( 

225 embedding[f_obs, 0], 

226 embedding[f_obs, 1], 

227 embedding[f_obs, 2], 

228 c=context_grid[i][j][f_obs], 

229 cmap=plt.get_cmap("jet"), 

230 s=1, 

231 ) 

232 

233 

234def show_dUQ_refinement( 

235 UQ, 

236 y=None, 

237 d=0, 

238 f_obs=None, 

239 max_cut_A=0.99, 

240 q_Eratio=2, 

241 E_cut_in_var_nominal=False, 

242 A_res_in_var_atypic=False, 

243): 

244 if isinstance(UQ, tuple): 

245 UQ = np.array(UQ) 

246 

247 if f_obs is None: 

248 f_obs = np.arange(UQ.shape[1]) 

249 

250 var_A, var_E = UQ 

251 extremum_var_TOT, ndUQ_ratio = UQ_proc.get_extremum_var_TOT_and_ndUQ_ratio( 

252 UQ, 

253 min_cut=0, 

254 max_cut=max_cut_A, 

255 var_min=0, 

256 var_max=None, 

257 factor=1, 

258 q_var=1, 

259 q_Eratio=q_Eratio, 

260 mode_multidim=True, 

261 E_cut_in_var_nominal=E_cut_in_var_nominal, 

262 A_res_in_var_atypic=A_res_in_var_atypic, 

263 ) 

264 

265 var_A_cut, var_E_res = UQ_proc.split_var_dUQ( 

266 UQ, 

267 q_var=1, 

268 q_var_e=1, 

269 ndUQ_ratio=ndUQ_ratio, 

270 E_cut_in_var_nominal=E_cut_in_var_nominal, 

271 A_res_in_var_atypic=A_res_in_var_atypic, 

272 extremum_var_TOT=extremum_var_TOT, 

273 ) 

274 

275 var_A_res = var_A - var_A_cut 

276 var_E_cut = var_E - var_E_res 

277 

278 val = 0 

279 if y is not None: 

280 val = 1 

281 

282 fig, ax = plt.subplots(3 + val, 1, sharex=True, figsize=(20, 5)) 

283 if val == 1: 

284 ax[0].plot(y[f_obs, d: d + 1], label="true_val") 

285 ax[0 + val].plot(var_A[f_obs, d: d + 1], label="row_var_A") 

286 ax[0 + val].plot(var_A_cut[f_obs, d: d + 1], label="refined_var_A") 

287 ax[0 + val].legend() 

288 ax[1 + val].plot(var_E[f_obs, d: d + 1], label="row_var_E") 

289 ax[1 + val].plot(var_E_res[f_obs, d: d + 1], label="refined_var_E") 

290 ax[1 + val].legend() 

291 ratio = var_E[f_obs, d: d + 1] / var_A[f_obs, d: d + 1] 

292 ax[2 + val].plot(ratio / ratio.std(), label="row_ratio") 

293 refined_ratio = (var_A_res[f_obs, d: d + 1] + var_E_res[f_obs, d: d + 1]) / ( 

294 var_A_cut[f_obs, d: d + 1] + var_E_cut[f_obs, d: d + 1] 

295 ) 

296 ax[2 + val].plot(refined_ratio / refined_ratio.std(), label="refined_ratio") 

297 ax[2 + val].legend() 

298 print("yaya") 

299 return (fig, ax)