Coverage for uqmodels / modelization / DL_estimator / utils.py: 97%
36 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 08:15 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 08:15 +0000
1import os
2import numpy as np
3import tensorflow as tf
4import random
7def sum_part_prod(array):
8 """compute sum_part_prod
9 array = [k1,...,kn]
10 return (k1+k1k2+k1k2k3+..+k1..Kn)
11 """
12 s = 0
13 for n in range(len(array)):
14 s += np.prod(array[:n])
15 return s
18def size_post_conv(w, l_k, l_st):
19 """provide size post conv (with padding=valid)
20 w : size of window
21 l_k : list kernel
22 l_s : list_stride
23 """
24 curent_s = w
25 for k, st in zip(l_k, l_st):
26 curent_s = np.ceil((curent_s - k + 1) / st)
27 return curent_s
30def find_conv_kernel(window_initial, size_final, list_strides):
31 """Return size of kernel according to :
32 window_initial : size of window
33 size_final : size final
34 list_strides : list of strides
36 return(list_kernel,list_strides)
37 """
39 val = sum_part_prod(list_strides[:-1])
40 float_kernel = (size_final * np.prod(list_strides[:-1]) - window_initial) / val - 1
41 kernel = int(max(np.floor(-float_kernel) - 1, 1))
42 before_last_size = size_post_conv(
43 window_initial, [kernel for i in list_strides[:-1]], list_strides[:-1]
44 )
45 last_kernel = (before_last_size - size_final + 1) / list_strides[-1]
47 if last_kernel < 1:
48 raise (ValueError("Incompatible list_strides values"))
50 list_kernel = [kernel for i in list_strides]
51 list_kernel[-1] = int(last_kernel)
52 return (list_kernel, list_strides)
55def set_seeds(seed=None):
56 if seed is not None:
57 os.environ["PYTHONHASHSEED"] = str(seed)
58 random.seed(seed)
59 tf.random.set_seed(seed)
60 np.random.seed(seed)
63def set_global_determinism(seed=None):
64 if seed is not None:
65 set_seeds(seed=seed)
67 os.environ["TF_DETERMINISTIC_OPS"] = "1"
68 os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
70 # tf.config.threading.set_inter_op_parallelism_threads(1)
71 # tf.config.threading.set_intra_op_parallelism_threads(1)