Coverage for tdaad/utils/local_pipeline.py: 50%

12 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-13 13:45 +0000

1"""Window Functions.""" 

2 

3# Author: Martin Royer 

4 

5from sklearn.pipeline import Pipeline 

6 

7 

8class LocalPipeline(Pipeline): 

9 """ Local pipeline modification for added functionality. """ 

10 

11 def __getitem__(self, ind): 

12 """ @hack to circumvent Pipeline's __getitem__ usage of self.__class__ 

13 This allows LocalPipeline objects inheriting to still support slicing. 

14 It fails with their current implementation as self.__class__ constructor 

15 does not point to Pipeline constructor, but to the inherited class. 

16 

17 Returns a sub-pipeline or a single estimator in the pipeline 

18 

19 Indexing with an integer will return an estimator; using a slice 

20 returns another Pipeline instance which copies a slice of this 

21 Pipeline. This copy is shallow: modifying (or fitting) estimators in 

22 the sub-pipeline will affect the larger pipeline and vice-versa. 

23 However, replacing a value in `step` will not affect a copy. 

24 """ 

25 if isinstance(ind, slice): 

26 if ind.step not in (1, None): 

27 raise ValueError("Pipeline slicing only supports a step of 1") 

28 return Pipeline( 

29 self.steps[ind], memory=self.memory, verbose=self.verbose 

30 ) 

31 try: 

32 _, est = self.steps[ind] 

33 except TypeError: 

34 # Not an int, try get step by name 

35 return self.named_steps[ind] 

36 return est