test_openml.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684
  1. """Test the openml loader."""
  2. import gzip
  3. import json
  4. import os
  5. import re
  6. from functools import partial
  7. from io import BytesIO
  8. from urllib.error import HTTPError
  9. import numpy as np
  10. import pytest
  11. import scipy.sparse
  12. import sklearn
  13. from sklearn import config_context
  14. from sklearn.datasets import fetch_openml as fetch_openml_orig
  15. from sklearn.datasets._openml import (
  16. _OPENML_PREFIX,
  17. _get_local_path,
  18. _open_openml_url,
  19. _retry_with_clean_cache,
  20. )
  21. from sklearn.utils import Bunch, check_pandas_support
  22. from sklearn.utils._testing import (
  23. SkipTest,
  24. assert_allclose,
  25. assert_array_equal,
  26. fails_if_pypy,
  27. )
  28. from sklearn.utils.fixes import _open_binary
  29. OPENML_TEST_DATA_MODULE = "sklearn.datasets.tests.data.openml"
  30. # if True, urlopen will be monkey patched to only use local files
  31. test_offline = True
  32. class _MockHTTPResponse:
  33. def __init__(self, data, is_gzip):
  34. self.data = data
  35. self.is_gzip = is_gzip
  36. def read(self, amt=-1):
  37. return self.data.read(amt)
  38. def close(self):
  39. self.data.close()
  40. def info(self):
  41. if self.is_gzip:
  42. return {"Content-Encoding": "gzip"}
  43. return {}
  44. def __iter__(self):
  45. return iter(self.data)
  46. def __enter__(self):
  47. return self
  48. def __exit__(self, exc_type, exc_val, exc_tb):
  49. return False
  50. # Disable the disk-based cache when testing `fetch_openml`:
  51. # the mock data in sklearn/datasets/tests/data/openml/ is not always consistent
  52. # with the version on openml.org. If one were to load the dataset outside of
  53. # the tests, it may result in data that does not represent openml.org.
  54. fetch_openml = partial(fetch_openml_orig, data_home=None)
  55. def _monkey_patch_webbased_functions(context, data_id, gzip_response):
  56. # monkey patches the urlopen function. Important note: Do NOT use this
  57. # in combination with a regular cache directory, as the files that are
  58. # stored as cache should not be mixed up with real openml datasets
  59. url_prefix_data_description = "https://api.openml.org/api/v1/json/data/"
  60. url_prefix_data_features = "https://api.openml.org/api/v1/json/data/features/"
  61. url_prefix_download_data = "https://api.openml.org/data/v1/"
  62. url_prefix_data_list = "https://api.openml.org/api/v1/json/data/list/"
  63. path_suffix = ".gz"
  64. read_fn = gzip.open
  65. data_module = OPENML_TEST_DATA_MODULE + "." + f"id_{data_id}"
  66. def _file_name(url, suffix):
  67. output = (
  68. re.sub(r"\W", "-", url[len("https://api.openml.org/") :])
  69. + suffix
  70. + path_suffix
  71. )
  72. # Shorten the filenames to have better compatibility with windows 10
  73. # and filenames > 260 characters
  74. return (
  75. output.replace("-json-data-list", "-jdl")
  76. .replace("-json-data-features", "-jdf")
  77. .replace("-json-data-qualities", "-jdq")
  78. .replace("-json-data", "-jd")
  79. .replace("-data_name", "-dn")
  80. .replace("-download", "-dl")
  81. .replace("-limit", "-l")
  82. .replace("-data_version", "-dv")
  83. .replace("-status", "-s")
  84. .replace("-deactivated", "-dact")
  85. .replace("-active", "-act")
  86. )
  87. def _mock_urlopen_shared(url, has_gzip_header, expected_prefix, suffix):
  88. assert url.startswith(expected_prefix)
  89. data_file_name = _file_name(url, suffix)
  90. with _open_binary(data_module, data_file_name) as f:
  91. if has_gzip_header and gzip_response:
  92. fp = BytesIO(f.read())
  93. return _MockHTTPResponse(fp, True)
  94. else:
  95. decompressed_f = read_fn(f, "rb")
  96. fp = BytesIO(decompressed_f.read())
  97. return _MockHTTPResponse(fp, False)
  98. def _mock_urlopen_data_description(url, has_gzip_header):
  99. return _mock_urlopen_shared(
  100. url=url,
  101. has_gzip_header=has_gzip_header,
  102. expected_prefix=url_prefix_data_description,
  103. suffix=".json",
  104. )
  105. def _mock_urlopen_data_features(url, has_gzip_header):
  106. return _mock_urlopen_shared(
  107. url=url,
  108. has_gzip_header=has_gzip_header,
  109. expected_prefix=url_prefix_data_features,
  110. suffix=".json",
  111. )
  112. def _mock_urlopen_download_data(url, has_gzip_header):
  113. return _mock_urlopen_shared(
  114. url=url,
  115. has_gzip_header=has_gzip_header,
  116. expected_prefix=url_prefix_download_data,
  117. suffix=".arff",
  118. )
  119. def _mock_urlopen_data_list(url, has_gzip_header):
  120. assert url.startswith(url_prefix_data_list)
  121. data_file_name = _file_name(url, ".json")
  122. # load the file itself, to simulate a http error
  123. with _open_binary(data_module, data_file_name) as f:
  124. decompressed_f = read_fn(f, "rb")
  125. decoded_s = decompressed_f.read().decode("utf-8")
  126. json_data = json.loads(decoded_s)
  127. if "error" in json_data:
  128. raise HTTPError(
  129. url=None, code=412, msg="Simulated mock error", hdrs=None, fp=None
  130. )
  131. with _open_binary(data_module, data_file_name) as f:
  132. if has_gzip_header:
  133. fp = BytesIO(f.read())
  134. return _MockHTTPResponse(fp, True)
  135. else:
  136. decompressed_f = read_fn(f, "rb")
  137. fp = BytesIO(decompressed_f.read())
  138. return _MockHTTPResponse(fp, False)
  139. def _mock_urlopen(request, *args, **kwargs):
  140. url = request.get_full_url()
  141. has_gzip_header = request.get_header("Accept-encoding") == "gzip"
  142. if url.startswith(url_prefix_data_list):
  143. return _mock_urlopen_data_list(url, has_gzip_header)
  144. elif url.startswith(url_prefix_data_features):
  145. return _mock_urlopen_data_features(url, has_gzip_header)
  146. elif url.startswith(url_prefix_download_data):
  147. return _mock_urlopen_download_data(url, has_gzip_header)
  148. elif url.startswith(url_prefix_data_description):
  149. return _mock_urlopen_data_description(url, has_gzip_header)
  150. else:
  151. raise ValueError("Unknown mocking URL pattern: %s" % url)
  152. # XXX: Global variable
  153. if test_offline:
  154. context.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen)
  155. ###############################################################################
  156. # Test the behaviour of `fetch_openml` depending of the input parameters.
  157. # Known failure of PyPy for OpenML. See the following issue:
  158. # https://github.com/scikit-learn/scikit-learn/issues/18906
  159. @fails_if_pypy
  160. @pytest.mark.parametrize(
  161. "data_id, dataset_params, n_samples, n_features, n_targets",
  162. [
  163. # iris
  164. (61, {"data_id": 61}, 150, 4, 1),
  165. (61, {"name": "iris", "version": 1}, 150, 4, 1),
  166. # anneal
  167. (2, {"data_id": 2}, 11, 38, 1),
  168. (2, {"name": "anneal", "version": 1}, 11, 38, 1),
  169. # cpu
  170. (561, {"data_id": 561}, 209, 7, 1),
  171. (561, {"name": "cpu", "version": 1}, 209, 7, 1),
  172. # emotions
  173. (40589, {"data_id": 40589}, 13, 72, 6),
  174. # adult-census
  175. (1119, {"data_id": 1119}, 10, 14, 1),
  176. (1119, {"name": "adult-census"}, 10, 14, 1),
  177. # miceprotein
  178. (40966, {"data_id": 40966}, 7, 77, 1),
  179. (40966, {"name": "MiceProtein"}, 7, 77, 1),
  180. # titanic
  181. (40945, {"data_id": 40945}, 1309, 13, 1),
  182. ],
  183. )
  184. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  185. @pytest.mark.parametrize("gzip_response", [True, False])
  186. def test_fetch_openml_as_frame_true(
  187. monkeypatch,
  188. data_id,
  189. dataset_params,
  190. n_samples,
  191. n_features,
  192. n_targets,
  193. parser,
  194. gzip_response,
  195. ):
  196. """Check the behaviour of `fetch_openml` with `as_frame=True`.
  197. Fetch by ID and/or name (depending if the file was previously cached).
  198. """
  199. pd = pytest.importorskip("pandas")
  200. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
  201. bunch = fetch_openml(
  202. as_frame=True,
  203. cache=False,
  204. parser=parser,
  205. **dataset_params,
  206. )
  207. assert int(bunch.details["id"]) == data_id
  208. assert isinstance(bunch, Bunch)
  209. assert isinstance(bunch.frame, pd.DataFrame)
  210. assert bunch.frame.shape == (n_samples, n_features + n_targets)
  211. assert isinstance(bunch.data, pd.DataFrame)
  212. assert bunch.data.shape == (n_samples, n_features)
  213. if n_targets == 1:
  214. assert isinstance(bunch.target, pd.Series)
  215. assert bunch.target.shape == (n_samples,)
  216. else:
  217. assert isinstance(bunch.target, pd.DataFrame)
  218. assert bunch.target.shape == (n_samples, n_targets)
  219. assert bunch.categories is None
  220. # Known failure of PyPy for OpenML. See the following issue:
  221. # https://github.com/scikit-learn/scikit-learn/issues/18906
  222. @fails_if_pypy
  223. @pytest.mark.parametrize(
  224. "data_id, dataset_params, n_samples, n_features, n_targets",
  225. [
  226. # iris
  227. (61, {"data_id": 61}, 150, 4, 1),
  228. (61, {"name": "iris", "version": 1}, 150, 4, 1),
  229. # anneal
  230. (2, {"data_id": 2}, 11, 38, 1),
  231. (2, {"name": "anneal", "version": 1}, 11, 38, 1),
  232. # cpu
  233. (561, {"data_id": 561}, 209, 7, 1),
  234. (561, {"name": "cpu", "version": 1}, 209, 7, 1),
  235. # emotions
  236. (40589, {"data_id": 40589}, 13, 72, 6),
  237. # adult-census
  238. (1119, {"data_id": 1119}, 10, 14, 1),
  239. (1119, {"name": "adult-census"}, 10, 14, 1),
  240. # miceprotein
  241. (40966, {"data_id": 40966}, 7, 77, 1),
  242. (40966, {"name": "MiceProtein"}, 7, 77, 1),
  243. ],
  244. )
  245. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  246. def test_fetch_openml_as_frame_false(
  247. monkeypatch,
  248. data_id,
  249. dataset_params,
  250. n_samples,
  251. n_features,
  252. n_targets,
  253. parser,
  254. ):
  255. """Check the behaviour of `fetch_openml` with `as_frame=False`.
  256. Fetch both by ID and/or name + version.
  257. """
  258. pytest.importorskip("pandas")
  259. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  260. bunch = fetch_openml(
  261. as_frame=False,
  262. cache=False,
  263. parser=parser,
  264. **dataset_params,
  265. )
  266. assert int(bunch.details["id"]) == data_id
  267. assert isinstance(bunch, Bunch)
  268. assert bunch.frame is None
  269. assert isinstance(bunch.data, np.ndarray)
  270. assert bunch.data.shape == (n_samples, n_features)
  271. assert isinstance(bunch.target, np.ndarray)
  272. if n_targets == 1:
  273. assert bunch.target.shape == (n_samples,)
  274. else:
  275. assert bunch.target.shape == (n_samples, n_targets)
  276. assert isinstance(bunch.categories, dict)
  277. # Known failure of PyPy for OpenML. See the following issue:
  278. # https://github.com/scikit-learn/scikit-learn/issues/18906
  279. @fails_if_pypy
  280. @pytest.mark.parametrize("data_id", [61, 1119, 40945])
  281. def test_fetch_openml_consistency_parser(monkeypatch, data_id):
  282. """Check the consistency of the LIAC-ARFF and pandas parsers."""
  283. pd = pytest.importorskip("pandas")
  284. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  285. bunch_liac = fetch_openml(
  286. data_id=data_id,
  287. as_frame=True,
  288. cache=False,
  289. parser="liac-arff",
  290. )
  291. bunch_pandas = fetch_openml(
  292. data_id=data_id,
  293. as_frame=True,
  294. cache=False,
  295. parser="pandas",
  296. )
  297. # The data frames for the input features should match up to some numerical
  298. # dtype conversions (e.g. float64 <=> Int64) due to limitations of the
  299. # LIAC-ARFF parser.
  300. data_liac, data_pandas = bunch_liac.data, bunch_pandas.data
  301. def convert_numerical_dtypes(series):
  302. pandas_series = data_pandas[series.name]
  303. if pd.api.types.is_numeric_dtype(pandas_series):
  304. return series.astype(pandas_series.dtype)
  305. else:
  306. return series
  307. data_liac_with_fixed_dtypes = data_liac.apply(convert_numerical_dtypes)
  308. pd.testing.assert_frame_equal(data_liac_with_fixed_dtypes, data_pandas)
  309. # Let's also check that the .frame attributes also match
  310. frame_liac, frame_pandas = bunch_liac.frame, bunch_pandas.frame
  311. # Note that the .frame attribute is a superset of the .data attribute:
  312. pd.testing.assert_frame_equal(frame_pandas[bunch_pandas.feature_names], data_pandas)
  313. # However the remaining columns, typically the target(s), are not necessarily
  314. # dtyped similarly by both parsers due to limitations of the LIAC-ARFF parser.
  315. # Therefore, extra dtype conversions are required for those columns:
  316. def convert_numerical_and_categorical_dtypes(series):
  317. pandas_series = frame_pandas[series.name]
  318. if pd.api.types.is_numeric_dtype(pandas_series):
  319. return series.astype(pandas_series.dtype)
  320. elif isinstance(pandas_series.dtype, pd.CategoricalDtype):
  321. # Compare categorical features by converting categorical liac uses
  322. # strings to denote the categories, we rename the categories to make
  323. # them comparable to the pandas parser. Fixing this behavior in
  324. # LIAC-ARFF would allow to check the consistency in the future but
  325. # we do not plan to maintain the LIAC-ARFF on the long term.
  326. return series.cat.rename_categories(pandas_series.cat.categories)
  327. else:
  328. return series
  329. frame_liac_with_fixed_dtypes = frame_liac.apply(
  330. convert_numerical_and_categorical_dtypes
  331. )
  332. pd.testing.assert_frame_equal(frame_liac_with_fixed_dtypes, frame_pandas)
  333. # Known failure of PyPy for OpenML. See the following issue:
  334. # https://github.com/scikit-learn/scikit-learn/issues/18906
  335. @fails_if_pypy
  336. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  337. def test_fetch_openml_equivalence_array_dataframe(monkeypatch, parser):
  338. """Check the equivalence of the dataset when using `as_frame=False` and
  339. `as_frame=True`.
  340. """
  341. pytest.importorskip("pandas")
  342. data_id = 61
  343. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  344. bunch_as_frame_true = fetch_openml(
  345. data_id=data_id,
  346. as_frame=True,
  347. cache=False,
  348. parser=parser,
  349. )
  350. bunch_as_frame_false = fetch_openml(
  351. data_id=data_id,
  352. as_frame=False,
  353. cache=False,
  354. parser=parser,
  355. )
  356. assert_allclose(bunch_as_frame_false.data, bunch_as_frame_true.data)
  357. assert_array_equal(bunch_as_frame_false.target, bunch_as_frame_true.target)
  358. # Known failure of PyPy for OpenML. See the following issue:
  359. # https://github.com/scikit-learn/scikit-learn/issues/18906
  360. @fails_if_pypy
  361. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  362. def test_fetch_openml_iris_pandas(monkeypatch, parser):
  363. """Check fetching on a numerical only dataset with string labels."""
  364. pd = pytest.importorskip("pandas")
  365. CategoricalDtype = pd.api.types.CategoricalDtype
  366. data_id = 61
  367. data_shape = (150, 4)
  368. target_shape = (150,)
  369. frame_shape = (150, 5)
  370. target_dtype = CategoricalDtype(
  371. ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
  372. )
  373. data_dtypes = [np.float64] * 4
  374. data_names = ["sepallength", "sepalwidth", "petallength", "petalwidth"]
  375. target_name = "class"
  376. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  377. bunch = fetch_openml(
  378. data_id=data_id,
  379. as_frame=True,
  380. cache=False,
  381. parser=parser,
  382. )
  383. data = bunch.data
  384. target = bunch.target
  385. frame = bunch.frame
  386. assert isinstance(data, pd.DataFrame)
  387. assert np.all(data.dtypes == data_dtypes)
  388. assert data.shape == data_shape
  389. assert np.all(data.columns == data_names)
  390. assert np.all(bunch.feature_names == data_names)
  391. assert bunch.target_names == [target_name]
  392. assert isinstance(target, pd.Series)
  393. assert target.dtype == target_dtype
  394. assert target.shape == target_shape
  395. assert target.name == target_name
  396. assert target.index.is_unique
  397. assert isinstance(frame, pd.DataFrame)
  398. assert frame.shape == frame_shape
  399. assert np.all(frame.dtypes == data_dtypes + [target_dtype])
  400. assert frame.index.is_unique
  401. # Known failure of PyPy for OpenML. See the following issue:
  402. # https://github.com/scikit-learn/scikit-learn/issues/18906
  403. @fails_if_pypy
  404. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  405. @pytest.mark.parametrize("target_column", ["petalwidth", ["petalwidth", "petallength"]])
  406. def test_fetch_openml_forcing_targets(monkeypatch, parser, target_column):
  407. """Check that we can force the target to not be the default target."""
  408. pd = pytest.importorskip("pandas")
  409. data_id = 61
  410. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  411. bunch_forcing_target = fetch_openml(
  412. data_id=data_id,
  413. as_frame=True,
  414. cache=False,
  415. target_column=target_column,
  416. parser=parser,
  417. )
  418. bunch_default = fetch_openml(
  419. data_id=data_id,
  420. as_frame=True,
  421. cache=False,
  422. parser=parser,
  423. )
  424. pd.testing.assert_frame_equal(bunch_forcing_target.frame, bunch_default.frame)
  425. if isinstance(target_column, list):
  426. pd.testing.assert_index_equal(
  427. bunch_forcing_target.target.columns, pd.Index(target_column)
  428. )
  429. assert bunch_forcing_target.data.shape == (150, 3)
  430. else:
  431. assert bunch_forcing_target.target.name == target_column
  432. assert bunch_forcing_target.data.shape == (150, 4)
  433. # Known failure of PyPy for OpenML. See the following issue:
  434. # https://github.com/scikit-learn/scikit-learn/issues/18906
  435. @fails_if_pypy
  436. @pytest.mark.parametrize("data_id", [61, 2, 561, 40589, 1119])
  437. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  438. def test_fetch_openml_equivalence_frame_return_X_y(monkeypatch, data_id, parser):
  439. """Check the behaviour of `return_X_y=True` when `as_frame=True`."""
  440. pd = pytest.importorskip("pandas")
  441. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  442. bunch = fetch_openml(
  443. data_id=data_id,
  444. as_frame=True,
  445. cache=False,
  446. return_X_y=False,
  447. parser=parser,
  448. )
  449. X, y = fetch_openml(
  450. data_id=data_id,
  451. as_frame=True,
  452. cache=False,
  453. return_X_y=True,
  454. parser=parser,
  455. )
  456. pd.testing.assert_frame_equal(bunch.data, X)
  457. if isinstance(y, pd.Series):
  458. pd.testing.assert_series_equal(bunch.target, y)
  459. else:
  460. pd.testing.assert_frame_equal(bunch.target, y)
  461. # Known failure of PyPy for OpenML. See the following issue:
  462. # https://github.com/scikit-learn/scikit-learn/issues/18906
  463. @fails_if_pypy
  464. @pytest.mark.parametrize("data_id", [61, 561, 40589, 1119])
  465. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  466. def test_fetch_openml_equivalence_array_return_X_y(monkeypatch, data_id, parser):
  467. """Check the behaviour of `return_X_y=True` when `as_frame=False`."""
  468. pytest.importorskip("pandas")
  469. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  470. bunch = fetch_openml(
  471. data_id=data_id,
  472. as_frame=False,
  473. cache=False,
  474. return_X_y=False,
  475. parser=parser,
  476. )
  477. X, y = fetch_openml(
  478. data_id=data_id,
  479. as_frame=False,
  480. cache=False,
  481. return_X_y=True,
  482. parser=parser,
  483. )
  484. assert_array_equal(bunch.data, X)
  485. assert_array_equal(bunch.target, y)
  486. # Known failure of PyPy for OpenML. See the following issue:
  487. # https://github.com/scikit-learn/scikit-learn/issues/18906
  488. @fails_if_pypy
  489. def test_fetch_openml_difference_parsers(monkeypatch):
  490. """Check the difference between liac-arff and pandas parser."""
  491. pytest.importorskip("pandas")
  492. data_id = 1119
  493. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=True)
  494. # When `as_frame=False`, the categories will be ordinally encoded with
  495. # liac-arff parser while this is not the case with pandas parser.
  496. as_frame = False
  497. bunch_liac_arff = fetch_openml(
  498. data_id=data_id,
  499. as_frame=as_frame,
  500. cache=False,
  501. parser="liac-arff",
  502. )
  503. bunch_pandas = fetch_openml(
  504. data_id=data_id,
  505. as_frame=as_frame,
  506. cache=False,
  507. parser="pandas",
  508. )
  509. assert bunch_liac_arff.data.dtype.kind == "f"
  510. assert bunch_pandas.data.dtype == "O"
  511. ###############################################################################
  512. # Test the ARFF parsing on several dataset to check if detect the correct
  513. # types (categories, integers, floats).
  514. @pytest.fixture(scope="module")
  515. def datasets_column_names():
  516. """Returns the columns names for each dataset."""
  517. return {
  518. 61: ["sepallength", "sepalwidth", "petallength", "petalwidth", "class"],
  519. 2: [
  520. "family",
  521. "product-type",
  522. "steel",
  523. "carbon",
  524. "hardness",
  525. "temper_rolling",
  526. "condition",
  527. "formability",
  528. "strength",
  529. "non-ageing",
  530. "surface-finish",
  531. "surface-quality",
  532. "enamelability",
  533. "bc",
  534. "bf",
  535. "bt",
  536. "bw%2Fme",
  537. "bl",
  538. "m",
  539. "chrom",
  540. "phos",
  541. "cbond",
  542. "marvi",
  543. "exptl",
  544. "ferro",
  545. "corr",
  546. "blue%2Fbright%2Fvarn%2Fclean",
  547. "lustre",
  548. "jurofm",
  549. "s",
  550. "p",
  551. "shape",
  552. "thick",
  553. "width",
  554. "len",
  555. "oil",
  556. "bore",
  557. "packing",
  558. "class",
  559. ],
  560. 561: ["vendor", "MYCT", "MMIN", "MMAX", "CACH", "CHMIN", "CHMAX", "class"],
  561. 40589: [
  562. "Mean_Acc1298_Mean_Mem40_Centroid",
  563. "Mean_Acc1298_Mean_Mem40_Rolloff",
  564. "Mean_Acc1298_Mean_Mem40_Flux",
  565. "Mean_Acc1298_Mean_Mem40_MFCC_0",
  566. "Mean_Acc1298_Mean_Mem40_MFCC_1",
  567. "Mean_Acc1298_Mean_Mem40_MFCC_2",
  568. "Mean_Acc1298_Mean_Mem40_MFCC_3",
  569. "Mean_Acc1298_Mean_Mem40_MFCC_4",
  570. "Mean_Acc1298_Mean_Mem40_MFCC_5",
  571. "Mean_Acc1298_Mean_Mem40_MFCC_6",
  572. "Mean_Acc1298_Mean_Mem40_MFCC_7",
  573. "Mean_Acc1298_Mean_Mem40_MFCC_8",
  574. "Mean_Acc1298_Mean_Mem40_MFCC_9",
  575. "Mean_Acc1298_Mean_Mem40_MFCC_10",
  576. "Mean_Acc1298_Mean_Mem40_MFCC_11",
  577. "Mean_Acc1298_Mean_Mem40_MFCC_12",
  578. "Mean_Acc1298_Std_Mem40_Centroid",
  579. "Mean_Acc1298_Std_Mem40_Rolloff",
  580. "Mean_Acc1298_Std_Mem40_Flux",
  581. "Mean_Acc1298_Std_Mem40_MFCC_0",
  582. "Mean_Acc1298_Std_Mem40_MFCC_1",
  583. "Mean_Acc1298_Std_Mem40_MFCC_2",
  584. "Mean_Acc1298_Std_Mem40_MFCC_3",
  585. "Mean_Acc1298_Std_Mem40_MFCC_4",
  586. "Mean_Acc1298_Std_Mem40_MFCC_5",
  587. "Mean_Acc1298_Std_Mem40_MFCC_6",
  588. "Mean_Acc1298_Std_Mem40_MFCC_7",
  589. "Mean_Acc1298_Std_Mem40_MFCC_8",
  590. "Mean_Acc1298_Std_Mem40_MFCC_9",
  591. "Mean_Acc1298_Std_Mem40_MFCC_10",
  592. "Mean_Acc1298_Std_Mem40_MFCC_11",
  593. "Mean_Acc1298_Std_Mem40_MFCC_12",
  594. "Std_Acc1298_Mean_Mem40_Centroid",
  595. "Std_Acc1298_Mean_Mem40_Rolloff",
  596. "Std_Acc1298_Mean_Mem40_Flux",
  597. "Std_Acc1298_Mean_Mem40_MFCC_0",
  598. "Std_Acc1298_Mean_Mem40_MFCC_1",
  599. "Std_Acc1298_Mean_Mem40_MFCC_2",
  600. "Std_Acc1298_Mean_Mem40_MFCC_3",
  601. "Std_Acc1298_Mean_Mem40_MFCC_4",
  602. "Std_Acc1298_Mean_Mem40_MFCC_5",
  603. "Std_Acc1298_Mean_Mem40_MFCC_6",
  604. "Std_Acc1298_Mean_Mem40_MFCC_7",
  605. "Std_Acc1298_Mean_Mem40_MFCC_8",
  606. "Std_Acc1298_Mean_Mem40_MFCC_9",
  607. "Std_Acc1298_Mean_Mem40_MFCC_10",
  608. "Std_Acc1298_Mean_Mem40_MFCC_11",
  609. "Std_Acc1298_Mean_Mem40_MFCC_12",
  610. "Std_Acc1298_Std_Mem40_Centroid",
  611. "Std_Acc1298_Std_Mem40_Rolloff",
  612. "Std_Acc1298_Std_Mem40_Flux",
  613. "Std_Acc1298_Std_Mem40_MFCC_0",
  614. "Std_Acc1298_Std_Mem40_MFCC_1",
  615. "Std_Acc1298_Std_Mem40_MFCC_2",
  616. "Std_Acc1298_Std_Mem40_MFCC_3",
  617. "Std_Acc1298_Std_Mem40_MFCC_4",
  618. "Std_Acc1298_Std_Mem40_MFCC_5",
  619. "Std_Acc1298_Std_Mem40_MFCC_6",
  620. "Std_Acc1298_Std_Mem40_MFCC_7",
  621. "Std_Acc1298_Std_Mem40_MFCC_8",
  622. "Std_Acc1298_Std_Mem40_MFCC_9",
  623. "Std_Acc1298_Std_Mem40_MFCC_10",
  624. "Std_Acc1298_Std_Mem40_MFCC_11",
  625. "Std_Acc1298_Std_Mem40_MFCC_12",
  626. "BH_LowPeakAmp",
  627. "BH_LowPeakBPM",
  628. "BH_HighPeakAmp",
  629. "BH_HighPeakBPM",
  630. "BH_HighLowRatio",
  631. "BHSUM1",
  632. "BHSUM2",
  633. "BHSUM3",
  634. "amazed.suprised",
  635. "happy.pleased",
  636. "relaxing.calm",
  637. "quiet.still",
  638. "sad.lonely",
  639. "angry.aggresive",
  640. ],
  641. 1119: [
  642. "age",
  643. "workclass",
  644. "fnlwgt:",
  645. "education:",
  646. "education-num:",
  647. "marital-status:",
  648. "occupation:",
  649. "relationship:",
  650. "race:",
  651. "sex:",
  652. "capital-gain:",
  653. "capital-loss:",
  654. "hours-per-week:",
  655. "native-country:",
  656. "class",
  657. ],
  658. 40966: [
  659. "DYRK1A_N",
  660. "ITSN1_N",
  661. "BDNF_N",
  662. "NR1_N",
  663. "NR2A_N",
  664. "pAKT_N",
  665. "pBRAF_N",
  666. "pCAMKII_N",
  667. "pCREB_N",
  668. "pELK_N",
  669. "pERK_N",
  670. "pJNK_N",
  671. "PKCA_N",
  672. "pMEK_N",
  673. "pNR1_N",
  674. "pNR2A_N",
  675. "pNR2B_N",
  676. "pPKCAB_N",
  677. "pRSK_N",
  678. "AKT_N",
  679. "BRAF_N",
  680. "CAMKII_N",
  681. "CREB_N",
  682. "ELK_N",
  683. "ERK_N",
  684. "GSK3B_N",
  685. "JNK_N",
  686. "MEK_N",
  687. "TRKA_N",
  688. "RSK_N",
  689. "APP_N",
  690. "Bcatenin_N",
  691. "SOD1_N",
  692. "MTOR_N",
  693. "P38_N",
  694. "pMTOR_N",
  695. "DSCR1_N",
  696. "AMPKA_N",
  697. "NR2B_N",
  698. "pNUMB_N",
  699. "RAPTOR_N",
  700. "TIAM1_N",
  701. "pP70S6_N",
  702. "NUMB_N",
  703. "P70S6_N",
  704. "pGSK3B_N",
  705. "pPKCG_N",
  706. "CDK5_N",
  707. "S6_N",
  708. "ADARB1_N",
  709. "AcetylH3K9_N",
  710. "RRP1_N",
  711. "BAX_N",
  712. "ARC_N",
  713. "ERBB4_N",
  714. "nNOS_N",
  715. "Tau_N",
  716. "GFAP_N",
  717. "GluR3_N",
  718. "GluR4_N",
  719. "IL1B_N",
  720. "P3525_N",
  721. "pCASP9_N",
  722. "PSD95_N",
  723. "SNCA_N",
  724. "Ubiquitin_N",
  725. "pGSK3B_Tyr216_N",
  726. "SHH_N",
  727. "BAD_N",
  728. "BCL2_N",
  729. "pS6_N",
  730. "pCFOS_N",
  731. "SYP_N",
  732. "H3AcK18_N",
  733. "EGR1_N",
  734. "H3MeK4_N",
  735. "CaNA_N",
  736. "class",
  737. ],
  738. 40945: [
  739. "pclass",
  740. "survived",
  741. "name",
  742. "sex",
  743. "age",
  744. "sibsp",
  745. "parch",
  746. "ticket",
  747. "fare",
  748. "cabin",
  749. "embarked",
  750. "boat",
  751. "body",
  752. "home.dest",
  753. ],
  754. }
  755. @pytest.fixture(scope="module")
  756. def datasets_missing_values():
  757. return {
  758. 61: {},
  759. 2: {
  760. "family": 11,
  761. "temper_rolling": 9,
  762. "condition": 2,
  763. "formability": 4,
  764. "non-ageing": 10,
  765. "surface-finish": 11,
  766. "enamelability": 11,
  767. "bc": 11,
  768. "bf": 10,
  769. "bt": 11,
  770. "bw%2Fme": 8,
  771. "bl": 9,
  772. "m": 11,
  773. "chrom": 11,
  774. "phos": 11,
  775. "cbond": 10,
  776. "marvi": 11,
  777. "exptl": 11,
  778. "ferro": 11,
  779. "corr": 11,
  780. "blue%2Fbright%2Fvarn%2Fclean": 11,
  781. "lustre": 8,
  782. "jurofm": 11,
  783. "s": 11,
  784. "p": 11,
  785. "oil": 10,
  786. "packing": 11,
  787. },
  788. 561: {},
  789. 40589: {},
  790. 1119: {},
  791. 40966: {"BCL2_N": 7},
  792. 40945: {
  793. "age": 263,
  794. "fare": 1,
  795. "cabin": 1014,
  796. "embarked": 2,
  797. "boat": 823,
  798. "body": 1188,
  799. "home.dest": 564,
  800. },
  801. }
  802. # Known failure of PyPy for OpenML. See the following issue:
  803. # https://github.com/scikit-learn/scikit-learn/issues/18906
  804. @fails_if_pypy
  805. @pytest.mark.parametrize(
  806. "data_id, parser, expected_n_categories, expected_n_floats, expected_n_ints",
  807. [
  808. # iris dataset
  809. (61, "liac-arff", 1, 4, 0),
  810. (61, "pandas", 1, 4, 0),
  811. # anneal dataset
  812. (2, "liac-arff", 33, 6, 0),
  813. (2, "pandas", 33, 2, 4),
  814. # cpu dataset
  815. (561, "liac-arff", 1, 7, 0),
  816. (561, "pandas", 1, 0, 7),
  817. # emotions dataset
  818. (40589, "liac-arff", 6, 72, 0),
  819. (40589, "pandas", 6, 69, 3),
  820. # adult-census dataset
  821. (1119, "liac-arff", 9, 6, 0),
  822. (1119, "pandas", 9, 0, 6),
  823. # miceprotein
  824. (40966, "liac-arff", 1, 77, 0),
  825. (40966, "pandas", 1, 77, 0),
  826. # titanic
  827. (40945, "liac-arff", 3, 6, 0),
  828. (40945, "pandas", 3, 3, 3),
  829. ],
  830. )
  831. @pytest.mark.parametrize("gzip_response", [True, False])
  832. def test_fetch_openml_types_inference(
  833. monkeypatch,
  834. data_id,
  835. parser,
  836. expected_n_categories,
  837. expected_n_floats,
  838. expected_n_ints,
  839. gzip_response,
  840. datasets_column_names,
  841. datasets_missing_values,
  842. ):
  843. """Check that `fetch_openml` infer the right number of categories, integers, and
  844. floats."""
  845. pd = pytest.importorskip("pandas")
  846. CategoricalDtype = pd.api.types.CategoricalDtype
  847. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
  848. bunch = fetch_openml(
  849. data_id=data_id,
  850. as_frame=True,
  851. cache=False,
  852. parser=parser,
  853. )
  854. frame = bunch.frame
  855. n_categories = len(
  856. [dtype for dtype in frame.dtypes if isinstance(dtype, CategoricalDtype)]
  857. )
  858. n_floats = len([dtype for dtype in frame.dtypes if dtype.kind == "f"])
  859. n_ints = len([dtype for dtype in frame.dtypes if dtype.kind == "i"])
  860. assert n_categories == expected_n_categories
  861. assert n_floats == expected_n_floats
  862. assert n_ints == expected_n_ints
  863. assert frame.columns.tolist() == datasets_column_names[data_id]
  864. frame_feature_to_n_nan = frame.isna().sum().to_dict()
  865. for name, n_missing in frame_feature_to_n_nan.items():
  866. expected_missing = datasets_missing_values[data_id].get(name, 0)
  867. assert n_missing == expected_missing
  868. ###############################################################################
  869. # Test some more specific behaviour
  870. # TODO(1.4): remove this filterwarning decorator
  871. @pytest.mark.filterwarnings("ignore:The default value of `parser` will change")
  872. @pytest.mark.parametrize(
  873. "params, err_msg",
  874. [
  875. (
  876. {"parser": "unknown"},
  877. "The 'parser' parameter of fetch_openml must be a str among",
  878. ),
  879. (
  880. {"as_frame": "unknown"},
  881. "The 'as_frame' parameter of fetch_openml must be an instance",
  882. ),
  883. ],
  884. )
  885. def test_fetch_openml_validation_parameter(monkeypatch, params, err_msg):
  886. data_id = 1119
  887. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  888. with pytest.raises(ValueError, match=err_msg):
  889. fetch_openml(data_id=data_id, **params)
  890. @pytest.mark.parametrize(
  891. "params",
  892. [
  893. {"as_frame": True, "parser": "auto"},
  894. {"as_frame": "auto", "parser": "auto"},
  895. {"as_frame": False, "parser": "pandas"},
  896. ],
  897. )
  898. def test_fetch_openml_requires_pandas_error(monkeypatch, params):
  899. """Check that we raise the proper errors when we require pandas."""
  900. data_id = 1119
  901. try:
  902. check_pandas_support("test_fetch_openml_requires_pandas")
  903. except ImportError:
  904. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  905. err_msg = "requires pandas to be installed. Alternatively, explicitly"
  906. with pytest.raises(ImportError, match=err_msg):
  907. fetch_openml(data_id=data_id, **params)
  908. else:
  909. raise SkipTest("This test requires pandas to not be installed.")
  910. # TODO(1.4): move this parameter option in`test_fetch_openml_requires_pandas_error`
  911. def test_fetch_openml_requires_pandas_in_future(monkeypatch):
  912. """Check that we raise a warning that pandas will be required in the future."""
  913. params = {"as_frame": False, "parser": "auto"}
  914. data_id = 1119
  915. try:
  916. check_pandas_support("test_fetch_openml_requires_pandas")
  917. except ImportError:
  918. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  919. warn_msg = (
  920. "From version 1.4, `parser='auto'` with `as_frame=False` will use pandas"
  921. )
  922. with pytest.warns(FutureWarning, match=warn_msg):
  923. fetch_openml(data_id=data_id, **params)
  924. else:
  925. raise SkipTest("This test requires pandas to not be installed.")
  926. @pytest.mark.filterwarnings("ignore:Version 1 of dataset Australian is inactive")
  927. # TODO(1.4): remove this filterwarning decorator for `parser`
  928. @pytest.mark.filterwarnings("ignore:The default value of `parser` will change")
  929. @pytest.mark.parametrize(
  930. "params, err_msg",
  931. [
  932. (
  933. {"parser": "pandas"},
  934. "Sparse ARFF datasets cannot be loaded with parser='pandas'",
  935. ),
  936. (
  937. {"as_frame": True},
  938. "Sparse ARFF datasets cannot be loaded with as_frame=True.",
  939. ),
  940. (
  941. {"parser": "pandas", "as_frame": True},
  942. "Sparse ARFF datasets cannot be loaded with as_frame=True.",
  943. ),
  944. ],
  945. )
  946. def test_fetch_openml_sparse_arff_error(monkeypatch, params, err_msg):
  947. """Check that we raise the expected error for sparse ARFF datasets and
  948. a wrong set of incompatible parameters.
  949. """
  950. pytest.importorskip("pandas")
  951. data_id = 292
  952. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  953. with pytest.raises(ValueError, match=err_msg):
  954. fetch_openml(
  955. data_id=data_id,
  956. cache=False,
  957. **params,
  958. )
  959. # Known failure of PyPy for OpenML. See the following issue:
  960. # https://github.com/scikit-learn/scikit-learn/issues/18906
  961. @fails_if_pypy
  962. @pytest.mark.filterwarnings("ignore:Version 1 of dataset Australian is inactive")
  963. @pytest.mark.parametrize(
  964. "data_id, data_type",
  965. [
  966. (61, "dataframe"), # iris dataset version 1
  967. (292, "sparse"), # Australian dataset version 1
  968. ],
  969. )
  970. def test_fetch_openml_auto_mode(monkeypatch, data_id, data_type):
  971. """Check the auto mode of `fetch_openml`."""
  972. pd = pytest.importorskip("pandas")
  973. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  974. data = fetch_openml(data_id=data_id, as_frame="auto", parser="auto", cache=False)
  975. klass = pd.DataFrame if data_type == "dataframe" else scipy.sparse.csr_matrix
  976. assert isinstance(data.data, klass)
  977. # Known failure of PyPy for OpenML. See the following issue:
  978. # https://github.com/scikit-learn/scikit-learn/issues/18906
  979. @fails_if_pypy
  980. def test_convert_arff_data_dataframe_warning_low_memory_pandas(monkeypatch):
  981. """Check that we raise a warning regarding the working memory when using
  982. LIAC-ARFF parser."""
  983. pytest.importorskip("pandas")
  984. data_id = 1119
  985. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  986. msg = "Could not adhere to working_memory config."
  987. with pytest.warns(UserWarning, match=msg):
  988. with config_context(working_memory=1e-6):
  989. fetch_openml(
  990. data_id=data_id,
  991. as_frame=True,
  992. cache=False,
  993. parser="liac-arff",
  994. )
  995. @pytest.mark.parametrize("gzip_response", [True, False])
  996. def test_fetch_openml_iris_warn_multiple_version(monkeypatch, gzip_response):
  997. """Check that a warning is raised when multiple versions exist and no version is
  998. requested."""
  999. data_id = 61
  1000. data_name = "iris"
  1001. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1002. msg = (
  1003. "Multiple active versions of the dataset matching the name"
  1004. " iris exist. Versions may be fundamentally different, "
  1005. "returning version 1."
  1006. )
  1007. with pytest.warns(UserWarning, match=msg):
  1008. fetch_openml(
  1009. name=data_name,
  1010. as_frame=False,
  1011. cache=False,
  1012. parser="liac-arff",
  1013. )
  1014. @pytest.mark.parametrize("gzip_response", [True, False])
  1015. def test_fetch_openml_no_target(monkeypatch, gzip_response):
  1016. """Check that we can get a dataset without target."""
  1017. data_id = 61
  1018. target_column = None
  1019. expected_observations = 150
  1020. expected_features = 5
  1021. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1022. data = fetch_openml(
  1023. data_id=data_id,
  1024. target_column=target_column,
  1025. cache=False,
  1026. as_frame=False,
  1027. parser="liac-arff",
  1028. )
  1029. assert data.data.shape == (expected_observations, expected_features)
  1030. assert data.target is None
  1031. @pytest.mark.parametrize("gzip_response", [True, False])
  1032. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  1033. def test_missing_values_pandas(monkeypatch, gzip_response, parser):
  1034. """check that missing values in categories are compatible with pandas
  1035. categorical"""
  1036. pytest.importorskip("pandas")
  1037. data_id = 42585
  1038. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response=gzip_response)
  1039. penguins = fetch_openml(
  1040. data_id=data_id,
  1041. cache=False,
  1042. as_frame=True,
  1043. parser=parser,
  1044. )
  1045. cat_dtype = penguins.data.dtypes["sex"]
  1046. # there are nans in the categorical
  1047. assert penguins.data["sex"].isna().any()
  1048. assert_array_equal(cat_dtype.categories, ["FEMALE", "MALE", "_"])
  1049. @pytest.mark.parametrize("gzip_response", [True, False])
  1050. @pytest.mark.parametrize(
  1051. "dataset_params",
  1052. [
  1053. {"data_id": 40675},
  1054. {"data_id": None, "name": "glass2", "version": 1},
  1055. ],
  1056. )
  1057. def test_fetch_openml_inactive(monkeypatch, gzip_response, dataset_params):
  1058. """Check that we raise a warning when the dataset is inactive."""
  1059. data_id = 40675
  1060. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1061. msg = "Version 1 of dataset glass2 is inactive,"
  1062. with pytest.warns(UserWarning, match=msg):
  1063. glass2 = fetch_openml(
  1064. cache=False, as_frame=False, parser="liac-arff", **dataset_params
  1065. )
  1066. assert glass2.data.shape == (163, 9)
  1067. assert glass2.details["id"] == "40675"
  1068. @pytest.mark.parametrize("gzip_response", [True, False])
  1069. @pytest.mark.parametrize(
  1070. "data_id, params, err_type, err_msg",
  1071. [
  1072. (40675, {"name": "glass2"}, ValueError, "No active dataset glass2 found"),
  1073. (
  1074. 61,
  1075. {"data_id": 61, "target_column": ["sepalwidth", "class"]},
  1076. ValueError,
  1077. "Can only handle homogeneous multi-target datasets",
  1078. ),
  1079. (
  1080. 40945,
  1081. {"data_id": 40945, "as_frame": False},
  1082. ValueError,
  1083. (
  1084. "STRING attributes are not supported for array representation. Try"
  1085. " as_frame=True"
  1086. ),
  1087. ),
  1088. (
  1089. 2,
  1090. {"data_id": 2, "target_column": "family", "as_frame": True},
  1091. ValueError,
  1092. "Target column 'family'",
  1093. ),
  1094. (
  1095. 2,
  1096. {"data_id": 2, "target_column": "family", "as_frame": False},
  1097. ValueError,
  1098. "Target column 'family'",
  1099. ),
  1100. (
  1101. 61,
  1102. {"data_id": 61, "target_column": "undefined"},
  1103. KeyError,
  1104. "Could not find target_column='undefined'",
  1105. ),
  1106. (
  1107. 61,
  1108. {"data_id": 61, "target_column": ["undefined", "class"]},
  1109. KeyError,
  1110. "Could not find target_column='undefined'",
  1111. ),
  1112. ],
  1113. )
  1114. @pytest.mark.parametrize("parser", ["liac-arff", "pandas"])
  1115. def test_fetch_openml_error(
  1116. monkeypatch, gzip_response, data_id, params, err_type, err_msg, parser
  1117. ):
  1118. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1119. if params.get("as_frame", True) or parser == "pandas":
  1120. pytest.importorskip("pandas")
  1121. with pytest.raises(err_type, match=err_msg):
  1122. fetch_openml(cache=False, parser=parser, **params)
  1123. @pytest.mark.parametrize(
  1124. "params, err_type, err_msg",
  1125. [
  1126. (
  1127. {"data_id": -1, "name": None, "version": "version"},
  1128. ValueError,
  1129. "The 'version' parameter of fetch_openml must be an int in the range",
  1130. ),
  1131. (
  1132. {"data_id": -1, "name": "nAmE"},
  1133. ValueError,
  1134. "The 'data_id' parameter of fetch_openml must be an int in the range",
  1135. ),
  1136. (
  1137. {"data_id": -1, "name": "nAmE", "version": "version"},
  1138. ValueError,
  1139. "The 'version' parameter of fetch_openml must be an int",
  1140. ),
  1141. (
  1142. {},
  1143. ValueError,
  1144. "Neither name nor data_id are provided. Please provide name or data_id.",
  1145. ),
  1146. ],
  1147. )
  1148. def test_fetch_openml_raises_illegal_argument(params, err_type, err_msg):
  1149. with pytest.raises(err_type, match=err_msg):
  1150. fetch_openml(**params)
  1151. @pytest.mark.parametrize("gzip_response", [True, False])
  1152. def test_warn_ignore_attribute(monkeypatch, gzip_response):
  1153. data_id = 40966
  1154. expected_row_id_msg = "target_column='{}' has flag is_row_identifier."
  1155. expected_ignore_msg = "target_column='{}' has flag is_ignore."
  1156. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1157. # single column test
  1158. target_col = "MouseID"
  1159. msg = expected_row_id_msg.format(target_col)
  1160. with pytest.warns(UserWarning, match=msg):
  1161. fetch_openml(
  1162. data_id=data_id,
  1163. target_column=target_col,
  1164. cache=False,
  1165. as_frame=False,
  1166. parser="liac-arff",
  1167. )
  1168. target_col = "Genotype"
  1169. msg = expected_ignore_msg.format(target_col)
  1170. with pytest.warns(UserWarning, match=msg):
  1171. fetch_openml(
  1172. data_id=data_id,
  1173. target_column=target_col,
  1174. cache=False,
  1175. as_frame=False,
  1176. parser="liac-arff",
  1177. )
  1178. # multi column test
  1179. target_col = "MouseID"
  1180. msg = expected_row_id_msg.format(target_col)
  1181. with pytest.warns(UserWarning, match=msg):
  1182. fetch_openml(
  1183. data_id=data_id,
  1184. target_column=[target_col, "class"],
  1185. cache=False,
  1186. as_frame=False,
  1187. parser="liac-arff",
  1188. )
  1189. target_col = "Genotype"
  1190. msg = expected_ignore_msg.format(target_col)
  1191. with pytest.warns(UserWarning, match=msg):
  1192. fetch_openml(
  1193. data_id=data_id,
  1194. target_column=[target_col, "class"],
  1195. cache=False,
  1196. as_frame=False,
  1197. parser="liac-arff",
  1198. )
  1199. @pytest.mark.parametrize("gzip_response", [True, False])
  1200. def test_dataset_with_openml_error(monkeypatch, gzip_response):
  1201. data_id = 1
  1202. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1203. msg = "OpenML registered a problem with the dataset. It might be unusable. Error:"
  1204. with pytest.warns(UserWarning, match=msg):
  1205. fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")
  1206. @pytest.mark.parametrize("gzip_response", [True, False])
  1207. def test_dataset_with_openml_warning(monkeypatch, gzip_response):
  1208. data_id = 3
  1209. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1210. msg = "OpenML raised a warning on the dataset. It might be unusable. Warning:"
  1211. with pytest.warns(UserWarning, match=msg):
  1212. fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")
  1213. def test_fetch_openml_overwrite_default_params_read_csv(monkeypatch):
  1214. """Check that we can overwrite the default parameters of `read_csv`."""
  1215. pytest.importorskip("pandas")
  1216. data_id = 1590
  1217. _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
  1218. common_params = {
  1219. "data_id": data_id,
  1220. "as_frame": True,
  1221. "cache": False,
  1222. "parser": "pandas",
  1223. }
  1224. # By default, the initial spaces are skipped. We checked that setting the parameter
  1225. # `skipinitialspace` to False will have an effect.
  1226. adult_without_spaces = fetch_openml(**common_params)
  1227. adult_with_spaces = fetch_openml(
  1228. **common_params, read_csv_kwargs={"skipinitialspace": False}
  1229. )
  1230. assert all(
  1231. cat.startswith(" ") for cat in adult_with_spaces.frame["class"].cat.categories
  1232. )
  1233. assert not any(
  1234. cat.startswith(" ")
  1235. for cat in adult_without_spaces.frame["class"].cat.categories
  1236. )
  1237. ###############################################################################
  1238. # Test cache, retry mechanisms, checksum, etc.
  1239. @pytest.mark.parametrize("gzip_response", [True, False])
  1240. def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
  1241. data_id = 61
  1242. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1243. openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
  1244. cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
  1245. # first fill the cache
  1246. response1 = _open_openml_url(openml_path, cache_directory)
  1247. # assert file exists
  1248. location = _get_local_path(openml_path, cache_directory)
  1249. assert os.path.isfile(location)
  1250. # redownload, to utilize cache
  1251. response2 = _open_openml_url(openml_path, cache_directory)
  1252. assert response1.read() == response2.read()
  1253. @pytest.mark.parametrize("write_to_disk", [True, False])
  1254. def test_open_openml_url_unlinks_local_path(monkeypatch, tmpdir, write_to_disk):
  1255. data_id = 61
  1256. openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
  1257. cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
  1258. location = _get_local_path(openml_path, cache_directory)
  1259. def _mock_urlopen(request, *args, **kwargs):
  1260. if write_to_disk:
  1261. with open(location, "w") as f:
  1262. f.write("")
  1263. raise ValueError("Invalid request")
  1264. monkeypatch.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen)
  1265. with pytest.raises(ValueError, match="Invalid request"):
  1266. _open_openml_url(openml_path, cache_directory)
  1267. assert not os.path.exists(location)
  1268. def test_retry_with_clean_cache(tmpdir):
  1269. data_id = 61
  1270. openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
  1271. cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
  1272. location = _get_local_path(openml_path, cache_directory)
  1273. os.makedirs(os.path.dirname(location))
  1274. with open(location, "w") as f:
  1275. f.write("")
  1276. @_retry_with_clean_cache(openml_path, cache_directory)
  1277. def _load_data():
  1278. # The first call will raise an error since location exists
  1279. if os.path.exists(location):
  1280. raise Exception("File exist!")
  1281. return 1
  1282. warn_msg = "Invalid cache, redownloading file"
  1283. with pytest.warns(RuntimeWarning, match=warn_msg):
  1284. result = _load_data()
  1285. assert result == 1
  1286. def test_retry_with_clean_cache_http_error(tmpdir):
  1287. data_id = 61
  1288. openml_path = sklearn.datasets._openml._DATA_FILE.format(data_id)
  1289. cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
  1290. @_retry_with_clean_cache(openml_path, cache_directory)
  1291. def _load_data():
  1292. raise HTTPError(
  1293. url=None, code=412, msg="Simulated mock error", hdrs=None, fp=None
  1294. )
  1295. error_msg = "Simulated mock error"
  1296. with pytest.raises(HTTPError, match=error_msg):
  1297. _load_data()
  1298. @pytest.mark.parametrize("gzip_response", [True, False])
  1299. def test_fetch_openml_cache(monkeypatch, gzip_response, tmpdir):
  1300. def _mock_urlopen_raise(request, *args, **kwargs):
  1301. raise ValueError(
  1302. "This mechanism intends to test correct cache"
  1303. "handling. As such, urlopen should never be "
  1304. "accessed. URL: %s"
  1305. % request.get_full_url()
  1306. )
  1307. data_id = 61
  1308. cache_directory = str(tmpdir.mkdir("scikit_learn_data"))
  1309. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1310. X_fetched, y_fetched = fetch_openml(
  1311. data_id=data_id,
  1312. cache=True,
  1313. data_home=cache_directory,
  1314. return_X_y=True,
  1315. as_frame=False,
  1316. parser="liac-arff",
  1317. )
  1318. monkeypatch.setattr(sklearn.datasets._openml, "urlopen", _mock_urlopen_raise)
  1319. X_cached, y_cached = fetch_openml(
  1320. data_id=data_id,
  1321. cache=True,
  1322. data_home=cache_directory,
  1323. return_X_y=True,
  1324. as_frame=False,
  1325. parser="liac-arff",
  1326. )
  1327. np.testing.assert_array_equal(X_fetched, X_cached)
  1328. np.testing.assert_array_equal(y_fetched, y_cached)
  1329. # Known failure of PyPy for OpenML. See the following issue:
  1330. # https://github.com/scikit-learn/scikit-learn/issues/18906
  1331. @fails_if_pypy
  1332. @pytest.mark.parametrize(
  1333. "as_frame, parser",
  1334. [
  1335. (True, "liac-arff"),
  1336. (False, "liac-arff"),
  1337. (True, "pandas"),
  1338. (False, "pandas"),
  1339. ],
  1340. )
  1341. def test_fetch_openml_verify_checksum(monkeypatch, as_frame, cache, tmpdir, parser):
  1342. """Check that the checksum is working as expected."""
  1343. if as_frame or parser == "pandas":
  1344. pytest.importorskip("pandas")
  1345. data_id = 2
  1346. _monkey_patch_webbased_functions(monkeypatch, data_id, True)
  1347. # create a temporary modified arff file
  1348. original_data_module = OPENML_TEST_DATA_MODULE + "." + f"id_{data_id}"
  1349. original_data_file_name = "data-v1-dl-1666876.arff.gz"
  1350. corrupt_copy_path = tmpdir / "test_invalid_checksum.arff"
  1351. with _open_binary(original_data_module, original_data_file_name) as orig_file:
  1352. orig_gzip = gzip.open(orig_file, "rb")
  1353. data = bytearray(orig_gzip.read())
  1354. data[len(data) - 1] = 37
  1355. with gzip.GzipFile(corrupt_copy_path, "wb") as modified_gzip:
  1356. modified_gzip.write(data)
  1357. # Requests are already mocked by monkey_patch_webbased_functions.
  1358. # We want to reuse that mock for all requests except file download,
  1359. # hence creating a thin mock over the original mock
  1360. mocked_openml_url = sklearn.datasets._openml.urlopen
  1361. def swap_file_mock(request, *args, **kwargs):
  1362. url = request.get_full_url()
  1363. if url.endswith("data/v1/download/1666876"):
  1364. with open(corrupt_copy_path, "rb") as f:
  1365. corrupted_data = f.read()
  1366. return _MockHTTPResponse(BytesIO(corrupted_data), is_gzip=True)
  1367. else:
  1368. return mocked_openml_url(request)
  1369. monkeypatch.setattr(sklearn.datasets._openml, "urlopen", swap_file_mock)
  1370. # validate failed checksum
  1371. with pytest.raises(ValueError) as exc:
  1372. sklearn.datasets.fetch_openml(
  1373. data_id=data_id, cache=False, as_frame=as_frame, parser=parser
  1374. )
  1375. # exception message should have file-path
  1376. assert exc.match("1666876")
  1377. def test_open_openml_url_retry_on_network_error(monkeypatch):
  1378. def _mock_urlopen_network_error(request, *args, **kwargs):
  1379. raise HTTPError("", 404, "Simulated network error", None, None)
  1380. monkeypatch.setattr(
  1381. sklearn.datasets._openml, "urlopen", _mock_urlopen_network_error
  1382. )
  1383. invalid_openml_url = "invalid-url"
  1384. with pytest.warns(
  1385. UserWarning,
  1386. match=re.escape(
  1387. "A network error occurred while downloading"
  1388. f" {_OPENML_PREFIX + invalid_openml_url}. Retrying..."
  1389. ),
  1390. ) as record:
  1391. with pytest.raises(HTTPError, match="Simulated network error"):
  1392. _open_openml_url(invalid_openml_url, None, delay=0)
  1393. assert len(record) == 3
  1394. ###############################################################################
  1395. # Non-regressiont tests
  1396. @pytest.mark.parametrize("gzip_response", [True, False])
  1397. @pytest.mark.parametrize("parser", ("liac-arff", "pandas"))
  1398. def test_fetch_openml_with_ignored_feature(monkeypatch, gzip_response, parser):
  1399. """Check that we can load the "zoo" dataset.
  1400. Non-regression test for:
  1401. https://github.com/scikit-learn/scikit-learn/issues/14340
  1402. """
  1403. if parser == "pandas":
  1404. pytest.importorskip("pandas")
  1405. data_id = 62
  1406. _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response)
  1407. dataset = sklearn.datasets.fetch_openml(
  1408. data_id=data_id, cache=False, as_frame=False, parser=parser
  1409. )
  1410. assert dataset is not None
  1411. # The dataset has 17 features, including 1 ignored (animal),
  1412. # so we assert that we don't have the ignored feature in the final Bunch
  1413. assert dataset["data"].shape == (101, 16)
  1414. assert "animal" not in dataset["feature_names"]
  1415. def test_fetch_openml_strip_quotes(monkeypatch):
  1416. """Check that we strip the single quotes when used as a string delimiter.
  1417. Non-regression test for:
  1418. https://github.com/scikit-learn/scikit-learn/issues/23381
  1419. """
  1420. pd = pytest.importorskip("pandas")
  1421. data_id = 40966
  1422. _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
  1423. common_params = {"as_frame": True, "cache": False, "data_id": data_id}
  1424. mice_pandas = fetch_openml(parser="pandas", **common_params)
  1425. mice_liac_arff = fetch_openml(parser="liac-arff", **common_params)
  1426. pd.testing.assert_series_equal(mice_pandas.target, mice_liac_arff.target)
  1427. assert not mice_pandas.target.str.startswith("'").any()
  1428. assert not mice_pandas.target.str.endswith("'").any()
  1429. # similar behaviour should be observed when the column is not the target
  1430. mice_pandas = fetch_openml(parser="pandas", target_column="NUMB_N", **common_params)
  1431. mice_liac_arff = fetch_openml(
  1432. parser="liac-arff", target_column="NUMB_N", **common_params
  1433. )
  1434. pd.testing.assert_series_equal(
  1435. mice_pandas.frame["class"], mice_liac_arff.frame["class"]
  1436. )
  1437. assert not mice_pandas.frame["class"].str.startswith("'").any()
  1438. assert not mice_pandas.frame["class"].str.endswith("'").any()
  1439. def test_fetch_openml_leading_whitespace(monkeypatch):
  1440. """Check that we can strip leading whitespace in pandas parser.
  1441. Non-regression test for:
  1442. https://github.com/scikit-learn/scikit-learn/issues/25311
  1443. """
  1444. pd = pytest.importorskip("pandas")
  1445. data_id = 1590
  1446. _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
  1447. common_params = {"as_frame": True, "cache": False, "data_id": data_id}
  1448. adult_pandas = fetch_openml(parser="pandas", **common_params)
  1449. adult_liac_arff = fetch_openml(parser="liac-arff", **common_params)
  1450. pd.testing.assert_series_equal(
  1451. adult_pandas.frame["class"], adult_liac_arff.frame["class"]
  1452. )
  1453. def test_fetch_openml_quotechar_escapechar(monkeypatch):
  1454. """Check that we can handle escapechar and single/double quotechar.
  1455. Non-regression test for:
  1456. https://github.com/scikit-learn/scikit-learn/issues/25478
  1457. """
  1458. pd = pytest.importorskip("pandas")
  1459. data_id = 42074
  1460. _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
  1461. common_params = {"as_frame": True, "cache": False, "data_id": data_id}
  1462. adult_pandas = fetch_openml(parser="pandas", **common_params)
  1463. adult_liac_arff = fetch_openml(parser="liac-arff", **common_params)
  1464. pd.testing.assert_frame_equal(adult_pandas.frame, adult_liac_arff.frame)
  1465. ###############################################################################
  1466. # Deprecation-changed parameters
  1467. # TODO(1.4): remove this test
  1468. def test_fetch_openml_deprecation_parser(monkeypatch):
  1469. """Check that we raise a deprecation warning for parser parameter."""
  1470. pytest.importorskip("pandas")
  1471. data_id = 61
  1472. _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
  1473. with pytest.warns(FutureWarning, match="The default value of `parser` will change"):
  1474. sklearn.datasets.fetch_openml(data_id=data_id)