test_arff_parser.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import textwrap
  2. from io import BytesIO
  3. import pytest
  4. from sklearn.datasets._arff_parser import (
  5. _liac_arff_parser,
  6. _pandas_arff_parser,
  7. _post_process_frame,
  8. load_arff_from_gzip_file,
  9. )
  10. @pytest.mark.parametrize(
  11. "feature_names, target_names",
  12. [
  13. (
  14. [
  15. "col_int_as_integer",
  16. "col_int_as_numeric",
  17. "col_float_as_real",
  18. "col_float_as_numeric",
  19. ],
  20. ["col_categorical", "col_string"],
  21. ),
  22. (
  23. [
  24. "col_int_as_integer",
  25. "col_int_as_numeric",
  26. "col_float_as_real",
  27. "col_float_as_numeric",
  28. ],
  29. ["col_categorical"],
  30. ),
  31. (
  32. [
  33. "col_int_as_integer",
  34. "col_int_as_numeric",
  35. "col_float_as_real",
  36. "col_float_as_numeric",
  37. ],
  38. [],
  39. ),
  40. ],
  41. )
  42. def test_post_process_frame(feature_names, target_names):
  43. """Check the behaviour of the post-processing function for splitting a dataframe."""
  44. pd = pytest.importorskip("pandas")
  45. X_original = pd.DataFrame(
  46. {
  47. "col_int_as_integer": [1, 2, 3],
  48. "col_int_as_numeric": [1, 2, 3],
  49. "col_float_as_real": [1.0, 2.0, 3.0],
  50. "col_float_as_numeric": [1.0, 2.0, 3.0],
  51. "col_categorical": ["a", "b", "c"],
  52. "col_string": ["a", "b", "c"],
  53. }
  54. )
  55. X, y = _post_process_frame(X_original, feature_names, target_names)
  56. assert isinstance(X, pd.DataFrame)
  57. if len(target_names) >= 2:
  58. assert isinstance(y, pd.DataFrame)
  59. elif len(target_names) == 1:
  60. assert isinstance(y, pd.Series)
  61. else:
  62. assert y is None
  63. def test_load_arff_from_gzip_file_error_parser():
  64. """An error will be raised if the parser is not known."""
  65. # None of the input parameters are required to be accurate since the check
  66. # of the parser will be carried out first.
  67. err_msg = "Unknown parser: 'xxx'. Should be 'liac-arff' or 'pandas'"
  68. with pytest.raises(ValueError, match=err_msg):
  69. load_arff_from_gzip_file("xxx", "xxx", "xxx", "xxx", "xxx", "xxx")
  70. @pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
  71. def test_pandas_arff_parser_strip_single_quotes(parser_func):
  72. """Check that we properly strip single quotes from the data."""
  73. pd = pytest.importorskip("pandas")
  74. arff_file = BytesIO(textwrap.dedent("""
  75. @relation 'toy'
  76. @attribute 'cat_single_quote' {'A', 'B', 'C'}
  77. @attribute 'str_single_quote' string
  78. @attribute 'str_nested_quote' string
  79. @attribute 'class' numeric
  80. @data
  81. 'A','some text','\"expect double quotes\"',0
  82. """).encode("utf-8"))
  83. columns_info = {
  84. "cat_single_quote": {
  85. "data_type": "nominal",
  86. "name": "cat_single_quote",
  87. },
  88. "str_single_quote": {
  89. "data_type": "string",
  90. "name": "str_single_quote",
  91. },
  92. "str_nested_quote": {
  93. "data_type": "string",
  94. "name": "str_nested_quote",
  95. },
  96. "class": {
  97. "data_type": "numeric",
  98. "name": "class",
  99. },
  100. }
  101. feature_names = [
  102. "cat_single_quote",
  103. "str_single_quote",
  104. "str_nested_quote",
  105. ]
  106. target_names = ["class"]
  107. # We don't strip single quotes for string columns with the pandas parser.
  108. expected_values = {
  109. "cat_single_quote": "A",
  110. "str_single_quote": (
  111. "some text" if parser_func is _liac_arff_parser else "'some text'"
  112. ),
  113. "str_nested_quote": (
  114. '"expect double quotes"'
  115. if parser_func is _liac_arff_parser
  116. else "'\"expect double quotes\"'"
  117. ),
  118. "class": 0,
  119. }
  120. _, _, frame, _ = parser_func(
  121. arff_file,
  122. output_arrays_type="pandas",
  123. openml_columns_info=columns_info,
  124. feature_names_to_select=feature_names,
  125. target_names_to_select=target_names,
  126. )
  127. assert frame.columns.tolist() == feature_names + target_names
  128. pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
  129. @pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
  130. def test_pandas_arff_parser_strip_double_quotes(parser_func):
  131. """Check that we properly strip double quotes from the data."""
  132. pd = pytest.importorskip("pandas")
  133. arff_file = BytesIO(textwrap.dedent("""
  134. @relation 'toy'
  135. @attribute 'cat_double_quote' {"A", "B", "C"}
  136. @attribute 'str_double_quote' string
  137. @attribute 'str_nested_quote' string
  138. @attribute 'class' numeric
  139. @data
  140. "A","some text","\'expect double quotes\'",0
  141. """).encode("utf-8"))
  142. columns_info = {
  143. "cat_double_quote": {
  144. "data_type": "nominal",
  145. "name": "cat_double_quote",
  146. },
  147. "str_double_quote": {
  148. "data_type": "string",
  149. "name": "str_double_quote",
  150. },
  151. "str_nested_quote": {
  152. "data_type": "string",
  153. "name": "str_nested_quote",
  154. },
  155. "class": {
  156. "data_type": "numeric",
  157. "name": "class",
  158. },
  159. }
  160. feature_names = [
  161. "cat_double_quote",
  162. "str_double_quote",
  163. "str_nested_quote",
  164. ]
  165. target_names = ["class"]
  166. expected_values = {
  167. "cat_double_quote": "A",
  168. "str_double_quote": "some text",
  169. "str_nested_quote": "'expect double quotes'",
  170. "class": 0,
  171. }
  172. _, _, frame, _ = parser_func(
  173. arff_file,
  174. output_arrays_type="pandas",
  175. openml_columns_info=columns_info,
  176. feature_names_to_select=feature_names,
  177. target_names_to_select=target_names,
  178. )
  179. assert frame.columns.tolist() == feature_names + target_names
  180. pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
  181. @pytest.mark.parametrize(
  182. "parser_func",
  183. [
  184. # internal quotes are not considered to follow the ARFF spec in LIAC ARFF
  185. pytest.param(_liac_arff_parser, marks=pytest.mark.xfail),
  186. _pandas_arff_parser,
  187. ],
  188. )
  189. def test_pandas_arff_parser_strip_no_quotes(parser_func):
  190. """Check that we properly parse with no quotes characters."""
  191. pd = pytest.importorskip("pandas")
  192. arff_file = BytesIO(textwrap.dedent("""
  193. @relation 'toy'
  194. @attribute 'cat_without_quote' {A, B, C}
  195. @attribute 'str_without_quote' string
  196. @attribute 'str_internal_quote' string
  197. @attribute 'class' numeric
  198. @data
  199. A,some text,'internal' quote,0
  200. """).encode("utf-8"))
  201. columns_info = {
  202. "cat_without_quote": {
  203. "data_type": "nominal",
  204. "name": "cat_without_quote",
  205. },
  206. "str_without_quote": {
  207. "data_type": "string",
  208. "name": "str_without_quote",
  209. },
  210. "str_internal_quote": {
  211. "data_type": "string",
  212. "name": "str_internal_quote",
  213. },
  214. "class": {
  215. "data_type": "numeric",
  216. "name": "class",
  217. },
  218. }
  219. feature_names = [
  220. "cat_without_quote",
  221. "str_without_quote",
  222. "str_internal_quote",
  223. ]
  224. target_names = ["class"]
  225. expected_values = {
  226. "cat_without_quote": "A",
  227. "str_without_quote": "some text",
  228. "str_internal_quote": "'internal' quote",
  229. "class": 0,
  230. }
  231. _, _, frame, _ = parser_func(
  232. arff_file,
  233. output_arrays_type="pandas",
  234. openml_columns_info=columns_info,
  235. feature_names_to_select=feature_names,
  236. target_names_to_select=target_names,
  237. )
  238. assert frame.columns.tolist() == feature_names + target_names
  239. pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))