Coverage for uqmodels / data_generation / Gen_basic_times_series.py: 97%

90 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 08:15 +0000

1from dataclasses import dataclass, field 

2import numpy as np 

3import scipy.stats 

4import math 

5import pandas as pd 

6from sklearn.preprocessing import StandardScaler 

7# from dataclasses import dataclass, field 

8from typing import Callable, Dict, Any, Sequence, Optional 

9from uqmodels.utils import cut, base_cos_freq 

10from uqmodels.preprocessing.preprocessing import rolling_statistics 

11 

12rng = np.random.RandomState(42) 

13 

14 

15def attack_mean(y: np.ndarray, 

16 loc: np.ndarray, 

17 dim: Sequence[int] = (0,), 

18 f: float = 1.0) -> np.ndarray: 

19 """ 

20 Ajoute une dérive progressive sur la moyenne sur les indices `loc`. 

21 """ 

22 y = np.copy(y) 

23 len_ = len(loc) 

24 for d in dim: 

25 for t in np.arange(len_): 

26 y[loc[t], d] += f * np.sqrt(min(t, len_ - t) / len_) 

27 return y 

28 

29 

30def attack_var(y: np.ndarray, 

31 loc: np.ndarray, 

32 dim: Sequence[int] = (0,), 

33 f: float = 1.0) -> np.ndarray: 

34 """ 

35 Augmente (ou diminue) la variance sur une fenêtre d'indices `loc`. 

36 """ 

37 y = np.copy(y) 

38 len_ = len(loc) 

39 ext_min = scipy.stats.norm.ppf(0.15, 0, f * 0.2) 

40 ext_max = scipy.stats.norm.ppf(0.85, 0, f * 0.2) 

41 for d in dim: 

42 noise = rng.normal(0, f * 0.2, len_) 

43 noise = np.maximum(np.minimum(noise, ext_max), ext_min) 

44 y[loc, d] += noise 

45 return y 

46 

47 

48def attack_spike(y: np.ndarray, 

49 loc: np.ndarray, 

50 dim: Sequence[int] = (0,), 

51 f: float = 1.0) -> np.ndarray: 

52 """ 

53 Ajoute un spike ponctuel sur le milieu de la fenêtre `loc`. 

54 """ 

55 y = np.copy(y) 

56 len_ = len(loc) 

57 idx = loc[int(len_ / 2)] 

58 for d in dim: 

59 y[idx, d] += f 

60 return y 

61 

62 

63@dataclass 

64class AttackSpec: 

65 """ 

66 Spécification d'une attaque à appliquer sur un signal. 

67 

68 Parameters 

69 ---------- 

70 func : Callable 

71 Fonction d'attaque (e.g. attack_mean, attack_var, attack_spike). 

72 loc : np.ndarray 

73 Indices sur lesquels appliquer l'attaque. 

74 kwargs : dict 

75 Arguments additionnels à passer à la fonction (e.g. f=1.1, dim=[0]). 

76 """ 

77 func: Callable[[np.ndarray, np.ndarray], np.ndarray] 

78 loc: np.ndarray 

79 kwargs: Dict[str, Any] = field(default_factory=dict) 

80 

81 

82def apply_attacks(y: np.ndarray, 

83 attacks: Sequence[AttackSpec]) -> np.ndarray: 

84 """ 

85 Applique séquentiellement une liste d'attaques sur le signal y. 

86 """ 

87 y_out = np.copy(y) 

88 for spec in attacks: 

89 y_out = spec.func(y_out, spec.loc, **spec.kwargs) 

90 return y_out 

91 

92 

93def core_gen( 

94 N: int = 10000, 

95 freq: float = 100.0, 

96 r1: float = -0.31, 

97 r2: float = 5.1, 

98 r3: float = 1.3, 

99 r4: float = 0.4, 

100 seed: int = 42, 

101 train_ratio: float = 0.7, 

102 name: str = "Unnoised ML-task", 

103 # attaques sur y et Z 

104 attacks_y: Optional[Sequence[AttackSpec]] = None, 

105 attacks_z: Optional[Sequence[AttackSpec]] = None, 

106 # fonctions externes 

107 cut_func: Callable[[np.ndarray, float, float], np.ndarray] = cut, 

108 base_cos_freq_func: Callable[[np.ndarray, Sequence[float]], np.ndarray] = base_cos_freq, 

109 rolling_statistics_func: Callable[..., pd.DataFrame] = rolling_statistics): 

110 """ 

111 Version enrichie : retourne aussi 

112 - y_no_obs : y et Z perturbés, sans statistiques 

113 - y_old : y et Z non perturbés, sans statistiques 

114 """ 

115 

116 if cut_func is None: 

117 raise ValueError("cut_func doit être fourni (par ex. `cut`).") 

118 if base_cos_freq_func is None: 

119 raise ValueError("base_cos_freq_func doit être fourni (par ex. `base_cos_freq`).") 

120 if rolling_statistics_func is None: 

121 raise ValueError("rolling_statistics_func doit être fourni (par ex. `rolling_statistics`).") 

122 

123 # RNG 

124 local_rng = np.random.RandomState(seed) 

125 

126 # ----------------------------- 

127 # 1. Grille temporelle 

128 # ----------------------------- 

129 X = np.arange(0, freq, freq / N) 

130 

131 # ----------------------------- 

132 # 2. Signaux propres 

133 # ----------------------------- 

134 y_mean = np.cos(X * math.pi + r1) + np.cos(r2 * X * math.pi) 

135 Z_mean = np.power(np.sin(X * math.pi + r3), 3) + np.cos(r4 * X * math.pi) 

136 

137 # ----------------------------- 

138 # 3. Bruit 

139 # ----------------------------- 

140 Y_noise = cut_func(local_rng.normal(0, 1.5, N), 0.25, 0.75) * (0.02 + np.abs(y_mean) / 5) 

141 Z_noise = cut_func(local_rng.normal(0, 1.2, N), 0.25, 0.75) * (0.02 + np.abs(1 - Z_mean) / 5) 

142 

143 y_clean = cut_func(y_mean + Y_noise, 0.01, 0.99).reshape(-1, 1) 

144 Z_clean = cut_func(Z_mean + Z_noise, 0.01, 0.99).reshape(-1, 1) 

145 

146 # Non perturbés (référence) 

147 old_y = np.copy(y_clean) 

148 old_Z = np.copy(Z_clean) 

149 

150 # ----------------------------- 

151 # 4. Attaques 

152 # ----------------------------- 

153 if attacks_y is None: 

154 attacks_y = [ 

155 AttackSpec(attack_mean, np.arange(9100, 9120), dict(f=-1.1)), 

156 AttackSpec(attack_spike, np.arange(9501, 9502), dict(f=1.55)), 

157 AttackSpec(attack_var, np.arange(9750, 9770), dict(f=2.5)), 

158 ] 

159 

160 if attacks_z is None: 

161 attacks_z = [ 

162 AttackSpec(attack_var, np.arange(9100, 9120), dict(f=2.0)), 

163 AttackSpec(attack_spike, np.arange(9770, 9771), dict(f=-0.6)), 

164 ] 

165 

166 y_attacked = apply_attacks(y_clean, attacks_y) 

167 Z_attacked = apply_attacks(Z_clean, attacks_z) 

168 

169 # ----------------------------- 

170 # NOUVEAU : sorties brutes pour visualisation 

171 # ----------------------------- 

172 y_no_obs = np.concatenate([y_attacked, Z_attacked], axis=1) 

173 y_old = np.concatenate([old_y, old_Z], axis=1) 

174 

175 # ----------------------------- 

176 # 5. Rolling statistics 

177 # ----------------------------- 

178 mean_var_10 = rolling_statistics_func( 

179 pd.DataFrame(np.roll(y_attacked, 0)), 

180 10, 1, 

181 ['mean', 'std', 'extremum'], 

182 ['mean', 'std', 'extremum'] 

183 ).replace(np.nan, 0).values[:, :1] 

184 

185 var_and_ext = rolling_statistics_func( 

186 pd.DataFrame(np.roll(y_attacked, 0)), 

187 30, 1, 

188 ['mean', 'std', 'extremum'], 

189 ['mean', 'std', 'extremum'] 

190 ).replace(np.nan, 0).values[:, 1:] 

191 

192 Z_mean_var_10 = rolling_statistics_func( 

193 pd.DataFrame(np.roll(Z_attacked, 0)), 

194 10, 1, 

195 ['mean', 'std', 'extremum'], 

196 ['mean', 'std', 'extremum'] 

197 ).replace(np.nan, 0).values[:, :1] 

198 

199 Z_var_and_ext = rolling_statistics_func( 

200 pd.DataFrame(np.roll(Z_attacked, 0)), 

201 30, 1, 

202 ['mean', 'std', 'extremum'], 

203 ['mean', 'std', 'extremum'] 

204 ).replace(np.nan, 0).values[:, 1:] 

205 

206 # --- Stats sur signaux NON perturbés (old_y / old_Z) --- 

207 

208 old_mean_var_10 = rolling_statistics_func( 

209 pd.DataFrame(np.roll(old_y, 0)), 

210 10, 1, 

211 ['mean', 'std', 'extremum'], 

212 ['mean', 'std', 'extremum'] 

213 ).replace(np.nan, 0).values[:, :1] 

214 

215 old_var_and_ext = rolling_statistics_func( 

216 pd.DataFrame(np.roll(old_y, 0)), 

217 30, 1, 

218 ['mean', 'std', 'extremum'], 

219 ['mean', 'std', 'extremum'] 

220 ).replace(np.nan, 0).values[:, 1:] 

221 

222 old_Z_mean_var_10 = rolling_statistics_func( 

223 pd.DataFrame(np.roll(old_Z, 0)), 

224 10, 1, 

225 ['mean', 'std', 'extremum'], 

226 ['mean', 'std', 'extremum'] 

227 ).replace(np.nan, 0).values[:, :1] 

228 

229 old_Z_var_and_ext = rolling_statistics_func( 

230 pd.DataFrame(np.roll(old_Z, 0)), 

231 30, 1, 

232 ['mean', 'std', 'extremum'], 

233 ['mean', 'std', 'extremum'] 

234 ).replace(np.nan, 0).values[:, 1:] 

235 

236 # --- y_old = même format que y_target mais sur données non perturbées --- 

237 y_old = np.concatenate( 

238 [old_mean_var_10, old_var_and_ext, 

239 old_Z_mean_var_10, old_Z_var_and_ext], 

240 axis=1 

241 ) 

242 

243 # cible ML 

244 y_target = np.concatenate( 

245 [mean_var_10, var_and_ext, Z_mean_var_10, Z_var_and_ext], 

246 axis=1 

247 ) 

248 

249 y_old = np.concatenate([ 

250 old_mean_var_10, 

251 old_var_and_ext, 

252 old_Z_mean_var_10, 

253 old_Z_var_and_ext], axis=1) 

254 

255 # ----------------------------- 

256 # 6. Features contextuelles 

257 # ----------------------------- 

258 feat1 = base_cos_freq_func(X, [0.5, 2, 0.5 * r2, 2 * r2]) 

259 feat2 = base_cos_freq_func(X, [0.5, 2, 0.5 * r4, 2 * r4]) 

260 features = np.concatenate([feat1, feat2], axis=1) 

261 

262 features = StandardScaler().fit_transform(features) 

263 

264 # ----------------------------- 

265 # 7. Train/test + state 

266 # ----------------------------- 

267 state = np.zeros(N) 

268 train_mask = np.arange(N) < int(train_ratio * N) 

269 

270 # ----------------------------- 

271 # 8. RETOUR enrichi 

272 # ----------------------------- 

273 return { 

274 "X": features, 

275 "Y": y_target, 

276 "Context": state, 

277 "train": train_mask, 

278 "test": np.invert(train_mask), 

279 "split": train_mask, 

280 "aux": {"y_no_obs": y_no_obs, "y_old": y_old} 

281 } 

282 

283 

284def generate_default(dict_params=dict()): 

285 dict_data = core_gen(**dict_params) 

286 return dict_data