Coverage for uqmodels/modelization/DL_estimator/transformer_ed.py: 63%
237 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:29 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:29 +0000
1import keras.backend as K
2import numpy as np
3import tensorflow as tf
4from keras.layers import TimeDistributed
5from tensorflow import keras
6from tensorflow.keras import Input, layers
8from uqmodels.modelization.DL_estimator.data_embedding import (
9 Factice_Time_Extension,
10 Mouving_conv_Embedding,
11 Mouving_Windows_Embedding,
12 PositionalEmbedding,
13)
14from uqmodels.modelization.DL_estimator.neural_network_UQ import (
15 NN_UQ,
16 get_training_parameters,
17 get_UQEstimator_parameters,
18 mlp,
19)
20from uqmodels.modelization.DL_estimator.utils import (
21 Folder_Generator,
22 set_global_determinism,
23)
24from uqmodels.utils import add_random_state, stack_and_roll
27@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
28class MultiHeadAttention(tf.keras.layers.MultiHeadAttention):
29 pass
32@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
33class LayerNormalization(tf.keras.layers.LayerNormalization):
34 pass
37@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
38class Dropout(tf.keras.layers.Dropout):
39 pass
42# Transformer Encoder Layer
43@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
44class Dense(tf.keras.layers.Dense):
45 pass
48# Transformer Encoder Layer
49@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
50class TransformerEncoder(layers.Layer):
51 """Transformer Encoder Layer from https://keras.io/examples/audio/transformer_asr/"""
53 def __init__(
54 self,
55 dim_z,
56 num_heads,
57 feed_forward_dim,
58 dp_rec=0.1,
59 flag_mc=False,
60 random_state=None,
61 **kwargs
62 ):
64 super().__init__()
65 self.dim_z = dim_z
66 self.num_heads = num_heads
67 self.feed_forward_dim = feed_forward_dim
68 self.dp_rec = dp_rec
69 self.flag_mc = flag_mc
70 self.random_state = random_state
71 set_global_determinism(self.random_state)
73 # Layers instanciation
74 self.att = MultiHeadAttention(num_heads=num_heads, key_dim=dim_z)
75 self.dense1 = Dense(feed_forward_dim, activation="relu")
76 self.dense2 = Dense(dim_z)
78 self.layernorm1 = LayerNormalization(epsilon=1e-6)
79 self.layernorm2 = LayerNormalization(epsilon=1e-6)
80 self.dropout1 = Dropout(dp_rec, seed=self.random_state)
81 self.dropout2 = Dropout(dp_rec, seed=add_random_state(self.random_state, 1))
83 def call(self, inputs, training=None):
84 """_summary_
86 Args:
87 inputs (_type_): _description_
88 training (_type_): _description_
90 Returns:
91 _type_: _description_
92 """
93 if training is None:
94 training = False
96 attn_output = self.att(inputs, inputs)
97 if self.dp_rec > 0:
98 attn_output = self.dropout1(attn_output, training=training | self.flag_mc)
99 out1 = self.layernorm1(inputs + attn_output)
100 ffn_output = self.dense2(self.dense1(out1))
101 if self.dp_rec > 0:
102 ffn_output = self.dropout2(ffn_output, training=training | self.flag_mc)
103 return self.layernorm2(out1 + ffn_output)
105 def get_config(self):
106 config = {
107 "dim_z": self.dim_z,
108 "num_heads": self.num_heads,
109 "feed_forward_dim": self.feed_forward_dim,
110 "dp_rec": self.dp_rec,
111 "flag_mc": self.flag_mc,
112 "random_state": self.random_state,
113 "att": tf.keras.utils.serialize_keras_object(self.att),
114 "layernorm1": tf.keras.utils.serialize_keras_object(self.layernorm1),
115 "layernorm2": tf.keras.utils.serialize_keras_object(self.layernorm2),
116 "dense1": tf.keras.utils.serialize_keras_object(self.dense1),
117 "dense2": tf.keras.utils.serialize_keras_object(self.dense2),
118 "dropout1": tf.keras.utils.serialize_keras_object(self.dropout1),
119 "dropout2": tf.keras.utils.serialize_keras_object(self.dropout2),
120 }
121 config = config
122 return config
124 @classmethod
125 def from_config(cls, config):
126 att = config.pop("att")
127 layernorm1 = config.pop("layernorm1")
128 layernorm2 = config.pop("layernorm2")
129 dropout1 = config.pop("dropout1")
130 dropout2 = config.pop("dropout2")
131 dense1 = config.pop("dense1")
132 dense2 = config.pop("dense2")
134 obj = cls(**config)
135 print(dense1)
136 print(att)
137 obj.att = tf.keras.utils.deserialize_keras_object(att)
138 obj.layernorm1 = tf.keras.utils.deserialize_keras_object(layernorm1)
139 obj.layernorm2 = tf.keras.utils.deserialize_keras_object(layernorm2)
140 obj.dropout1 = tf.keras.utils.deserialize_keras_object(dropout1)
141 obj.dropout2 = tf.keras.utils.deserialize_keras_object(dropout2)
142 obj.dense1 = tf.keras.utils.deserialize_keras_object(dense1)
143 obj.dense2 = tf.keras.utils.deserialize_keras_object(dense2)
145 return obj
148# Transformer Decoder Layer
149@tf.keras.utils.register_keras_serializable(package="UQModels_layers")
150class TransformerDecoder(layers.Layer):
151 """Transformer Encoder Layer from https://keras.io/examples/audio/transformer_asr/"""
153 def __init__(
154 self,
155 dim_z,
156 dim_horizon,
157 num_heads,
158 feed_forward_dim,
159 dp_rec=0.1,
160 flag_mc=False,
161 random_state=None,
162 **kwargs
163 ):
165 super().__init__()
166 self.dim_z = dim_z
167 self.dim_horizon = dim_horizon
168 self.num_heads = num_heads
169 self.feed_forward_dim = feed_forward_dim
170 self.dp_rec = dp_rec
171 self.flag_mc = flag_mc
172 self.random_state = random_state
173 set_global_determinism(self.random_state)
175 self.layernorm1 = LayerNormalization(epsilon=1e-6)
176 self.layernorm2 = LayerNormalization(epsilon=1e-6)
177 self.layernorm3 = LayerNormalization(epsilon=1e-6)
178 self.self_att = MultiHeadAttention(num_heads=num_heads, key_dim=dim_z)
179 self.enc_att = MultiHeadAttention(num_heads=num_heads, key_dim=dim_z)
180 self.self_dropout = Dropout(dp_rec, seed=random_state)
181 self.enc_dropout = Dropout(dp_rec, seed=add_random_state(random_state, 1))
182 self.ffn_dropout = Dropout(dp_rec, seed=add_random_state(random_state, 2))
183 self.dense1 = Dense(feed_forward_dim, activation="relu")
184 self.dense2 = Dense(dim_z)
186 def causal_attention_mask(self, batch_size, n_dest, n_src, dim_horizon, dtype):
187 """Masks the upper half of the dot product matrix in self attention.
189 This prevents flow of information from future tokens to current token.
190 1's in the lower triangle, counting from the lower right corner.
191 """
192 len_past = n_dest - dim_horizon
193 i = tf.concat(
194 [
195 tf.zeros(len_past, dtype=tf.int32) + len_past - 1,
196 tf.range(dim_horizon) + len_past,
197 ],
198 0,
199 )[:, None]
200 j = tf.range(n_src)
201 m = (i) >= (j - n_src + n_dest)
202 mask = tf.cast(m, dtype)
203 mask = tf.reshape(mask, [1, n_dest, n_src])
204 mult = tf.concat(
205 [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
206 )
207 return tf.tile(mask, mult)
209 def call(self, enc_out, target, training=None):
210 """_summary_
212 Args:
213 enc_out (_type_): _description_
214 target (_type_): _description_
216 Returns:
217 _type_: _description_
218 """
219 if training is None:
220 training = False
222 input_shape = tf.shape(target)
223 batch_size = input_shape[0]
224 seq_len = input_shape[1]
225 causal_mask = self.causal_attention_mask(
226 batch_size, seq_len, seq_len, self.dim_horizon, tf.bool
227 )
228 target_att = self.self_att(target, target, attention_mask=causal_mask)
229 target_norm = self.layernorm1(
230 target + self.self_dropout(target_att, training=training | self.flag_mc)
231 )
232 enc_out = self.enc_att(target_norm, enc_out)
233 enc_out_norm = self.layernorm2(
234 self.enc_dropout(enc_out, training=training | self.flag_mc) + target_norm
235 )
236 ffn_out = self.dense2(self.dense1(enc_out_norm))
237 ffn_out_norm = self.layernorm3(
238 enc_out_norm + self.ffn_dropout(ffn_out, training=training | self.flag_mc)
239 )
240 return ffn_out_norm
242 def get_config(self):
243 config = {
244 "dim_z": self.dim_z,
245 "dim_horizon": self.dim_horizon,
246 "num_heads": self.num_heads,
247 "feed_forward_dim": self.feed_forward_dim,
248 "dp_rec": self.dp_rec,
249 "flag_mc": self.flag_mc,
250 "random_state": self.random_state,
251 "layernorm1": tf.keras.utils.serialize_keras_object(self.layernorm1),
252 "layernorm2": tf.keras.utils.serialize_keras_object(self.layernorm2),
253 "layernorm3": tf.keras.utils.serialize_keras_object(self.layernorm3),
254 "self_att": tf.keras.utils.serialize_keras_object(self.self_att),
255 "enc_att": tf.keras.utils.serialize_keras_object(self.enc_att),
256 "self_dropout": tf.keras.utils.serialize_keras_object(self.self_dropout),
257 "enc_dropout": tf.keras.utils.serialize_keras_object(self.enc_dropout),
258 "ffn_dropout": tf.keras.utils.serialize_keras_object(self.ffn_dropout),
259 "dense1": tf.keras.utils.serialize_keras_object(self.dense1),
260 "dense2": tf.keras.utils.serialize_keras_object(self.dense2),
261 }
262 return config
264 @classmethod
265 def from_config(cls, config):
266 layernorm1 = config.pop("layernorm1")
267 layernorm2 = config.pop("layernorm2")
268 layernorm3 = config.pop("layernorm3")
269 self_att = config.pop("self_att")
270 enc_att = config.pop("enc_att")
271 self_dropout = config.pop("self_dropout")
272 enc_dropout = config.pop("enc_dropout")
273 ffn_dropout = config.pop("ffn_dropout")
274 dense1 = config.pop("dense1")
275 dense2 = config.pop("dense2")
276 print(config)
277 obj = cls(**config)
279 obj.layernorm1 = tf.keras.utils.deserialize_keras_object(layernorm1)
280 obj.layernorm2 = tf.keras.utils.deserialize_keras_object(layernorm2)
281 obj.layernorm3 = tf.keras.utils.deserialize_keras_object(layernorm3)
282 obj.self_att = tf.keras.utils.deserialize_keras_object(self_att)
283 obj.enc_att = tf.keras.utils.deserialize_keras_object(enc_att)
284 obj.self_dropout = tf.keras.utils.deserialize_keras_object(self_dropout)
285 obj.enc_dropout = tf.keras.utils.deserialize_keras_object(enc_dropout)
286 obj.ffn_dropout = tf.keras.utils.deserialize_keras_object(ffn_dropout)
287 obj.dense1 = tf.keras.utils.deserialize_keras_object(dense1)
288 obj.dense2 = tf.keras.utils.deserialize_keras_object(dense2)
289 return obj
292# encoder
293def build_transformer(
294 size_window=10,
295 n_windows=5,
296 step=1,
297 dim_target=1,
298 dim_chan=1,
299 dim_horizon=3,
300 dim_ctx=20,
301 dim_z=100,
302 num_heads=2,
303 num_feed_forward=128,
304 num_layers_enc=3,
305 num_layers_dec=2,
306 layers_enc=[150],
307 layers_dec=[150, 75],
308 dp=0.05,
309 dp_rec=0.03,
310 k_reg=(0.00001, 0.00001),
311 list_strides=[2, 1],
312 list_filters=None,
313 list_kernels=None,
314 dim_dyn=None,
315 with_positional_embedding=False,
316 with_ctx_input=True,
317 with_convolution=True,
318 type_output=None,
319 random_state=None,
320 **kwargs
321):
322 """Builder for Transformer ED with convolutive preprocessing
324 Args:
325 size_window (int, optional): Size of window for lag values. Defaults to 10.
326 n_windows (int, optional): Number of window in past. Defaults to 5.
327 step (int, optional): step between windows. Defaults to 1.
328 dim_target (int, optional): dimension of TS. Defaults to 1.
329 dim_chan (int, optional): Number of channel of TS. Defaults to 1.
330 dim_horizon (int, optional): futur_horizon to predict. Defaults to 3.
331 dim_ctx (int, optional): Number of ctx_features. Defaults to 20.
332 dim_z (int, optional): Size of latent sapce. Defaults to 100.
333 num_heads (int, optional): num of heads transformer. Defaults to 2.
334 num_feed_forward (int, optional): feed_forward transfomer dimension. Defaults to 128.
335 num_layers_enc (int, optional): num of transformer enc block
336 (after concatenation of past values embeding + ctx) . Defaults to 3.
337 num_layers_dec (int, optional): num of transformer dec block Defaults to 2.
338 layers_enc (list, optional):size of MLP preprocessing
339 (after concatenation of past values embeding + ctx) Defaults to [150].
340 layers_dec (list, optional): size of MLP interpretor. Defaults to 2.
341 dp (float, optional): dropout. Defaults to 0.05.
342 dp_t (float, optional): transformer dropout. Defaults to 0.1.
343 k_reg (tuple, optional): _description_. Defaults to (0.00001, 0.00001).
344 dim_dyn (int, None): size of dyn inputs, if None consider dim_dyn have same size than dim target
345 with_positional_embedding (bool, optional): _description_. Defaults to False.
346 with_ctx_input (bool, optional): Expect ctx features in addition to lag. Defaults to True.
347 with_convolution (bool, optional): use convolution rather than
348 whole lag values in the windows. Defaults to True.
349 type_output (_type_, optional): mode of UQ (see NN_UQ). Defaults to None.
350 random_state (bool): handle experimental random using seed.
351 Returns:
352 transformer : multi-step forecaster with UQ
353 """
354 if dim_dyn is None:
355 dim_dyn = dim_target
357 flag_mc = 0
358 if type_output in ["BNN", "MC_Dropout"]:
359 flag_mc = 1
361 set_global_determinism(random_state)
363 # Embedding_interpretor
364 Interpretor = mlp(
365 dim_in=dim_z,
366 dim_out=dim_target,
367 layers_size=layers_dec,
368 dp=dp,
369 type_output=type_output,
370 name="Interpretor",
371 random_state=random_state,
372 )
374 # dim_output_size = Interpretor.output.shape[-1]
376 Pos_Embeddor = None
377 if with_positional_embedding:
378 Pos_Embeddor = PositionalEmbedding(dim_z, max_len=size_window + dim_horizon - 1)
380 # Input definition
382 list_input = []
383 if with_ctx_input:
384 CTX_inputs = Input(shape=(n_windows, dim_ctx), name="encoder_inputs")
385 list_input.append(CTX_inputs)
387 Y_past_in = Input(shape=(size_window, dim_dyn), name="past_inputs")
388 list_input.append(Y_past_in)
390 Y_past = Y_past_in
392 # Preprocessing layers definition
393 if with_convolution:
394 MWE = Mouving_conv_Embedding(
395 size_window,
396 n_windows,
397 step=step,
398 dim_d=dim_dyn,
399 dim_chan=dim_chan,
400 use_conv2D=True,
401 list_strides=list_strides,
402 list_filters=list_filters,
403 list_kernels=list_kernels,
404 dp=0.05,
405 flag_mc=flag_mc,
406 seed=add_random_state(random_state, 100),
407 )
408 else:
409 MWE = Mouving_Windows_Embedding(
410 size_window,
411 n_windows,
412 step=step,
413 dim_d=dim_dyn,
414 dim_chan=dim_chan,
415 seed=add_random_state(random_state, 100),
416 )
418 FTE = Factice_Time_Extension(dim_horizon)
419 layers_enc.append(dim_z)
421 dim_embedding = MWE.last_shape
422 if with_ctx_input:
423 dim_embedding += dim_ctx
425 Embeddor_ctx = mlp(
426 dim_in=dim_embedding,
427 dim_out=None,
428 layers_size=layers_enc,
429 dp=dp,
430 name="Embeddor",
431 regularizer_W=k_reg,
432 random_state=add_random_state(random_state, 200),
433 )
435 # Preprocessing computation
436 Data = MWE(Y_past)
437 # Concat with cat features
438 if with_ctx_input:
439 Data = K.concatenate([CTX_inputs, Data], axis=-1)
440 # Factice time augmentation (actually useless but can be usefull for extended predict horizon)
441 Data = FTE(Data)
443 Embedding = TimeDistributed(Embeddor_ctx)(Data)
445 # Static Pe that encode window position
447 if Pos_Embeddor:
448 Pe_Embedding = Pos_Embeddor(Embedding)
449 Embedding = Embedding + Pe_Embedding
451 # Encoder l'information passé
452 enc_out = Embedding[:, :(-dim_horizon), :]
453 encoder = []
454 for i in range(num_layers_enc):
455 encoder.append(
456 TransformerEncoder(
457 dim_z,
458 num_heads,
459 feed_forward_dim=50,
460 num_feed_forward=num_feed_forward,
461 dp_rec=dp_rec,
462 flag_mc=flag_mc,
463 random_state=add_random_state(random_state, 300 + i),
464 )
465 )
466 enc_out = encoder[-1](enc_out)
468 # For learning :
469 decoder = []
470 dec_out = enc_out
471 for i in range(num_layers_dec):
472 decoder.append(
473 TransformerDecoder(
474 dim_z=dim_z,
475 dim_horizon=dim_horizon,
476 feed_forward_dim=50,
477 num_heads=num_heads,
478 num_feed_forward=num_feed_forward,
479 dp_rec=dp_rec,
480 flag_mc=flag_mc,
481 random_state=add_random_state(random_state, 400 + i),
482 )
483 )
484 dec_out = decoder[-1](dec_out, Embedding)
486 outputs = TimeDistributed(Interpretor)(dec_out[:, -(dim_horizon):])
488 model = tf.keras.Model(list_input, outputs, name="model")
489 return model
492class Transformer_ED_UQ(NN_UQ):
493 """Transformer_ED for forecasting with UQ : see build_transformer to check model parameters"""
495 def __init__(
496 self,
497 model_parameters,
498 factory_parameters={"factory_lag_lt": 0, "factory_lag_st": 0},
499 training_parameters=dict(),
500 type_output=None,
501 rescale=False,
502 n_ech=5,
503 train_ratio=0.9,
504 name="Lstm_stacked",
505 random_state=None,
506 ):
507 """Initialization
509 Args:
510 model_parameters (_type_): _description_
511 factory_parameters (dict, optional): _description_. Defaults to {'factory_lag_lt': 0, 'factory_lag_st': 0}.
512 training_parameters (_type_, optional): _description_. Defaults to dict().
513 type_output (_type_, optional): _description_. Defaults to None.
514 rescale (bool, optional): _description_. Defaults to False.
515 n_ech (int, optional): _description_. Defaults to 8.
516 train_ratio (float, optional): _description_. Defaults to 0.9.
517 name (str, optional): _description_. Defaults to "Lstm_stacked".
518 random_state (bool): handle experimental random using seed.
520 """
521 if (random_state) is not None:
522 print("Warning : issues non-deterministic behaviour even with random state")
524 super().__init__(
525 model_initializer=build_transformer,
526 model_parameters=model_parameters,
527 factory_parameters=factory_parameters,
528 training_parameters=training_parameters,
529 type_output=type_output,
530 rescale=rescale,
531 n_ech=n_ech,
532 train_ratio=train_ratio,
533 name=name,
534 random_state=random_state,
535 )
537 def factory(self, X, y, mask=None, only_fit_scaler=False, **kwarg):
538 model_params = self.model_parameters
539 factory_params = self.factory_parameters
541 with_ctx_input = model_params["with_ctx_input"]
543 step = 1
544 if "step" in model_params.keys():
545 step = model_params["step"]
547 X_none = False
548 if X is None:
549 X_none = True
551 if X_none:
552 inputs = None
553 else:
554 if with_ctx_input:
555 X, X_lag = X
556 X, X_lag, mask = super().factory(X, X_lag, mask)
557 if only_fit_scaler:
558 return None
559 X_lt = stack_and_roll(
560 X,
561 model_params["n_windows"],
562 lag=factory_params["factory_lag_lt"],
563 step=step,
564 )
566 X_st = stack_and_roll(
567 X_lag,
568 model_params["size_window"],
569 lag=factory_params["factory_lag_st"] - 1,
570 step=step,
571 )
573 inputs = [X_lt, X_st]
574 else:
575 X, _, _ = super().factory(X, None, mask)
576 if only_fit_scaler:
577 return None
578 X_lag = X
579 X_st = stack_and_roll(
580 X,
581 model_params["size_window"],
582 lag=factory_params["factory_lag_st"] - 1,
583 step=step,
584 )
585 inputs = [X_st]
587 new_y = None
588 if y is not None:
589 _, y, _ = super().factory(None, y, mask)
590 new_y = stack_and_roll(
591 y,
592 model_params["dim_horizon"],
593 lag=model_params["dim_horizon"] - 1,
594 step=step,
595 )
596 return inputs, new_y, mask
598 def Build_generator(self, X, y, batch=32, shuffle=True, train=True):
599 return Folder_Generator(
600 X,
601 y,
602 self,
603 batch=batch,
604 shuffle=shuffle,
605 train=train,
606 random_state=self.random_state,
607 )
610def get_params_dict(
611 dim_ctx,
612 dim_dyn,
613 dim_target,
614 dim_chan=1,
615 size_window=20,
616 n_windows=5,
617 dim_horizon=5,
618 dim_z=50,
619 dp=0.05,
620 dp_rec=0.02,
621 num_heads=2,
622 num_feed_forward=128,
623 num_layers_enc=3,
624 num_layers_dec=2,
625 layers_enc=[75, 150, 75],
626 layers_dec=[200, 125, 75],
627 list_strides=[2, 1, 1, 1],
628 list_filters=[128, 128, 128],
629 list_kernels=None,
630 with_convolution=True,
631 with_ctx_input=True,
632 n_ech=3,
633 type_output="MC_Dropout",
634 random_state=None,
635):
636 dict_params = {
637 "dim_ctx": dim_ctx,
638 "size_window": size_window,
639 "n_windows": n_windows,
640 "dim_horizon": dim_horizon,
641 "dim_target": dim_target,
642 "dim_chan": dim_chan,
643 "step": 1,
644 "dim_z": dim_z,
645 "dp": dp,
646 "dp_rec": dp_rec,
647 "dim_dyn": dim_dyn,
648 "type_output": type_output,
649 "num_heads": num_heads,
650 "num_feed_forward": num_feed_forward,
651 "num_layers_enc": num_layers_enc,
652 "num_layers_dec": num_layers_dec,
653 "k_reg": (10e-6, 10e-6),
654 "layers_enc": layers_enc,
655 "layers_dec": layers_dec,
656 "list_strides": list_strides,
657 "list_filters": list_filters,
658 "list_kernels": list_kernels,
659 "with_convolution": with_convolution,
660 "with_ctx_input": with_ctx_input,
661 "n_ech": n_ech,
662 "random_state": random_state,
663 }
664 return dict_params