Coverage for uqmodels / modelization / DL_estimator / data_generator.py: 67%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 08:15 +0000

1import tensorflow as tf 

2import numpy as np 

3from uqmodels.utils import apply_mask 

4 

5 

6class default_Generator(tf.keras.utils.Sequence): 

7 def __init__( 

8 self, X, y, metamodel, batch=64, shuffle=True, train=True, random_state=None 

9 ): 

10 """ 

11 Standard batch Sequence generator for supervised learning. 

12 Builds batches from X and y, applies metamodel preprocessing via factory, 

13 and returns fixed-shape input/output arrays compatible with Keras training. 

14 """ 

15 

16 self.X = X 

17 self.y = y 

18 self.len_ = len(y) # nombre d'exemples 

19 self.train = train 

20 self.random_state = random_state 

21 self.shuffle = shuffle 

22 self.batch = batch 

23 

24 self.factory = metamodel.factory 

25 self._format = metamodel._format 

26 self.rescale = metamodel.rescale 

27 

28 # indices d'échantillons 

29 self.indices = np.arange(self.len_) 

30 if self.shuffle: 

31 rng = np.random.default_rng(self.random_state) 

32 rng.shuffle(self.indices) 

33 

34 def __len__(self): 

35 """Nombre de batches par epoch.""" 

36 return int(np.ceil(self.len_ / self.batch)) 

37 

38 def __getitem__(self, idx): 

39 """Retourne le batch idx (Inputs, Outputs) sous forme de np.ndarray.""" 

40 # idx : index de batch (0, 1, 2, ...) 

41 start = idx * self.batch 

42 end = min((idx + 1) * self.batch, self.len_) 

43 

44 batch_indices = self.indices[start:end] 

45 

46 # batch brut 

47 X_batch = self.X[batch_indices] 

48 y_batch = self.y[batch_indices] 

49 

50 # factory renvoie (X_transformed, y_transformed, mask) 

51 X_trans, y_trans, _ = self.factory(X_batch, y_batch) 

52 

53 # on force en np.ndarray pour que Keras / tf.data puissent 

54 # inférer un output_signature propre 

55 X_trans = np.asarray(X_trans) 

56 y_trans = np.asarray(y_trans) 

57 

58 return X_trans, y_trans 

59 

60 def on_epoch_end(self): 

61 if self.shuffle: 

62 rng = np.random.default_rng(self.random_state) 

63 rng.shuffle(self.indices) 

64 

65 

66class Folder_Generator(tf.keras.utils.Sequence): 

67 def __init__( 

68 self, X, y, metamodel, batch=64, shuffle=True, train=True, random_state=None, 

69 dtype=np.float32 

70 ): 

71 """ 

72 Folder-based Sequence generator producing sliding-window batches for temporal models. 

73 Extracts past and future context around each batch, applies metamodel formatting, 

74 and returns masked input/output sequences compatible with Keras training and inference. 

75 """ 

76 self.X = X 

77 self.y = y 

78 self.random_state = random_state 

79 self.dtype = np.float32 

80 if X is not None: 

81 # X est supposé être une liste/tuple de arrays : [X0, X1, ...] 

82 self.len_ = X[0].shape[0] 

83 elif y is not None: 

84 self.len_ = y.shape[0] 

85 else: 

86 raise ValueError("Folder_Generator requires at least X or y to be non-None.") 

87 

88 self.train = train 

89 self.shuffle = shuffle 

90 self.batch = batch 

91 

92 self.factory = metamodel.factory 

93 self._format = metamodel._format 

94 self.rescale = metamodel.rescale 

95 

96 self.causality_remove = None 

97 self.model_parameters = metamodel.model_parameters 

98 self.past_horizon = metamodel.model_parameters["size_window"] 

99 self.futur_horizon = ( 

100 metamodel.model_parameters["dim_horizon"] 

101 * metamodel.model_parameters["step"] 

102 ) 

103 self.size_seq = self.past_horizon + self.futur_horizon + self.batch 

104 self.size_window_futur = 1 

105 

106 # nombre de batches 

107 self.n_batch = int(np.ceil(self.len_ / self.batch)) 

108 

109 # indices de batches (0, 1, ..., n_batch-1) pour le shuffle 

110 self.indices = np.arange(self.n_batch) 

111 if self.shuffle: 

112 rng = np.random.default_rng(self.random_state) 

113 rng.shuffle(self.indices) 

114 

115 def load(self, idx): 

116 """ 

117 Charge la séquence de données centrée autour du batch idx : 

118 [idx * batch - past_horizon, idx * batch + futur_horizon] 

119 """ 

120 idx = idx * self.batch 

121 

122 idx_min = max(0, idx - self.past_horizon) 

123 idx_max = max(self.size_seq + idx_min, idx + self.futur_horizon) 

124 

125 # cas du dernier batch : on peut remonter un peu pour compléter la fenêtre 

126 if idx > 0: 

127 idx_min = max(idx_min - max(0, idx_max - self.len_), 0) 

128 

129 y_batch = None 

130 if self.y is not None: 

131 y_batch = self.y[idx_min:idx_max] 

132 

133 if self.X is None: 

134 return [None, None], y_batch 

135 else: 

136 # X est supposé être une liste [X0, X1] 

137 return [self.X[0][idx_min:idx_max], self.X[1][idx_min:idx_max]], y_batch 

138 

139 def __len__(self): 

140 """Nombre de batches par epoch.""" 

141 return self.n_batch 

142 

143 def __getitem__(self, idx): 

144 if self.shuffle: 

145 idx = self.indices[idx] 

146 

147 x, y = self.load(idx) 

148 

149 Inputs, Outputs, _ = self.factory(x, y, fit_rescale=False) 

150 

151 selection = np.zeros(len(Inputs[0]), dtype=bool) 

152 idx_min = max(0, idx * self.batch - self.past_horizon) 

153 idx_max = max( 

154 self.size_seq + idx_min, 

155 idx * self.batch + self.batch + self.futur_horizon, 

156 ) 

157 

158 if self.train: 

159 selection[self.past_horizon: -self.futur_horizon] = True 

160 else: 

161 idx_min = max(0, idx * self.batch - self.past_horizon) 

162 idx_max = max( 

163 self.size_seq + idx_min, 

164 idx * self.batch + self.batch + self.futur_horizon, 

165 ) 

166 

167 if idx == 0: 

168 if self.batch >= self.len_: 

169 selection[:] = True 

170 else: 

171 selection[: -self.past_horizon - self.futur_horizon] = True 

172 else: 

173 padding_test = max(self.futur_horizon, idx_max - self.len_) 

174 selection[padding_test + self.past_horizon:] = True 

175 

176 Inputs = apply_mask(Inputs, selection) 

177 Outputs = apply_mask(Outputs, selection) 

178 

179 # Hold multi-input case 

180 if isinstance(Inputs, (list, tuple)): 

181 Inputs = tuple(np.asarray(xi) for xi in Inputs) 

182 else: 

183 Inputs = np.asarray(Inputs, dtype=self.dtype) 

184 

185 Outputs = np.asarray(Outputs, dtype=self.dtype) 

186 

187 return Inputs, Outputs 

188 

189 def on_epoch_end(self): 

190 """Shuffle des batches à la fin de chaque epoch.""" 

191 if self.shuffle: 

192 rng = np.random.default_rng(self.random_state) 

193 rng.shuffle(self.indices)