Source code for tadkit.utils.decomposable_tadlearner

import inspect
from typing import Type, Optional
from tadkit.base.tadlearner import TADLearner
from tadkit.base.typing import ArrayLike


[docs] class DecomposableTADLearner(TADLearner): """Abstract base class combining a Preprocessor and a TADLearner.""" Preprocessor: Type Learner: Type[TADLearner] def __init__(self, **params): """Instantiate preprocessor and learner using relevant subset of kwargs.""" pre_sig = inspect.signature(self.Preprocessor.__init__) learner_sig = inspect.signature(self.Learner.__init__) pre_params = {k: v for k, v in params.items() if k in pre_sig.parameters} learner_params = {k: v for k, v in params.items() if k in learner_sig.parameters} self.preprocessor = self.Preprocessor(**pre_params) self.learner = self.Learner(**learner_params)
[docs] def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> "DecomposableTADLearner": if hasattr(self.preprocessor, "fit_transform"): X_transformed = self.preprocessor.fit_transform(X, y) else: if hasattr(self.preprocessor, "fit"): self.preprocessor.fit(X, y) X_transformed = self.embed(X) self.learner.fit(X_transformed, y) return self
[docs] def embed(self, X: ArrayLike) -> ArrayLike: if hasattr(self.preprocessor, "transform"): return self.preprocessor.transform(X) return self.preprocessor(X)
[docs] def score_samples(self, X: ArrayLike) -> ArrayLike: X_transformed = self.embed(X) return self.learner.score_samples(X_transformed)
[docs] def decomposable_tadlearner_factory( Preprocessor: Type, Learner: Type[TADLearner], name: Optional[str] = None, ) -> Type[TADLearner]: """Create a TADLearner class combining a preprocessor and a learner with a proper __init__ signature.""" cls_name = name or f"{Preprocessor.__name__}{Learner.__name__}" class Wrapped(DecomposableTADLearner): __doc__ = f"{cls_name}: DecomposableTADLearner combining {Preprocessor.__name__} and {Learner.__name__}" __name__ = cls_name __qualname__ = cls_name Wrapped.Preprocessor = Preprocessor Wrapped.Learner = Learner pre_sig = inspect.signature(Preprocessor.__init__) learner_sig = inspect.signature(Learner.__init__) combined_params = [ p.replace(kind=inspect.Parameter.KEYWORD_ONLY) if p.kind == inspect.Parameter.VAR_KEYWORD else p for p in list(pre_sig.parameters.values())[1:] + list(learner_sig.parameters.values())[1:] ] combined_sig = inspect.Signature(combined_params) def __init__(self, *args, **kwargs): bound = combined_sig.bind(*args, **kwargs) bound.apply_defaults() super(Wrapped, self).__init__(**bound.arguments) __init__.__signature__ = combined_sig __init__.__doc__ = f"{cls_name} __init__ combines {Preprocessor.__name__} and {Learner.__name__} parameters" Wrapped.__init__ = __init__ return Wrapped