test_pprint.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. import re
  2. from pprint import PrettyPrinter
  3. import numpy as np
  4. from sklearn.utils._pprint import _EstimatorPrettyPrinter
  5. from sklearn.linear_model import LogisticRegressionCV
  6. from sklearn.pipeline import make_pipeline
  7. from sklearn.base import BaseEstimator, TransformerMixin
  8. from sklearn.feature_selection import SelectKBest, chi2
  9. from sklearn import config_context
  10. # Ignore flake8 (lots of line too long issues)
  11. # ruff: noqa
  12. # Constructors excerpted to test pprinting
  13. class LogisticRegression(BaseEstimator):
  14. def __init__(
  15. self,
  16. penalty="l2",
  17. dual=False,
  18. tol=1e-4,
  19. C=1.0,
  20. fit_intercept=True,
  21. intercept_scaling=1,
  22. class_weight=None,
  23. random_state=None,
  24. solver="warn",
  25. max_iter=100,
  26. multi_class="warn",
  27. verbose=0,
  28. warm_start=False,
  29. n_jobs=None,
  30. l1_ratio=None,
  31. ):
  32. self.penalty = penalty
  33. self.dual = dual
  34. self.tol = tol
  35. self.C = C
  36. self.fit_intercept = fit_intercept
  37. self.intercept_scaling = intercept_scaling
  38. self.class_weight = class_weight
  39. self.random_state = random_state
  40. self.solver = solver
  41. self.max_iter = max_iter
  42. self.multi_class = multi_class
  43. self.verbose = verbose
  44. self.warm_start = warm_start
  45. self.n_jobs = n_jobs
  46. self.l1_ratio = l1_ratio
  47. def fit(self, X, y):
  48. return self
  49. class StandardScaler(TransformerMixin, BaseEstimator):
  50. def __init__(self, copy=True, with_mean=True, with_std=True):
  51. self.with_mean = with_mean
  52. self.with_std = with_std
  53. self.copy = copy
  54. def transform(self, X, copy=None):
  55. return self
  56. class RFE(BaseEstimator):
  57. def __init__(self, estimator, n_features_to_select=None, step=1, verbose=0):
  58. self.estimator = estimator
  59. self.n_features_to_select = n_features_to_select
  60. self.step = step
  61. self.verbose = verbose
  62. class GridSearchCV(BaseEstimator):
  63. def __init__(
  64. self,
  65. estimator,
  66. param_grid,
  67. scoring=None,
  68. n_jobs=None,
  69. iid="warn",
  70. refit=True,
  71. cv="warn",
  72. verbose=0,
  73. pre_dispatch="2*n_jobs",
  74. error_score="raise-deprecating",
  75. return_train_score=False,
  76. ):
  77. self.estimator = estimator
  78. self.param_grid = param_grid
  79. self.scoring = scoring
  80. self.n_jobs = n_jobs
  81. self.iid = iid
  82. self.refit = refit
  83. self.cv = cv
  84. self.verbose = verbose
  85. self.pre_dispatch = pre_dispatch
  86. self.error_score = error_score
  87. self.return_train_score = return_train_score
  88. class CountVectorizer(BaseEstimator):
  89. def __init__(
  90. self,
  91. input="content",
  92. encoding="utf-8",
  93. decode_error="strict",
  94. strip_accents=None,
  95. lowercase=True,
  96. preprocessor=None,
  97. tokenizer=None,
  98. stop_words=None,
  99. token_pattern=r"(?u)\b\w\w+\b",
  100. ngram_range=(1, 1),
  101. analyzer="word",
  102. max_df=1.0,
  103. min_df=1,
  104. max_features=None,
  105. vocabulary=None,
  106. binary=False,
  107. dtype=np.int64,
  108. ):
  109. self.input = input
  110. self.encoding = encoding
  111. self.decode_error = decode_error
  112. self.strip_accents = strip_accents
  113. self.preprocessor = preprocessor
  114. self.tokenizer = tokenizer
  115. self.analyzer = analyzer
  116. self.lowercase = lowercase
  117. self.token_pattern = token_pattern
  118. self.stop_words = stop_words
  119. self.max_df = max_df
  120. self.min_df = min_df
  121. self.max_features = max_features
  122. self.ngram_range = ngram_range
  123. self.vocabulary = vocabulary
  124. self.binary = binary
  125. self.dtype = dtype
  126. class Pipeline(BaseEstimator):
  127. def __init__(self, steps, memory=None):
  128. self.steps = steps
  129. self.memory = memory
  130. class SVC(BaseEstimator):
  131. def __init__(
  132. self,
  133. C=1.0,
  134. kernel="rbf",
  135. degree=3,
  136. gamma="auto_deprecated",
  137. coef0=0.0,
  138. shrinking=True,
  139. probability=False,
  140. tol=1e-3,
  141. cache_size=200,
  142. class_weight=None,
  143. verbose=False,
  144. max_iter=-1,
  145. decision_function_shape="ovr",
  146. random_state=None,
  147. ):
  148. self.kernel = kernel
  149. self.degree = degree
  150. self.gamma = gamma
  151. self.coef0 = coef0
  152. self.tol = tol
  153. self.C = C
  154. self.shrinking = shrinking
  155. self.probability = probability
  156. self.cache_size = cache_size
  157. self.class_weight = class_weight
  158. self.verbose = verbose
  159. self.max_iter = max_iter
  160. self.decision_function_shape = decision_function_shape
  161. self.random_state = random_state
  162. class PCA(BaseEstimator):
  163. def __init__(
  164. self,
  165. n_components=None,
  166. copy=True,
  167. whiten=False,
  168. svd_solver="auto",
  169. tol=0.0,
  170. iterated_power="auto",
  171. random_state=None,
  172. ):
  173. self.n_components = n_components
  174. self.copy = copy
  175. self.whiten = whiten
  176. self.svd_solver = svd_solver
  177. self.tol = tol
  178. self.iterated_power = iterated_power
  179. self.random_state = random_state
  180. class NMF(BaseEstimator):
  181. def __init__(
  182. self,
  183. n_components=None,
  184. init=None,
  185. solver="cd",
  186. beta_loss="frobenius",
  187. tol=1e-4,
  188. max_iter=200,
  189. random_state=None,
  190. alpha=0.0,
  191. l1_ratio=0.0,
  192. verbose=0,
  193. shuffle=False,
  194. ):
  195. self.n_components = n_components
  196. self.init = init
  197. self.solver = solver
  198. self.beta_loss = beta_loss
  199. self.tol = tol
  200. self.max_iter = max_iter
  201. self.random_state = random_state
  202. self.alpha = alpha
  203. self.l1_ratio = l1_ratio
  204. self.verbose = verbose
  205. self.shuffle = shuffle
  206. class SimpleImputer(BaseEstimator):
  207. def __init__(
  208. self,
  209. missing_values=np.nan,
  210. strategy="mean",
  211. fill_value=None,
  212. verbose=0,
  213. copy=True,
  214. ):
  215. self.missing_values = missing_values
  216. self.strategy = strategy
  217. self.fill_value = fill_value
  218. self.verbose = verbose
  219. self.copy = copy
  220. def test_basic(print_changed_only_false):
  221. # Basic pprint test
  222. lr = LogisticRegression()
  223. expected = """
  224. LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
  225. intercept_scaling=1, l1_ratio=None, max_iter=100,
  226. multi_class='warn', n_jobs=None, penalty='l2',
  227. random_state=None, solver='warn', tol=0.0001, verbose=0,
  228. warm_start=False)"""
  229. expected = expected[1:] # remove first \n
  230. assert lr.__repr__() == expected
  231. def test_changed_only():
  232. # Make sure the changed_only param is correctly used when True (default)
  233. lr = LogisticRegression(C=99)
  234. expected = """LogisticRegression(C=99)"""
  235. assert lr.__repr__() == expected
  236. # Check with a repr that doesn't fit on a single line
  237. lr = LogisticRegression(
  238. C=99, class_weight=0.4, fit_intercept=False, tol=1234, verbose=True
  239. )
  240. expected = """
  241. LogisticRegression(C=99, class_weight=0.4, fit_intercept=False, tol=1234,
  242. verbose=True)"""
  243. expected = expected[1:] # remove first \n
  244. assert lr.__repr__() == expected
  245. imputer = SimpleImputer(missing_values=0)
  246. expected = """SimpleImputer(missing_values=0)"""
  247. assert imputer.__repr__() == expected
  248. # Defaults to np.nan, trying with float('NaN')
  249. imputer = SimpleImputer(missing_values=float("NaN"))
  250. expected = """SimpleImputer()"""
  251. assert imputer.__repr__() == expected
  252. # make sure array parameters don't throw error (see #13583)
  253. repr(LogisticRegressionCV(Cs=np.array([0.1, 1])))
  254. def test_pipeline(print_changed_only_false):
  255. # Render a pipeline object
  256. pipeline = make_pipeline(StandardScaler(), LogisticRegression(C=999))
  257. expected = """
  258. Pipeline(memory=None,
  259. steps=[('standardscaler',
  260. StandardScaler(copy=True, with_mean=True, with_std=True)),
  261. ('logisticregression',
  262. LogisticRegression(C=999, class_weight=None, dual=False,
  263. fit_intercept=True, intercept_scaling=1,
  264. l1_ratio=None, max_iter=100,
  265. multi_class='warn', n_jobs=None,
  266. penalty='l2', random_state=None,
  267. solver='warn', tol=0.0001, verbose=0,
  268. warm_start=False))],
  269. verbose=False)"""
  270. expected = expected[1:] # remove first \n
  271. assert pipeline.__repr__() == expected
  272. def test_deeply_nested(print_changed_only_false):
  273. # Render a deeply nested estimator
  274. rfe = RFE(RFE(RFE(RFE(RFE(RFE(RFE(LogisticRegression())))))))
  275. expected = """
  276. RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=LogisticRegression(C=1.0,
  277. class_weight=None,
  278. dual=False,
  279. fit_intercept=True,
  280. intercept_scaling=1,
  281. l1_ratio=None,
  282. max_iter=100,
  283. multi_class='warn',
  284. n_jobs=None,
  285. penalty='l2',
  286. random_state=None,
  287. solver='warn',
  288. tol=0.0001,
  289. verbose=0,
  290. warm_start=False),
  291. n_features_to_select=None,
  292. step=1,
  293. verbose=0),
  294. n_features_to_select=None,
  295. step=1,
  296. verbose=0),
  297. n_features_to_select=None,
  298. step=1, verbose=0),
  299. n_features_to_select=None, step=1,
  300. verbose=0),
  301. n_features_to_select=None, step=1, verbose=0),
  302. n_features_to_select=None, step=1, verbose=0),
  303. n_features_to_select=None, step=1, verbose=0)"""
  304. expected = expected[1:] # remove first \n
  305. assert rfe.__repr__() == expected
  306. def test_gridsearch(print_changed_only_false):
  307. # render a gridsearch
  308. param_grid = [
  309. {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]},
  310. {"kernel": ["linear"], "C": [1, 10, 100, 1000]},
  311. ]
  312. gs = GridSearchCV(SVC(), param_grid, cv=5)
  313. expected = """
  314. GridSearchCV(cv=5, error_score='raise-deprecating',
  315. estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  316. decision_function_shape='ovr', degree=3,
  317. gamma='auto_deprecated', kernel='rbf', max_iter=-1,
  318. probability=False, random_state=None, shrinking=True,
  319. tol=0.001, verbose=False),
  320. iid='warn', n_jobs=None,
  321. param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
  322. 'kernel': ['rbf']},
  323. {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
  324. pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
  325. scoring=None, verbose=0)"""
  326. expected = expected[1:] # remove first \n
  327. assert gs.__repr__() == expected
  328. def test_gridsearch_pipeline(print_changed_only_false):
  329. # render a pipeline inside a gridsearch
  330. pp = _EstimatorPrettyPrinter(compact=True, indent=1, indent_at_name=True)
  331. pipeline = Pipeline([("reduce_dim", PCA()), ("classify", SVC())])
  332. N_FEATURES_OPTIONS = [2, 4, 8]
  333. C_OPTIONS = [1, 10, 100, 1000]
  334. param_grid = [
  335. {
  336. "reduce_dim": [PCA(iterated_power=7), NMF()],
  337. "reduce_dim__n_components": N_FEATURES_OPTIONS,
  338. "classify__C": C_OPTIONS,
  339. },
  340. {
  341. "reduce_dim": [SelectKBest(chi2)],
  342. "reduce_dim__k": N_FEATURES_OPTIONS,
  343. "classify__C": C_OPTIONS,
  344. },
  345. ]
  346. gspipline = GridSearchCV(pipeline, cv=3, n_jobs=1, param_grid=param_grid)
  347. expected = """
  348. GridSearchCV(cv=3, error_score='raise-deprecating',
  349. estimator=Pipeline(memory=None,
  350. steps=[('reduce_dim',
  351. PCA(copy=True, iterated_power='auto',
  352. n_components=None,
  353. random_state=None,
  354. svd_solver='auto', tol=0.0,
  355. whiten=False)),
  356. ('classify',
  357. SVC(C=1.0, cache_size=200,
  358. class_weight=None, coef0=0.0,
  359. decision_function_shape='ovr',
  360. degree=3, gamma='auto_deprecated',
  361. kernel='rbf', max_iter=-1,
  362. probability=False,
  363. random_state=None, shrinking=True,
  364. tol=0.001, verbose=False))]),
  365. iid='warn', n_jobs=1,
  366. param_grid=[{'classify__C': [1, 10, 100, 1000],
  367. 'reduce_dim': [PCA(copy=True, iterated_power=7,
  368. n_components=None,
  369. random_state=None,
  370. svd_solver='auto', tol=0.0,
  371. whiten=False),
  372. NMF(alpha=0.0, beta_loss='frobenius',
  373. init=None, l1_ratio=0.0,
  374. max_iter=200, n_components=None,
  375. random_state=None, shuffle=False,
  376. solver='cd', tol=0.0001,
  377. verbose=0)],
  378. 'reduce_dim__n_components': [2, 4, 8]},
  379. {'classify__C': [1, 10, 100, 1000],
  380. 'reduce_dim': [SelectKBest(k=10,
  381. score_func=<function chi2 at some_address>)],
  382. 'reduce_dim__k': [2, 4, 8]}],
  383. pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
  384. scoring=None, verbose=0)"""
  385. expected = expected[1:] # remove first \n
  386. repr_ = pp.pformat(gspipline)
  387. # Remove address of '<function chi2 at 0x.....>' for reproducibility
  388. repr_ = re.sub("function chi2 at 0x.*>", "function chi2 at some_address>", repr_)
  389. assert repr_ == expected
  390. def test_n_max_elements_to_show(print_changed_only_false):
  391. n_max_elements_to_show = 30
  392. pp = _EstimatorPrettyPrinter(
  393. compact=True,
  394. indent=1,
  395. indent_at_name=True,
  396. n_max_elements_to_show=n_max_elements_to_show,
  397. )
  398. # No ellipsis
  399. vocabulary = {i: i for i in range(n_max_elements_to_show)}
  400. vectorizer = CountVectorizer(vocabulary=vocabulary)
  401. expected = r"""
  402. CountVectorizer(analyzer='word', binary=False, decode_error='strict',
  403. dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
  404. lowercase=True, max_df=1.0, max_features=None, min_df=1,
  405. ngram_range=(1, 1), preprocessor=None, stop_words=None,
  406. strip_accents=None, token_pattern='(?u)\\b\\w\\w+\\b',
  407. tokenizer=None,
  408. vocabulary={0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7,
  409. 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14,
  410. 15: 15, 16: 16, 17: 17, 18: 18, 19: 19, 20: 20,
  411. 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26,
  412. 27: 27, 28: 28, 29: 29})"""
  413. expected = expected[1:] # remove first \n
  414. assert pp.pformat(vectorizer) == expected
  415. # Now with ellipsis
  416. vocabulary = {i: i for i in range(n_max_elements_to_show + 1)}
  417. vectorizer = CountVectorizer(vocabulary=vocabulary)
  418. expected = r"""
  419. CountVectorizer(analyzer='word', binary=False, decode_error='strict',
  420. dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
  421. lowercase=True, max_df=1.0, max_features=None, min_df=1,
  422. ngram_range=(1, 1), preprocessor=None, stop_words=None,
  423. strip_accents=None, token_pattern='(?u)\\b\\w\\w+\\b',
  424. tokenizer=None,
  425. vocabulary={0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7,
  426. 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14,
  427. 15: 15, 16: 16, 17: 17, 18: 18, 19: 19, 20: 20,
  428. 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26,
  429. 27: 27, 28: 28, 29: 29, ...})"""
  430. expected = expected[1:] # remove first \n
  431. assert pp.pformat(vectorizer) == expected
  432. # Also test with lists
  433. param_grid = {"C": list(range(n_max_elements_to_show))}
  434. gs = GridSearchCV(SVC(), param_grid)
  435. expected = """
  436. GridSearchCV(cv='warn', error_score='raise-deprecating',
  437. estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  438. decision_function_shape='ovr', degree=3,
  439. gamma='auto_deprecated', kernel='rbf', max_iter=-1,
  440. probability=False, random_state=None, shrinking=True,
  441. tol=0.001, verbose=False),
  442. iid='warn', n_jobs=None,
  443. param_grid={'C': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
  444. 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
  445. 27, 28, 29]},
  446. pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
  447. scoring=None, verbose=0)"""
  448. expected = expected[1:] # remove first \n
  449. assert pp.pformat(gs) == expected
  450. # Now with ellipsis
  451. param_grid = {"C": list(range(n_max_elements_to_show + 1))}
  452. gs = GridSearchCV(SVC(), param_grid)
  453. expected = """
  454. GridSearchCV(cv='warn', error_score='raise-deprecating',
  455. estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  456. decision_function_shape='ovr', degree=3,
  457. gamma='auto_deprecated', kernel='rbf', max_iter=-1,
  458. probability=False, random_state=None, shrinking=True,
  459. tol=0.001, verbose=False),
  460. iid='warn', n_jobs=None,
  461. param_grid={'C': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
  462. 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
  463. 27, 28, 29, ...]},
  464. pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
  465. scoring=None, verbose=0)"""
  466. expected = expected[1:] # remove first \n
  467. assert pp.pformat(gs) == expected
  468. def test_bruteforce_ellipsis(print_changed_only_false):
  469. # Check that the bruteforce ellipsis (used when the number of non-blank
  470. # characters exceeds N_CHAR_MAX) renders correctly.
  471. lr = LogisticRegression()
  472. # test when the left and right side of the ellipsis aren't on the same
  473. # line.
  474. expected = """
  475. LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
  476. in...
  477. multi_class='warn', n_jobs=None, penalty='l2',
  478. random_state=None, solver='warn', tol=0.0001, verbose=0,
  479. warm_start=False)"""
  480. expected = expected[1:] # remove first \n
  481. assert expected == lr.__repr__(N_CHAR_MAX=150)
  482. # test with very small N_CHAR_MAX
  483. # Note that N_CHAR_MAX is not strictly enforced, but it's normal: to avoid
  484. # weird reprs we still keep the whole line of the right part (after the
  485. # ellipsis).
  486. expected = """
  487. Lo...
  488. warm_start=False)"""
  489. expected = expected[1:] # remove first \n
  490. assert expected == lr.__repr__(N_CHAR_MAX=4)
  491. # test with N_CHAR_MAX == number of non-blank characters: In this case we
  492. # don't want ellipsis
  493. full_repr = lr.__repr__(N_CHAR_MAX=float("inf"))
  494. n_nonblank = len("".join(full_repr.split()))
  495. assert lr.__repr__(N_CHAR_MAX=n_nonblank) == full_repr
  496. assert "..." not in full_repr
  497. # test with N_CHAR_MAX == number of non-blank characters - 10: the left and
  498. # right side of the ellispsis are on different lines. In this case we
  499. # want to expend the whole line of the right side
  500. expected = """
  501. LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
  502. intercept_scaling=1, l1_ratio=None, max_i...
  503. multi_class='warn', n_jobs=None, penalty='l2',
  504. random_state=None, solver='warn', tol=0.0001, verbose=0,
  505. warm_start=False)"""
  506. expected = expected[1:] # remove first \n
  507. assert expected == lr.__repr__(N_CHAR_MAX=n_nonblank - 10)
  508. # test with N_CHAR_MAX == number of non-blank characters - 10: the left and
  509. # right side of the ellispsis are on the same line. In this case we don't
  510. # want to expend the whole line of the right side, just add the ellispsis
  511. # between the 2 sides.
  512. expected = """
  513. LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
  514. intercept_scaling=1, l1_ratio=None, max_iter...,
  515. multi_class='warn', n_jobs=None, penalty='l2',
  516. random_state=None, solver='warn', tol=0.0001, verbose=0,
  517. warm_start=False)"""
  518. expected = expected[1:] # remove first \n
  519. assert expected == lr.__repr__(N_CHAR_MAX=n_nonblank - 4)
  520. # test with N_CHAR_MAX == number of non-blank characters - 2: the left and
  521. # right side of the ellispsis are on the same line, but adding the ellipsis
  522. # would actually make the repr longer. So we don't add the ellipsis.
  523. expected = """
  524. LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
  525. intercept_scaling=1, l1_ratio=None, max_iter=100,
  526. multi_class='warn', n_jobs=None, penalty='l2',
  527. random_state=None, solver='warn', tol=0.0001, verbose=0,
  528. warm_start=False)"""
  529. expected = expected[1:] # remove first \n
  530. assert expected == lr.__repr__(N_CHAR_MAX=n_nonblank - 2)
  531. def test_builtin_prettyprinter():
  532. # non regression test than ensures we can still use the builtin
  533. # PrettyPrinter class for estimators (as done e.g. by joblib).
  534. # Used to be a bug
  535. PrettyPrinter().pprint(LogisticRegression())
  536. def test_kwargs_in_init():
  537. # Make sure the changed_only=True mode is OK when an argument is passed as
  538. # kwargs.
  539. # Non-regression test for
  540. # https://github.com/scikit-learn/scikit-learn/issues/17206
  541. class WithKWargs(BaseEstimator):
  542. # Estimator with a kwargs argument. These need to hack around
  543. # set_params and get_params. Here we mimic what LightGBM does.
  544. def __init__(self, a="willchange", b="unchanged", **kwargs):
  545. self.a = a
  546. self.b = b
  547. self._other_params = {}
  548. self.set_params(**kwargs)
  549. def get_params(self, deep=True):
  550. params = super().get_params(deep=deep)
  551. params.update(self._other_params)
  552. return params
  553. def set_params(self, **params):
  554. for key, value in params.items():
  555. setattr(self, key, value)
  556. self._other_params[key] = value
  557. return self
  558. est = WithKWargs(a="something", c="abcd", d=None)
  559. expected = "WithKWargs(a='something', c='abcd', d=None)"
  560. assert expected == est.__repr__()
  561. with config_context(print_changed_only=False):
  562. expected = "WithKWargs(a='something', b='unchanged', c='abcd', d=None)"
  563. assert expected == est.__repr__()
  564. def test_complexity_print_changed_only():
  565. # Make sure `__repr__` is called the same amount of times
  566. # whether `print_changed_only` is True or False
  567. # Non-regression test for
  568. # https://github.com/scikit-learn/scikit-learn/issues/18490
  569. class DummyEstimator(TransformerMixin, BaseEstimator):
  570. nb_times_repr_called = 0
  571. def __init__(self, estimator=None):
  572. self.estimator = estimator
  573. def __repr__(self):
  574. DummyEstimator.nb_times_repr_called += 1
  575. return super().__repr__()
  576. def transform(self, X, copy=None): # pragma: no cover
  577. return X
  578. estimator = DummyEstimator(
  579. make_pipeline(DummyEstimator(DummyEstimator()), DummyEstimator(), "passthrough")
  580. )
  581. with config_context(print_changed_only=False):
  582. repr(estimator)
  583. nb_repr_print_changed_only_false = DummyEstimator.nb_times_repr_called
  584. DummyEstimator.nb_times_repr_called = 0
  585. with config_context(print_changed_only=True):
  586. repr(estimator)
  587. nb_repr_print_changed_only_true = DummyEstimator.nb_times_repr_called
  588. assert nb_repr_print_changed_only_false == nb_repr_print_changed_only_true