Source code for tadkit.base.registry

# tadkit/learners/registry.py
import importlib
from typing import Any, Callable, Dict, List, Type, Union
import inspect

from tadkit.base.tadlearner import TADLearner

HEADER = "\033[95m"
FAIL = "\033[91m"
FAIL_PROP = "\033[93m"
ENDC = "\033[0m"


[docs] class Registry: """Registry for dynamically tracking and matching learner classes.""" def __init__(self): # name -> {"class": object or callable, "condition": callable} self._learners: Dict[str, Dict[str, Any]] = {} # ----------------------- # Learner registration # -----------------------
[docs] def register_learner( self, name: str, learner: Union[Type, str], condition: Callable[[Any], bool], optional: bool = False, ): """ Register a learner with a compatibility condition. Parameters ---------- name : str Display name for the learner. learner : class or str The learner class OR an import path ("module.submodule.ClassName"). condition : callable(formatter) -> bool Determines whether this learner is compatible. optional : bool If True, missing imports are ignored instead of raising. """ cls = None if isinstance(learner, str): # Try dynamic import module_name, cls_name = learner.rsplit(".", 1) try: mod = importlib.import_module(module_name) cls = getattr(mod, cls_name) except (ModuleNotFoundError, AttributeError) as e: if optional: print(f"[registry] Skipping optional learner '{name}': {e}") return raise else: cls = learner self._learners[name] = {"class": cls, "condition": condition}
# ----------------------- # Matching # -----------------------
[docs] def match_learners(self, formatter: Any) -> List[Type]: """Return learner classes compatible with the given formatter.""" matches = [] for name, info in self._learners.items(): try: if info["condition"](formatter): matches.append(info["class"]) except Exception as e: print(f"[registry] Learner '{name}' failed match check: {e}") return matches
[docs] def list_learners(self) -> List[str]: return list(self._learners.keys())
# ----------------------- # Utils # ----------------------- @staticmethod def _validate_default_init(learner_class, learner_name): """ Validate that the `cls` satisfies the `protocol`'s __init__ method with the required default values. """ class_init = getattr(learner_class, "__init__", None) if not class_init: print(f"{FAIL_PROP}{learner_name} must have an __init__ method.{ENDC}") # Check __init__ signatures class_sig = inspect.signature(class_init) # Check default values for param_name, param in class_sig.parameters.items(): if param_name == "self": continue if param.default is param.empty: print( f"{FAIL_PROP}{learner_name}.__init__ parameter '{param_name}' must have default value.{ENDC}" ) return
[docs] @staticmethod def print_compliance_miss(learner): missing = [] for attr in TADLearner.__annotations__: if not hasattr(learner, attr): missing.append(attr) print("Missing attributes:", missing) for name in dir(TADLearner): if callable(getattr(TADLearner, name, None)) and not name.startswith("__"): if not hasattr(learner, name): print(f"Missing method: {name}")
def _print_class(self, learner_name, learner_class, detailed=False): print(f"Class {HEADER}{learner_name=}{ENDC} is registered in TADKit.") try: if inspect.isclass(learner_class): print(f"{learner_name} is operational in this environment.") if isinstance(learner_class, TADLearner): print(f"{learner_name} is implicit child of TADLearner.") else: print( f"{FAIL_PROP}{learner_name} doesn't implicitly inherit from TADLearner:{ENDC}" ) self.print_compliance_miss(learner_class) Registry._validate_default_init(learner_class, learner_name) except ModuleNotFoundError as err: print(f"{FAIL}{learner_name} returns {err=}.{ENDC}") return try: if detailed: printed_params_description = { name: str(param_description) for name, param_description in learner_class.metadata.items() } print(f"{learner_name} has {printed_params_description=}.") except AttributeError as err: print( f"{FAIL_PROP}{learner_name} with signature {learner_class=} returns {err=}.{ENDC}" ) try: if detailed: print(f"{learner_name} has {learner_class.required_properties=}.") except AttributeError as err: print( f"{FAIL_PROP}{learner_name} with signature {learner_class=} returns {err=}.{ENDC}" )
[docs] def print_catalog_classes(self, detailed=False): print("[TADkit registered Catalog]") for learner_name, learner in self._learners.items(): self._print_class( learner_name, learner_class=learner["class"], detailed=detailed )
# global instance registry = Registry()