Coverage for tests / test_param_inference.py: 92%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-03 15:41 +0000

1import pytest 

2from numbers import Real, Integral 

3from sklearn.utils._param_validation import Interval, StrOptions 

4 

5import tadkit.utils.param_spec as ps 

6 

7 

8# --- Test get_default_class_values --- # 

9def test_get_default_class_values(): 

10 class A: 

11 def __init__(self, x=10, y="hello", z=None): 

12 pass 

13 

14 defaults = ps.get_default_class_values(A) 

15 assert defaults == {"x": 10, "y": "hello", "z": None} 

16 

17 

18# --- Test get_param_descriptions --- # 

19def test_get_param_descriptions(): 

20 class B: 

21 """ 

22 Example class. 

23 

24 Parameters 

25 ---------- 

26 x : int, default=1 

27 The x parameter 

28 y : str 

29 The y parameter 

30 """ 

31 

32 def __init__(self, x=1, y="default"): 

33 pass 

34 

35 desc = ps.get_param_descriptions(B) 

36 assert "x" in desc and desc["x"] == "The x parameter" 

37 assert "y" in desc and desc["y"] == "The y parameter" 

38 

39 

40# --- Test parse_sklearn_constraints --- # 

41def test_parse_sklearn_constraints_basic(): 

42 constraints = { 

43 # Provide type, left, right, closed explicitly as per Interval signature 

44 "a": [Interval(Real, 0.0, 10.0, closed="both")], 

45 "b": ["integer", None], 

46 "c": [StrOptions({"red", "green"})], 

47 "d": [], # no constraint 

48 } 

49 

50 parsed = ps.parse_sklearn_constraints(constraints) 

51 

52 # 'a' should map to float, with min=0, max=10 

53 assert parsed["a"]["type"] == float 

54 assert parsed["a"]["bounds"]["min"] == 0.0 

55 assert parsed["a"]["bounds"]["max"] == 10.0 

56 assert parsed["a"]["allow_none"] is False 

57 

58 # 'b' integer + None → type int, allow_none True 

59 assert parsed["b"]["type"] == int 

60 assert parsed["b"]["allow_none"] is True 

61 

62 # 'c' categorical string options 

63 assert parsed["c"]["type"] == "categorical" 

64 assert sorted(parsed["c"]["options"]) == ["green", "red"] 

65 assert parsed["c"]["allow_none"] is False 

66 

67 # 'd' no constraints 

68 assert parsed["d"]["type"] is None 

69 assert parsed["d"]["allow_none"] is True 

70 

71 

72# --- Test anchor_type_to_default --- # 

73def test_anchor_type_to_default_behavior(): 

74 entry = { 

75 "default": 5, 

76 "type": (int, float), 

77 "options": None, 

78 "allow_none": True, 

79 } 

80 anchored = ps.anchor_type_to_default(dict(entry)) 

81 assert anchored["type"] is int 

82 # default is not None, so allow_none should become False 

83 assert anchored["allow_none"] is False 

84 

85 entry_none = { 

86 "default": None, 

87 "type": (int, float), 

88 "options": ["a", "b"], 

89 "allow_none": False, 

90 } 

91 anchored_none = ps.anchor_type_to_default(dict(entry_none)) 

92 assert anchored_none["type"] is type(None) 

93 assert anchored_none["options"] is None 

94 assert anchored_none["allow_none"] is True 

95 

96 

97# --- Test determine_widget --- # 

98@pytest.mark.parametrize( 

99 "entry,expected_widget", 

100 [ 

101 ( 

102 { 

103 "type": int, 

104 "default": 5, 

105 "bounds": {"min": 0, "max": 10}, 

106 "allow_none": False, 

107 "options": None, 

108 }, 

109 "slider", 

110 ), 

111 ( 

112 { 

113 "type": float, 

114 "default": 2.5, 

115 "bounds": {"min": 0.0, "max": 10.0}, 

116 "allow_none": False, 

117 "options": None, 

118 }, 

119 "slider", 

120 ), 

121 ( 

122 { 

123 "type": bool, 

124 "default": True, 

125 "bounds": {"min": None, "max": None}, 

126 "allow_none": False, 

127 "options": None, 

128 }, 

129 "checkbox", 

130 ), 

131 ( 

132 { 

133 "type": str, 

134 "default": "abc", 

135 "bounds": {"min": None, "max": None}, 

136 "allow_none": False, 

137 "options": None, 

138 }, 

139 "text", 

140 ), 

141 ( 

142 { 

143 "type": type(None), 

144 "default": None, 

145 "bounds": {"min": None, "max": None}, 

146 "allow_none": True, 

147 "options": None, 

148 }, 

149 "text", 

150 ), 

151 ( 

152 { 

153 "type": "categorical", 

154 "default": "red", 

155 "bounds": {"min": None, "max": None}, 

156 "allow_none": False, 

157 "options": ["red", "green"], 

158 }, 

159 "select", 

160 ), 

161 ], 

162) 

163def test_determine_widget_various(entry, expected_widget): 

164 result = ps.determine_widget(dict(entry)) 

165 assert result["widget"] == expected_widget 

166 assert "default" in result["widget_args"] 

167 

168 

169# --- Test params_from_class --- # 

170def test_params_from_class_combined(): 

171 class C: 

172 """ 

173 Sample class. 

174 

175 Parameters 

176 ---------- 

177 x : int, default=1 

178 x param 

179 y : str 

180 y param 

181 z : float 

182 z param 

183 """ 

184 

185 _parameter_constraints = { 

186 "x": [Interval(Integral, 0, 10, closed="both")], 

187 "y": [StrOptions({"a", "b"})], 

188 "z": ["number"], 

189 } 

190 

191 def __init__(self, x=1, y="a", z=None): 

192 self.x = x 

193 self.y = y 

194 self.z = z 

195 

196 spec = ps.params_from_class(C) 

197 # x should exist and widget slider (int bounds) 

198 assert "x" in spec 

199 assert spec["x"]["type"] is int 

200 assert spec["x"]["widget"] == "slider" 

201 # y should exist, categorical select 

202 assert spec["y"]["widget"] == "select" 

203 assert sorted(spec["y"]["options"]) == ["a", "b"] 

204 # z default is None, so type should reflect None or float, allow_none True 

205 assert spec["z"]["allow_none"] is True 

206 # Check description was populated 

207 assert spec["x"]["description"] == "x param" 

208 assert spec["y"]["description"] == "y param"