Coverage for uqmodels / modelization / DL_estimator / utils.py: 97%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 08:15 +0000

1import os 

2import numpy as np 

3import tensorflow as tf 

4import random 

5 

6 

7def sum_part_prod(array): 

8 """compute sum_part_prod 

9 array = [k1,...,kn] 

10 return (k1+k1k2+k1k2k3+..+k1..Kn) 

11 """ 

12 s = 0 

13 for n in range(len(array)): 

14 s += np.prod(array[:n]) 

15 return s 

16 

17 

18def size_post_conv(w, l_k, l_st): 

19 """provide size post conv (with padding=valid) 

20 w : size of window 

21 l_k : list kernel 

22 l_s : list_stride 

23 """ 

24 curent_s = w 

25 for k, st in zip(l_k, l_st): 

26 curent_s = np.ceil((curent_s - k + 1) / st) 

27 return curent_s 

28 

29 

30def find_conv_kernel(window_initial, size_final, list_strides): 

31 """Return size of kernel according to : 

32 window_initial : size of window 

33 size_final : size final 

34 list_strides : list of strides 

35 

36 return(list_kernel,list_strides) 

37 """ 

38 

39 val = sum_part_prod(list_strides[:-1]) 

40 float_kernel = (size_final * np.prod(list_strides[:-1]) - window_initial) / val - 1 

41 kernel = int(max(np.floor(-float_kernel) - 1, 1)) 

42 before_last_size = size_post_conv( 

43 window_initial, [kernel for i in list_strides[:-1]], list_strides[:-1] 

44 ) 

45 last_kernel = (before_last_size - size_final + 1) / list_strides[-1] 

46 

47 if last_kernel < 1: 

48 raise (ValueError("Incompatible list_strides values")) 

49 

50 list_kernel = [kernel for i in list_strides] 

51 list_kernel[-1] = int(last_kernel) 

52 return (list_kernel, list_strides) 

53 

54 

55def set_seeds(seed=None): 

56 if seed is not None: 

57 os.environ["PYTHONHASHSEED"] = str(seed) 

58 random.seed(seed) 

59 tf.random.set_seed(seed) 

60 np.random.seed(seed) 

61 

62 

63def set_global_determinism(seed=None): 

64 if seed is not None: 

65 set_seeds(seed=seed) 

66 

67 os.environ["TF_DETERMINISTIC_OPS"] = "1" 

68 os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 

69 

70 # tf.config.threading.set_inter_op_parallelism_threads(1) 

71 # tf.config.threading.set_intra_op_parallelism_threads(1)