Coverage for tadkit / utils / param_spec.py: 82%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-03 15:41 +0000

1from typing import Dict, Any, get_type_hints, get_args, Literal, Union 

2import inspect 

3from numbers import Integral 

4 

5import math 

6from sklearn.utils._param_validation import Interval, StrOptions 

7 

8UI_INF = 1e6 

9 

10 

11def get_default_class_values(cls) -> Dict[str, Any]: 

12 sig = inspect.signature(cls.__init__) 

13 return { 

14 k: v.default 

15 for k, v in sig.parameters.items() 

16 if v.default is not inspect.Parameter.empty and k != "self" 

17 } 

18 

19 

20def get_param_descriptions(cls) -> Dict[str, str]: 

21 import re 

22 

23 doc = inspect.getdoc(cls) 

24 if not doc: 

25 return {} 

26 

27 param_descriptions = {} 

28 lines = doc.splitlines() 

29 

30 # locate "Parameters" section 

31 for i, line in enumerate(lines): 

32 if ( 

33 line.strip().lower() in {"parameters", "attributes"} 

34 and i + 1 < len(lines) 

35 and set(lines[i + 1].strip()) == {"-"} 

36 ): 

37 start_idx = i + 2 

38 break 

39 else: 

40 return {} 

41 

42 i = start_idx 

43 while i < len(lines): 

44 line = lines[i].strip() 

45 match = re.match(r"^(\w+)\s*:\s*([^,]+)(?:,\s*default=.*)?", line) 

46 if match: 

47 param_name = match.group(1) 

48 desc_lines = [] 

49 i += 1 

50 while i < len(lines) and ( 

51 lines[i].startswith(" ") or lines[i].strip() == "" 

52 ): 

53 desc_lines.append(lines[i].strip()) 

54 i += 1 

55 param_descriptions[param_name] = " ".join(desc_lines).strip() 

56 else: 

57 i += 1 

58 

59 return param_descriptions 

60 

61 

62def parse_sklearn_constraints(parameter_constraints) -> Dict[str, Dict[str, Any]]: 

63 """Convert sklearn-style constraints to structured type info.""" 

64 

65 def map_type(c): 

66 if isinstance(c, Interval): 

67 return float if c.type.__name__ == "Real" else int 

68 if isinstance(c, str): 

69 if c == "integer": 

70 return int 

71 if c in {"number", "float"}: 

72 return float 

73 if c == "boolean": 

74 return bool 

75 if c == "random_state": 

76 return int 

77 if isinstance(c, type): 

78 if issubclass(c, Integral): 

79 return int 

80 if c.__name__ == "Real": 

81 return float 

82 if c in (int, float, str, bool): 

83 return c 

84 return None 

85 

86 def parse_bound(s): 

87 if not isinstance(s, str): 

88 return None 

89 if s.startswith(">="): 

90 return ("min", float(s[2:])) 

91 if s.startswith("<="): 

92 return ("max", float(s[2:])) 

93 return None 

94 

95 param_spec = {} 

96 

97 for param_name, constraints in parameter_constraints.items(): 

98 types = set() 

99 options = set() 

100 bounds = {"min": None, "max": None, "closed": "both"} 

101 

102 if not constraints: 

103 param_spec[param_name] = { 

104 "type": None, 

105 "bounds": bounds, 

106 "options": None, 

107 "allow_none": True, 

108 } 

109 continue 

110 

111 if None in constraints: 

112 types.add(type(None)) 

113 

114 for c in constraints: 

115 if isinstance(c, Interval): 

116 t = map_type(c) 

117 if t: 

118 types.add(t) 

119 bounds["min"], bounds["max"], bounds["closed"] = ( 

120 c.left, 

121 c.right, 

122 c.closed, 

123 ) 

124 elif isinstance(c, StrOptions): 

125 options.update(c.options) 

126 types.add(str) 

127 elif isinstance(c, str): 

128 t = map_type(c) 

129 if t: 

130 types.add(t) 

131 b = parse_bound(c) 

132 if b: 

133 key, val = b 

134 val = int(val) if val.is_integer() else val 

135 bounds[key] = val 

136 elif isinstance(c, type): 

137 t = map_type(c) 

138 if t: 

139 types.add(t) 

140 elif isinstance(c, set): 

141 options.update(c) 

142 types.add(str) 

143 

144 # Determine main type 

145 if options: 

146 selected_type = "categorical" 

147 else: 

148 selected_type = next(iter(types)) if types else None 

149 

150 param_spec[param_name] = { 

151 "type": selected_type, 

152 "bounds": bounds, 

153 "options": sorted(options, key=str) if options else None, 

154 "allow_none": type(None) in types, 

155 } 

156 

157 return param_spec 

158 

159 

160# --- Composite type resolver --- 

161def anchor_type_to_default(entry: Dict[str, Any]) -> Dict[str, Any]: 

162 """ 

163 If parameter has multiple possible types or categories, 

164 restrict to the one matching the default’s type. 

165 """ 

166 default = entry.get("default") 

167 if default is inspect._empty: 

168 return entry 

169 

170 default_type = type(default) 

171 type_field = entry.get("type") 

172 options = entry.get("options") 

173 

174 # if the default is None, keep only None 

175 if default is None: 

176 entry["type"] = type(None) 

177 entry["options"] = None 

178 entry["allow_none"] = True 

179 return entry 

180 

181 # if type_field represents a composite (list, tuple, "multi", or set) 

182 if isinstance(type_field, (list, tuple, set)) or type_field in ( 

183 "multi", 

184 "categorical", 

185 ): 

186 entry["type"] = default_type 

187 

188 # if there are options but default doesn’t match option type 

189 if options: 

190 if not isinstance(default, str) and isinstance(options[0], str): 

191 # default is not str, drop options entirely 

192 entry["options"] = None 

193 elif isinstance(default, str) and all( 

194 isinstance(o, (int, float)) for o in options 

195 ): 

196 entry["options"] = None 

197 

198 # if options and allow_none but default not None -> drop allow_none 

199 if entry.get("allow_none") and default is not None: 

200 entry["allow_none"] = False 

201 

202 return entry 

203 

204 

205# --- Widget inference --- 

206def determine_widget(entry: Dict[str, Any]) -> Dict[str, Any]: 

207 """Infer a UI widget and arguments from parameter metadata.""" 

208 t = entry.get("type") 

209 options = entry.get("options") 

210 bounds = entry.get("bounds", {"min": None, "max": None}) 

211 default = entry.get("default") 

212 allow_none = entry.get("allow_none", False) 

213 

214 widget, widget_args = None, {} 

215 

216 # 1️⃣ Categorical / enum-like 

217 if options: 

218 widget = "select" 

219 opts = list(options) 

220 if allow_none: 

221 opts = [None] + opts 

222 widget_args = { 

223 "options": opts, 

224 "default": default if default in opts else opts[0], 

225 } 

226 

227 # 2️⃣ Numeric sliders 

228 elif t in (int, float): 

229 min_val, max_val = bounds.get("min"), bounds.get("max") 

230 # clean infinities 

231 if isinstance(min_val, (int, float)) and not math.isfinite(min_val): 

232 min_val = -UI_INF if t is int else 0.0 

233 if isinstance(max_val, (int, float)) and not math.isfinite(max_val): 

234 max_val = UI_INF if t is int else 10.0 

235 step = 1 if t is int else 0.1 

236 widget = "slider" 

237 widget_args = { 

238 "min": min_val, 

239 "max": max_val, 

240 "step": step, 

241 "default": default if default is not None else (0 if t is int else 0.0), 

242 } 

243 

244 # 3️⃣ Booleans 

245 elif t is bool: 

246 widget = "checkbox" 

247 widget_args["default"] = bool(default) 

248 

249 # 4️⃣ Callable / dict / list 

250 elif t in ("callable", dict, list): 

251 widget = "text" 

252 widget_args["default"] = str(default) if default is not None else "" 

253 

254 # 5️⃣ Strings or None 

255 elif t in (str, type(None)): 

256 widget = "text" 

257 widget_args["default"] = "" if default is None else str(default) 

258 

259 # 6️⃣ Fallback 

260 else: 

261 widget = "text" 

262 widget_args["default"] = str(default) if default is not None else "" 

263 

264 entry["widget"] = widget 

265 entry["widget_args"] = widget_args 

266 return entry 

267 

268 

269# --- New unified builder --- 

270def params_from_class(cls) -> Dict[str, Dict[str, Any]]: 

271 """ 

272 Combine default values, docstrings, and sklearn parameter constraints 

273 into a unified param specification dictionary. 

274 

275 Returns 

276 ------- 

277 Dict[str, Dict[str, Any]] 

278 param_name -> { 

279 'default': Any, 

280 'type': type or str, 

281 'bounds': {'min':..., 'max':..., 'closed':...}, 

282 'options': list[str] or None, 

283 'allow_none': bool, 

284 'description': str 

285 } 

286 """ 

287 defaults = get_default_class_values(cls) 

288 type_hints = get_type_hints(cls.__init__) 

289 docs = get_param_descriptions(cls) 

290 constraints = getattr(cls, "_parameter_constraints", {}) 

291 parsed_constraints = parse_sklearn_constraints(constraints) 

292 all_params = set(defaults) | set(parsed_constraints) | set(docs) 

293 spec = {} 

294 

295 for name in all_params: 

296 if name.endswith("_"): 

297 continue 

298 default = defaults.get(name, inspect._empty) 

299 entry = { 

300 "default": default, 

301 "description": docs.get(name, ""), 

302 } 

303 

304 if name in parsed_constraints: 

305 entry.update(parsed_constraints[name]) 

306 else: 

307 entry.update( 

308 { 

309 "type": type(defaults[name]).__name__ 

310 if name in defaults and defaults[name] is not None 

311 else None, 

312 "bounds": {"min": None, "max": None, "closed": "both"}, 

313 "options": None, 

314 "allow_none": defaults.get(name) is None, 

315 } 

316 ) 

317 

318 # --- Apply type hints if constraints didn’t already provide type --- 

319 if (name not in parsed_constraints) and (name in type_hints): 

320 typ = type_hints[name] 

321 

322 # Literal → categorical 

323 if getattr(typ, "__origin__", None) is Literal: 

324 entry["type"] = str 

325 entry["options"] = list(get_args(typ)) 

326 

327 # Union → pick branch matching default 

328 elif getattr(typ, "__origin__", None) is Union: 

329 union_types = get_args(typ) 

330 if default is None: 

331 entry["type"] = type(None) 

332 else: 

333 for t in union_types: 

334 if t is not type(None) and isinstance(default, t): 

335 entry["type"] = t 

336 break 

337 else: 

338 entry["type"] = type(default) 

339 else: 

340 entry["type"] = typ 

341 

342 inferred_type = type(default) 

343 entry.update(parsed_constraints.get(name, {})) 

344 entry["type"] = inferred_type 

345 entry["allow_none"] = entry.get("allow_none", False) or default is None 

346 

347 entry = anchor_type_to_default(entry) 

348 entry = determine_widget(entry) 

349 

350 spec[name] = entry 

351 

352 return spec