Coverage for tests / test_param_inference.py: 92%
62 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-03 15:41 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-03 15:41 +0000
1import pytest
2from numbers import Real, Integral
3from sklearn.utils._param_validation import Interval, StrOptions
5import tadkit.utils.param_spec as ps
8# --- Test get_default_class_values --- #
9def test_get_default_class_values():
10 class A:
11 def __init__(self, x=10, y="hello", z=None):
12 pass
14 defaults = ps.get_default_class_values(A)
15 assert defaults == {"x": 10, "y": "hello", "z": None}
18# --- Test get_param_descriptions --- #
19def test_get_param_descriptions():
20 class B:
21 """
22 Example class.
24 Parameters
25 ----------
26 x : int, default=1
27 The x parameter
28 y : str
29 The y parameter
30 """
32 def __init__(self, x=1, y="default"):
33 pass
35 desc = ps.get_param_descriptions(B)
36 assert "x" in desc and desc["x"] == "The x parameter"
37 assert "y" in desc and desc["y"] == "The y parameter"
40# --- Test parse_sklearn_constraints --- #
41def test_parse_sklearn_constraints_basic():
42 constraints = {
43 # Provide type, left, right, closed explicitly as per Interval signature
44 "a": [Interval(Real, 0.0, 10.0, closed="both")],
45 "b": ["integer", None],
46 "c": [StrOptions({"red", "green"})],
47 "d": [], # no constraint
48 }
50 parsed = ps.parse_sklearn_constraints(constraints)
52 # 'a' should map to float, with min=0, max=10
53 assert parsed["a"]["type"] == float
54 assert parsed["a"]["bounds"]["min"] == 0.0
55 assert parsed["a"]["bounds"]["max"] == 10.0
56 assert parsed["a"]["allow_none"] is False
58 # 'b' integer + None → type int, allow_none True
59 assert parsed["b"]["type"] == int
60 assert parsed["b"]["allow_none"] is True
62 # 'c' categorical string options
63 assert parsed["c"]["type"] == "categorical"
64 assert sorted(parsed["c"]["options"]) == ["green", "red"]
65 assert parsed["c"]["allow_none"] is False
67 # 'd' no constraints
68 assert parsed["d"]["type"] is None
69 assert parsed["d"]["allow_none"] is True
72# --- Test anchor_type_to_default --- #
73def test_anchor_type_to_default_behavior():
74 entry = {
75 "default": 5,
76 "type": (int, float),
77 "options": None,
78 "allow_none": True,
79 }
80 anchored = ps.anchor_type_to_default(dict(entry))
81 assert anchored["type"] is int
82 # default is not None, so allow_none should become False
83 assert anchored["allow_none"] is False
85 entry_none = {
86 "default": None,
87 "type": (int, float),
88 "options": ["a", "b"],
89 "allow_none": False,
90 }
91 anchored_none = ps.anchor_type_to_default(dict(entry_none))
92 assert anchored_none["type"] is type(None)
93 assert anchored_none["options"] is None
94 assert anchored_none["allow_none"] is True
97# --- Test determine_widget --- #
98@pytest.mark.parametrize(
99 "entry,expected_widget",
100 [
101 (
102 {
103 "type": int,
104 "default": 5,
105 "bounds": {"min": 0, "max": 10},
106 "allow_none": False,
107 "options": None,
108 },
109 "slider",
110 ),
111 (
112 {
113 "type": float,
114 "default": 2.5,
115 "bounds": {"min": 0.0, "max": 10.0},
116 "allow_none": False,
117 "options": None,
118 },
119 "slider",
120 ),
121 (
122 {
123 "type": bool,
124 "default": True,
125 "bounds": {"min": None, "max": None},
126 "allow_none": False,
127 "options": None,
128 },
129 "checkbox",
130 ),
131 (
132 {
133 "type": str,
134 "default": "abc",
135 "bounds": {"min": None, "max": None},
136 "allow_none": False,
137 "options": None,
138 },
139 "text",
140 ),
141 (
142 {
143 "type": type(None),
144 "default": None,
145 "bounds": {"min": None, "max": None},
146 "allow_none": True,
147 "options": None,
148 },
149 "text",
150 ),
151 (
152 {
153 "type": "categorical",
154 "default": "red",
155 "bounds": {"min": None, "max": None},
156 "allow_none": False,
157 "options": ["red", "green"],
158 },
159 "select",
160 ),
161 ],
162)
163def test_determine_widget_various(entry, expected_widget):
164 result = ps.determine_widget(dict(entry))
165 assert result["widget"] == expected_widget
166 assert "default" in result["widget_args"]
169# --- Test params_from_class --- #
170def test_params_from_class_combined():
171 class C:
172 """
173 Sample class.
175 Parameters
176 ----------
177 x : int, default=1
178 x param
179 y : str
180 y param
181 z : float
182 z param
183 """
185 _parameter_constraints = {
186 "x": [Interval(Integral, 0, 10, closed="both")],
187 "y": [StrOptions({"a", "b"})],
188 "z": ["number"],
189 }
191 def __init__(self, x=1, y="a", z=None):
192 self.x = x
193 self.y = y
194 self.z = z
196 spec = ps.params_from_class(C)
197 # x should exist and widget slider (int bounds)
198 assert "x" in spec
199 assert spec["x"]["type"] is int
200 assert spec["x"]["widget"] == "slider"
201 # y should exist, categorical select
202 assert spec["y"]["widget"] == "select"
203 assert sorted(spec["y"]["options"]) == ["a", "b"]
204 # z default is None, so type should reflect None or float, allow_none True
205 assert spec["z"]["allow_none"] is True
206 # Check description was populated
207 assert spec["x"]["description"] == "x param"
208 assert spec["y"]["description"] == "y param"