Coverage for tadkit / utils / param_spec.py: 82%
201 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-03 15:41 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-03 15:41 +0000
1from typing import Dict, Any, get_type_hints, get_args, Literal, Union
2import inspect
3from numbers import Integral
5import math
6from sklearn.utils._param_validation import Interval, StrOptions
8UI_INF = 1e6
11def get_default_class_values(cls) -> Dict[str, Any]:
12 sig = inspect.signature(cls.__init__)
13 return {
14 k: v.default
15 for k, v in sig.parameters.items()
16 if v.default is not inspect.Parameter.empty and k != "self"
17 }
20def get_param_descriptions(cls) -> Dict[str, str]:
21 import re
23 doc = inspect.getdoc(cls)
24 if not doc:
25 return {}
27 param_descriptions = {}
28 lines = doc.splitlines()
30 # locate "Parameters" section
31 for i, line in enumerate(lines):
32 if (
33 line.strip().lower() in {"parameters", "attributes"}
34 and i + 1 < len(lines)
35 and set(lines[i + 1].strip()) == {"-"}
36 ):
37 start_idx = i + 2
38 break
39 else:
40 return {}
42 i = start_idx
43 while i < len(lines):
44 line = lines[i].strip()
45 match = re.match(r"^(\w+)\s*:\s*([^,]+)(?:,\s*default=.*)?", line)
46 if match:
47 param_name = match.group(1)
48 desc_lines = []
49 i += 1
50 while i < len(lines) and (
51 lines[i].startswith(" ") or lines[i].strip() == ""
52 ):
53 desc_lines.append(lines[i].strip())
54 i += 1
55 param_descriptions[param_name] = " ".join(desc_lines).strip()
56 else:
57 i += 1
59 return param_descriptions
62def parse_sklearn_constraints(parameter_constraints) -> Dict[str, Dict[str, Any]]:
63 """Convert sklearn-style constraints to structured type info."""
65 def map_type(c):
66 if isinstance(c, Interval):
67 return float if c.type.__name__ == "Real" else int
68 if isinstance(c, str):
69 if c == "integer":
70 return int
71 if c in {"number", "float"}:
72 return float
73 if c == "boolean":
74 return bool
75 if c == "random_state":
76 return int
77 if isinstance(c, type):
78 if issubclass(c, Integral):
79 return int
80 if c.__name__ == "Real":
81 return float
82 if c in (int, float, str, bool):
83 return c
84 return None
86 def parse_bound(s):
87 if not isinstance(s, str):
88 return None
89 if s.startswith(">="):
90 return ("min", float(s[2:]))
91 if s.startswith("<="):
92 return ("max", float(s[2:]))
93 return None
95 param_spec = {}
97 for param_name, constraints in parameter_constraints.items():
98 types = set()
99 options = set()
100 bounds = {"min": None, "max": None, "closed": "both"}
102 if not constraints:
103 param_spec[param_name] = {
104 "type": None,
105 "bounds": bounds,
106 "options": None,
107 "allow_none": True,
108 }
109 continue
111 if None in constraints:
112 types.add(type(None))
114 for c in constraints:
115 if isinstance(c, Interval):
116 t = map_type(c)
117 if t:
118 types.add(t)
119 bounds["min"], bounds["max"], bounds["closed"] = (
120 c.left,
121 c.right,
122 c.closed,
123 )
124 elif isinstance(c, StrOptions):
125 options.update(c.options)
126 types.add(str)
127 elif isinstance(c, str):
128 t = map_type(c)
129 if t:
130 types.add(t)
131 b = parse_bound(c)
132 if b:
133 key, val = b
134 val = int(val) if val.is_integer() else val
135 bounds[key] = val
136 elif isinstance(c, type):
137 t = map_type(c)
138 if t:
139 types.add(t)
140 elif isinstance(c, set):
141 options.update(c)
142 types.add(str)
144 # Determine main type
145 if options:
146 selected_type = "categorical"
147 else:
148 selected_type = next(iter(types)) if types else None
150 param_spec[param_name] = {
151 "type": selected_type,
152 "bounds": bounds,
153 "options": sorted(options, key=str) if options else None,
154 "allow_none": type(None) in types,
155 }
157 return param_spec
160# --- Composite type resolver ---
161def anchor_type_to_default(entry: Dict[str, Any]) -> Dict[str, Any]:
162 """
163 If parameter has multiple possible types or categories,
164 restrict to the one matching the default’s type.
165 """
166 default = entry.get("default")
167 if default is inspect._empty:
168 return entry
170 default_type = type(default)
171 type_field = entry.get("type")
172 options = entry.get("options")
174 # if the default is None, keep only None
175 if default is None:
176 entry["type"] = type(None)
177 entry["options"] = None
178 entry["allow_none"] = True
179 return entry
181 # if type_field represents a composite (list, tuple, "multi", or set)
182 if isinstance(type_field, (list, tuple, set)) or type_field in (
183 "multi",
184 "categorical",
185 ):
186 entry["type"] = default_type
188 # if there are options but default doesn’t match option type
189 if options:
190 if not isinstance(default, str) and isinstance(options[0], str):
191 # default is not str, drop options entirely
192 entry["options"] = None
193 elif isinstance(default, str) and all(
194 isinstance(o, (int, float)) for o in options
195 ):
196 entry["options"] = None
198 # if options and allow_none but default not None -> drop allow_none
199 if entry.get("allow_none") and default is not None:
200 entry["allow_none"] = False
202 return entry
205# --- Widget inference ---
206def determine_widget(entry: Dict[str, Any]) -> Dict[str, Any]:
207 """Infer a UI widget and arguments from parameter metadata."""
208 t = entry.get("type")
209 options = entry.get("options")
210 bounds = entry.get("bounds", {"min": None, "max": None})
211 default = entry.get("default")
212 allow_none = entry.get("allow_none", False)
214 widget, widget_args = None, {}
216 # 1️⃣ Categorical / enum-like
217 if options:
218 widget = "select"
219 opts = list(options)
220 if allow_none:
221 opts = [None] + opts
222 widget_args = {
223 "options": opts,
224 "default": default if default in opts else opts[0],
225 }
227 # 2️⃣ Numeric sliders
228 elif t in (int, float):
229 min_val, max_val = bounds.get("min"), bounds.get("max")
230 # clean infinities
231 if isinstance(min_val, (int, float)) and not math.isfinite(min_val):
232 min_val = -UI_INF if t is int else 0.0
233 if isinstance(max_val, (int, float)) and not math.isfinite(max_val):
234 max_val = UI_INF if t is int else 10.0
235 step = 1 if t is int else 0.1
236 widget = "slider"
237 widget_args = {
238 "min": min_val,
239 "max": max_val,
240 "step": step,
241 "default": default if default is not None else (0 if t is int else 0.0),
242 }
244 # 3️⃣ Booleans
245 elif t is bool:
246 widget = "checkbox"
247 widget_args["default"] = bool(default)
249 # 4️⃣ Callable / dict / list
250 elif t in ("callable", dict, list):
251 widget = "text"
252 widget_args["default"] = str(default) if default is not None else ""
254 # 5️⃣ Strings or None
255 elif t in (str, type(None)):
256 widget = "text"
257 widget_args["default"] = "" if default is None else str(default)
259 # 6️⃣ Fallback
260 else:
261 widget = "text"
262 widget_args["default"] = str(default) if default is not None else ""
264 entry["widget"] = widget
265 entry["widget_args"] = widget_args
266 return entry
269# --- New unified builder ---
270def params_from_class(cls) -> Dict[str, Dict[str, Any]]:
271 """
272 Combine default values, docstrings, and sklearn parameter constraints
273 into a unified param specification dictionary.
275 Returns
276 -------
277 Dict[str, Dict[str, Any]]
278 param_name -> {
279 'default': Any,
280 'type': type or str,
281 'bounds': {'min':..., 'max':..., 'closed':...},
282 'options': list[str] or None,
283 'allow_none': bool,
284 'description': str
285 }
286 """
287 defaults = get_default_class_values(cls)
288 type_hints = get_type_hints(cls.__init__)
289 docs = get_param_descriptions(cls)
290 constraints = getattr(cls, "_parameter_constraints", {})
291 parsed_constraints = parse_sklearn_constraints(constraints)
292 all_params = set(defaults) | set(parsed_constraints) | set(docs)
293 spec = {}
295 for name in all_params:
296 if name.endswith("_"):
297 continue
298 default = defaults.get(name, inspect._empty)
299 entry = {
300 "default": default,
301 "description": docs.get(name, ""),
302 }
304 if name in parsed_constraints:
305 entry.update(parsed_constraints[name])
306 else:
307 entry.update(
308 {
309 "type": type(defaults[name]).__name__
310 if name in defaults and defaults[name] is not None
311 else None,
312 "bounds": {"min": None, "max": None, "closed": "both"},
313 "options": None,
314 "allow_none": defaults.get(name) is None,
315 }
316 )
318 # --- Apply type hints if constraints didn’t already provide type ---
319 if (name not in parsed_constraints) and (name in type_hints):
320 typ = type_hints[name]
322 # Literal → categorical
323 if getattr(typ, "__origin__", None) is Literal:
324 entry["type"] = str
325 entry["options"] = list(get_args(typ))
327 # Union → pick branch matching default
328 elif getattr(typ, "__origin__", None) is Union:
329 union_types = get_args(typ)
330 if default is None:
331 entry["type"] = type(None)
332 else:
333 for t in union_types:
334 if t is not type(None) and isinstance(default, t):
335 entry["type"] = t
336 break
337 else:
338 entry["type"] = type(default)
339 else:
340 entry["type"] = typ
342 inferred_type = type(default)
343 entry.update(parsed_constraints.get(name, {}))
344 entry["type"] = inferred_type
345 entry["allow_none"] = entry.get("allow_none", False) or default is None
347 entry = anchor_type_to_default(entry)
348 entry = determine_widget(entry)
350 spec[name] = entry
352 return spec