Coverage for tadkit/catalog/learners/_confiance_components/_sbad_wrapper.py: 18%
17 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:09 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:09 +0000
1from typing import List
4def get_wrapped_dilanodetectm():
5 """Returns the TADlearner wrapped from the DiLAnoDetectm method
6 of the Sparsity-Based Anomaly Detection framework.
8 The function is intended for use if the dependency is available.
9 """
11 from sbad_fnn.models import DiLAnoDetectm
13 class DiLAnoDetectmWrapper(DiLAnoDetectm):
15 def __init__(
16 self,
17 nb_atoms: int,
18 nb_blocks: int,
19 kernel_size: int,
20 nb_scales: int,
21 overlap_perc: float,
22 layers_sizes: List[int] = None,
23 activation: str = "id",
24 wvlt_name: str = "db2",
25 same: bool = False,
26 share_weights: bool = False,
27 soft: bool = False,
28 shrink: bool = True,
29 b: float = 1.5,
30 min_nbsegments: int = 50,
31 min_lag: int = 64,
32 ) -> None:
33 if layers_sizes is None:
34 layers_sizes = [39, 3]
35 super().__init__(
36 nb_atoms=nb_atoms,
37 nb_blocks=nb_blocks,
38 kernel_size=kernel_size,
39 layers_sizes=layers_sizes,
40 nb_scales=nb_scales,
41 overlap_perc=overlap_perc,
42 activation=activation,
43 wvlt_name=wvlt_name,
44 same=same,
45 share_weights=share_weights,
46 soft=soft,
47 shrink=shrink,
48 b=b,
49 min_nbsegments=min_nbsegments,
50 min_lag=min_lag,
51 ) # to be removed when lists are supported by widgets
53 def fit(self, x):
54 x = [[x.swapaxes(0, 1)]]
55 return super().fit(x)
57 def score_samples(self, x):
58 x = [[x.swapaxes(0, 1)]]
59 return -super().score_samples(x)[0][0][0][: len(x)]
61 DiLAnoDetectmWrapper.required_properties = ["multiple_time_series"]
62 DiLAnoDetectmWrapper.params_description = {
63 "nb_atoms": {
64 "description": "Number of patterns to learn.",
65 "value_type": "range",
66 "start": 1,
67 "step": 1,
68 "stop": 100, # @martin: stop put to 100 without any info
69 "default": 2,
70 },
71 "nb_blocks": {
72 "description": "Number of unrolled gradient steps blocks.",
73 "value_type": "range",
74 "start": 1,
75 "step": 1,
76 "stop": 100, # @martin: stop put to 100 without any info
77 "default": 1,
78 },
79 "kernel_size": {
80 "description": "Convolutional patterns size, must be odd.",
81 "value_type": "range",
82 "start": 1,
83 "step": 2,
84 "stop": 100, # @martin: stop put to 100 without any info
85 "default": 5,
86 },
87 # "layers_sizes": {
88 # "description": "Spectral autoencoder layers sizes.",
89 # "value_type": "array",
90 # "length": 2,
91 # "elements": {
92 # "value_type": "range", "start": 1, "step": 1,
93 # },
94 # "default": (39, 3),
95 # },
96 "nb_scales": {
97 "description": "Number of wavelet scales for decomposing time series.",
98 "value_type": "range",
99 "start": 1,
100 "step": 1,
101 "stop": 100, # @martin: stop put to 100 without any info
102 "default": 1,
103 },
104 "overlap_perc": {
105 "description": "Overlapping ratio between sliding windows.",
106 "value_type": "range",
107 "start": 0,
108 "step": 0.01,
109 "stop": 1,
110 "default": 0.98,
111 },
112 "same": {
113 "description": "Whether the blocks are identical or have different weights.",
114 "value_type": "bool_choice",
115 "default": False,
116 },
117 "share_weights": {
118 "description": "Whether within a given block, the conv and conv_transpose layers share weights or not.",
119 "value_type": "bool_choice",
120 "default": False,
121 },
122 "soft": {
123 "description": "Whether the shrinking is soft.",
124 "value_type": "bool_choice",
125 "default": False,
126 },
127 "shrink": {
128 "description": "Whether a shrinking activation function is applied to each block's output.",
129 "value_type": "bool_choice",
130 "default": True,
131 },
132 "min_nbsegments": {
133 "description": "Minimum number of segments.",
134 "value_type": "range",
135 "start": 1,
136 "step": 1,
137 "stop": 100, # @martin: stop put to 100 without any info
138 "default": 1,
139 },
140 "min_lag": {
141 "description": "Minimum lag.",
142 "value_type": "range",
143 "start": 1,
144 "step": 1,
145 "stop": 100, # @martin: stop put to 100 without any info
146 "default": 2,
147 },
148 }
150 return DiLAnoDetectmWrapper