_svmlight_format_io.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. """This module implements a loader and dumper for the svmlight format
  2. This format is a text-based format, with one sample per line. It does
  3. not store zero valued features hence is suitable for sparse dataset.
  4. The first element of each line can be used to store a target variable to
  5. predict.
  6. This format is used as the default format for both svmlight and the
  7. libsvm command line programs.
  8. """
  9. # Authors: Mathieu Blondel <mathieu@mblondel.org>
  10. # Lars Buitinck
  11. # Olivier Grisel <olivier.grisel@ensta.org>
  12. # License: BSD 3 clause
  13. import os.path
  14. from contextlib import closing
  15. from numbers import Integral
  16. import numpy as np
  17. import scipy.sparse as sp
  18. from .. import __version__
  19. from ..utils import IS_PYPY, check_array
  20. from ..utils._param_validation import HasMethods, Interval, StrOptions, validate_params
  21. if not IS_PYPY:
  22. from ._svmlight_format_fast import (
  23. _dump_svmlight_file,
  24. _load_svmlight_file,
  25. )
  26. else:
  27. def _load_svmlight_file(*args, **kwargs):
  28. raise NotImplementedError(
  29. "load_svmlight_file is currently not "
  30. "compatible with PyPy (see "
  31. "https://github.com/scikit-learn/scikit-learn/issues/11543 "
  32. "for the status updates)."
  33. )
  34. @validate_params(
  35. {
  36. "f": [
  37. str,
  38. Interval(Integral, 0, None, closed="left"),
  39. os.PathLike,
  40. HasMethods("read"),
  41. ],
  42. "n_features": [Interval(Integral, 1, None, closed="left"), None],
  43. "dtype": "no_validation", # delegate validation to numpy
  44. "multilabel": ["boolean"],
  45. "zero_based": ["boolean", StrOptions({"auto"})],
  46. "query_id": ["boolean"],
  47. "offset": [Interval(Integral, 0, None, closed="left")],
  48. "length": [Integral],
  49. },
  50. prefer_skip_nested_validation=True,
  51. )
  52. def load_svmlight_file(
  53. f,
  54. *,
  55. n_features=None,
  56. dtype=np.float64,
  57. multilabel=False,
  58. zero_based="auto",
  59. query_id=False,
  60. offset=0,
  61. length=-1,
  62. ):
  63. """Load datasets in the svmlight / libsvm format into sparse CSR matrix.
  64. This format is a text-based format, with one sample per line. It does
  65. not store zero valued features hence is suitable for sparse dataset.
  66. The first element of each line can be used to store a target variable
  67. to predict.
  68. This format is used as the default format for both svmlight and the
  69. libsvm command line programs.
  70. Parsing a text based source can be expensive. When repeatedly
  71. working on the same dataset, it is recommended to wrap this
  72. loader with joblib.Memory.cache to store a memmapped backup of the
  73. CSR results of the first call and benefit from the near instantaneous
  74. loading of memmapped structures for the subsequent calls.
  75. In case the file contains a pairwise preference constraint (known
  76. as "qid" in the svmlight format) these are ignored unless the
  77. query_id parameter is set to True. These pairwise preference
  78. constraints can be used to constraint the combination of samples
  79. when using pairwise loss functions (as is the case in some
  80. learning to rank problems) so that only pairs with the same
  81. query_id value are considered.
  82. This implementation is written in Cython and is reasonably fast.
  83. However, a faster API-compatible loader is also available at:
  84. https://github.com/mblondel/svmlight-loader
  85. Parameters
  86. ----------
  87. f : str, path-like, file-like or int
  88. (Path to) a file to load. If a path ends in ".gz" or ".bz2", it will
  89. be uncompressed on the fly. If an integer is passed, it is assumed to
  90. be a file descriptor. A file-like or file descriptor will not be closed
  91. by this function. A file-like object must be opened in binary mode.
  92. .. versionchanged:: 1.2
  93. Path-like objects are now accepted.
  94. n_features : int, default=None
  95. The number of features to use. If None, it will be inferred. This
  96. argument is useful to load several files that are subsets of a
  97. bigger sliced dataset: each subset might not have examples of
  98. every feature, hence the inferred shape might vary from one
  99. slice to another.
  100. n_features is only required if ``offset`` or ``length`` are passed a
  101. non-default value.
  102. dtype : numpy data type, default=np.float64
  103. Data type of dataset to be loaded. This will be the data type of the
  104. output numpy arrays ``X`` and ``y``.
  105. multilabel : bool, default=False
  106. Samples may have several labels each (see
  107. https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html).
  108. zero_based : bool or "auto", default="auto"
  109. Whether column indices in f are zero-based (True) or one-based
  110. (False). If column indices are one-based, they are transformed to
  111. zero-based to match Python/NumPy conventions.
  112. If set to "auto", a heuristic check is applied to determine this from
  113. the file contents. Both kinds of files occur "in the wild", but they
  114. are unfortunately not self-identifying. Using "auto" or True should
  115. always be safe when no ``offset`` or ``length`` is passed.
  116. If ``offset`` or ``length`` are passed, the "auto" mode falls back
  117. to ``zero_based=True`` to avoid having the heuristic check yield
  118. inconsistent results on different segments of the file.
  119. query_id : bool, default=False
  120. If True, will return the query_id array for each file.
  121. offset : int, default=0
  122. Ignore the offset first bytes by seeking forward, then
  123. discarding the following bytes up until the next new line
  124. character.
  125. length : int, default=-1
  126. If strictly positive, stop reading any new line of data once the
  127. position in the file has reached the (offset + length) bytes threshold.
  128. Returns
  129. -------
  130. X : scipy.sparse matrix of shape (n_samples, n_features)
  131. The data matrix.
  132. y : ndarray of shape (n_samples,), or a list of tuples of length n_samples
  133. The target. It is a list of tuples when ``multilabel=True``, else a
  134. ndarray.
  135. query_id : array of shape (n_samples,)
  136. The query_id for each sample. Only returned when query_id is set to
  137. True.
  138. See Also
  139. --------
  140. load_svmlight_files : Similar function for loading multiple files in this
  141. format, enforcing the same number of features/columns on all of them.
  142. Examples
  143. --------
  144. To use joblib.Memory to cache the svmlight file::
  145. from joblib import Memory
  146. from .datasets import load_svmlight_file
  147. mem = Memory("./mycache")
  148. @mem.cache
  149. def get_data():
  150. data = load_svmlight_file("mysvmlightfile")
  151. return data[0], data[1]
  152. X, y = get_data()
  153. """
  154. return tuple(
  155. load_svmlight_files(
  156. [f],
  157. n_features=n_features,
  158. dtype=dtype,
  159. multilabel=multilabel,
  160. zero_based=zero_based,
  161. query_id=query_id,
  162. offset=offset,
  163. length=length,
  164. )
  165. )
  166. def _gen_open(f):
  167. if isinstance(f, int): # file descriptor
  168. return open(f, "rb", closefd=False)
  169. elif isinstance(f, os.PathLike):
  170. f = os.fspath(f)
  171. elif not isinstance(f, str):
  172. raise TypeError("expected {str, int, path-like, file-like}, got %s" % type(f))
  173. _, ext = os.path.splitext(f)
  174. if ext == ".gz":
  175. import gzip
  176. return gzip.open(f, "rb")
  177. elif ext == ".bz2":
  178. from bz2 import BZ2File
  179. return BZ2File(f, "rb")
  180. else:
  181. return open(f, "rb")
  182. def _open_and_load(f, dtype, multilabel, zero_based, query_id, offset=0, length=-1):
  183. if hasattr(f, "read"):
  184. actual_dtype, data, ind, indptr, labels, query = _load_svmlight_file(
  185. f, dtype, multilabel, zero_based, query_id, offset, length
  186. )
  187. else:
  188. with closing(_gen_open(f)) as f:
  189. actual_dtype, data, ind, indptr, labels, query = _load_svmlight_file(
  190. f, dtype, multilabel, zero_based, query_id, offset, length
  191. )
  192. # convert from array.array, give data the right dtype
  193. if not multilabel:
  194. labels = np.frombuffer(labels, np.float64)
  195. data = np.frombuffer(data, actual_dtype)
  196. indices = np.frombuffer(ind, np.longlong)
  197. indptr = np.frombuffer(indptr, dtype=np.longlong) # never empty
  198. query = np.frombuffer(query, np.int64)
  199. data = np.asarray(data, dtype=dtype) # no-op for float{32,64}
  200. return data, indices, indptr, labels, query
  201. @validate_params(
  202. {
  203. "files": [
  204. "array-like",
  205. str,
  206. os.PathLike,
  207. HasMethods("read"),
  208. Interval(Integral, 0, None, closed="left"),
  209. ],
  210. "n_features": [Interval(Integral, 1, None, closed="left"), None],
  211. "dtype": "no_validation", # delegate validation to numpy
  212. "multilabel": ["boolean"],
  213. "zero_based": ["boolean", StrOptions({"auto"})],
  214. "query_id": ["boolean"],
  215. "offset": [Interval(Integral, 0, None, closed="left")],
  216. "length": [Integral],
  217. },
  218. prefer_skip_nested_validation=True,
  219. )
  220. def load_svmlight_files(
  221. files,
  222. *,
  223. n_features=None,
  224. dtype=np.float64,
  225. multilabel=False,
  226. zero_based="auto",
  227. query_id=False,
  228. offset=0,
  229. length=-1,
  230. ):
  231. """Load dataset from multiple files in SVMlight format.
  232. This function is equivalent to mapping load_svmlight_file over a list of
  233. files, except that the results are concatenated into a single, flat list
  234. and the samples vectors are constrained to all have the same number of
  235. features.
  236. In case the file contains a pairwise preference constraint (known
  237. as "qid" in the svmlight format) these are ignored unless the
  238. query_id parameter is set to True. These pairwise preference
  239. constraints can be used to constraint the combination of samples
  240. when using pairwise loss functions (as is the case in some
  241. learning to rank problems) so that only pairs with the same
  242. query_id value are considered.
  243. Parameters
  244. ----------
  245. files : array-like, dtype=str, path-like, file-like or int
  246. (Paths of) files to load. If a path ends in ".gz" or ".bz2", it will
  247. be uncompressed on the fly. If an integer is passed, it is assumed to
  248. be a file descriptor. File-likes and file descriptors will not be
  249. closed by this function. File-like objects must be opened in binary
  250. mode.
  251. .. versionchanged:: 1.2
  252. Path-like objects are now accepted.
  253. n_features : int, default=None
  254. The number of features to use. If None, it will be inferred from the
  255. maximum column index occurring in any of the files.
  256. This can be set to a higher value than the actual number of features
  257. in any of the input files, but setting it to a lower value will cause
  258. an exception to be raised.
  259. dtype : numpy data type, default=np.float64
  260. Data type of dataset to be loaded. This will be the data type of the
  261. output numpy arrays ``X`` and ``y``.
  262. multilabel : bool, default=False
  263. Samples may have several labels each (see
  264. https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html).
  265. zero_based : bool or "auto", default="auto"
  266. Whether column indices in f are zero-based (True) or one-based
  267. (False). If column indices are one-based, they are transformed to
  268. zero-based to match Python/NumPy conventions.
  269. If set to "auto", a heuristic check is applied to determine this from
  270. the file contents. Both kinds of files occur "in the wild", but they
  271. are unfortunately not self-identifying. Using "auto" or True should
  272. always be safe when no offset or length is passed.
  273. If offset or length are passed, the "auto" mode falls back
  274. to zero_based=True to avoid having the heuristic check yield
  275. inconsistent results on different segments of the file.
  276. query_id : bool, default=False
  277. If True, will return the query_id array for each file.
  278. offset : int, default=0
  279. Ignore the offset first bytes by seeking forward, then
  280. discarding the following bytes up until the next new line
  281. character.
  282. length : int, default=-1
  283. If strictly positive, stop reading any new line of data once the
  284. position in the file has reached the (offset + length) bytes threshold.
  285. Returns
  286. -------
  287. [X1, y1, ..., Xn, yn] or [X1, y1, q1, ..., Xn, yn, qn]: list of arrays
  288. Each (Xi, yi) pair is the result from load_svmlight_file(files[i]).
  289. If query_id is set to True, this will return instead (Xi, yi, qi)
  290. triplets.
  291. See Also
  292. --------
  293. load_svmlight_file: Similar function for loading a single file in this
  294. format.
  295. Notes
  296. -----
  297. When fitting a model to a matrix X_train and evaluating it against a
  298. matrix X_test, it is essential that X_train and X_test have the same
  299. number of features (X_train.shape[1] == X_test.shape[1]). This may not
  300. be the case if you load the files individually with load_svmlight_file.
  301. """
  302. if (offset != 0 or length > 0) and zero_based == "auto":
  303. # disable heuristic search to avoid getting inconsistent results on
  304. # different segments of the file
  305. zero_based = True
  306. if (offset != 0 or length > 0) and n_features is None:
  307. raise ValueError("n_features is required when offset or length is specified.")
  308. r = [
  309. _open_and_load(
  310. f,
  311. dtype,
  312. multilabel,
  313. bool(zero_based),
  314. bool(query_id),
  315. offset=offset,
  316. length=length,
  317. )
  318. for f in files
  319. ]
  320. if (
  321. zero_based is False
  322. or zero_based == "auto"
  323. and all(len(tmp[1]) and np.min(tmp[1]) > 0 for tmp in r)
  324. ):
  325. for _, indices, _, _, _ in r:
  326. indices -= 1
  327. n_f = max(ind[1].max() if len(ind[1]) else 0 for ind in r) + 1
  328. if n_features is None:
  329. n_features = n_f
  330. elif n_features < n_f:
  331. raise ValueError(
  332. "n_features was set to {}, but input file contains {} features".format(
  333. n_features, n_f
  334. )
  335. )
  336. result = []
  337. for data, indices, indptr, y, query_values in r:
  338. shape = (indptr.shape[0] - 1, n_features)
  339. X = sp.csr_matrix((data, indices, indptr), shape)
  340. X.sort_indices()
  341. result += X, y
  342. if query_id:
  343. result.append(query_values)
  344. return result
  345. def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
  346. if comment:
  347. f.write(
  348. (
  349. "# Generated by dump_svmlight_file from scikit-learn %s\n" % __version__
  350. ).encode()
  351. )
  352. f.write(
  353. ("# Column indices are %s-based\n" % ["zero", "one"][one_based]).encode()
  354. )
  355. f.write(b"#\n")
  356. f.writelines(b"# %s\n" % line for line in comment.splitlines())
  357. X_is_sp = sp.issparse(X)
  358. y_is_sp = sp.issparse(y)
  359. if not multilabel and not y_is_sp:
  360. y = y[:, np.newaxis]
  361. _dump_svmlight_file(
  362. X,
  363. y,
  364. f,
  365. multilabel,
  366. one_based,
  367. query_id,
  368. X_is_sp,
  369. y_is_sp,
  370. )
  371. @validate_params(
  372. {
  373. "X": ["array-like", "sparse matrix"],
  374. "y": ["array-like", "sparse matrix"],
  375. "f": [str, HasMethods(["write"])],
  376. "zero_based": ["boolean"],
  377. "comment": [str, bytes, None],
  378. "query_id": ["array-like", None],
  379. "multilabel": ["boolean"],
  380. },
  381. prefer_skip_nested_validation=True,
  382. )
  383. def dump_svmlight_file(
  384. X,
  385. y,
  386. f,
  387. *,
  388. zero_based=True,
  389. comment=None,
  390. query_id=None,
  391. multilabel=False,
  392. ):
  393. """Dump the dataset in svmlight / libsvm file format.
  394. This format is a text-based format, with one sample per line. It does
  395. not store zero valued features hence is suitable for sparse dataset.
  396. The first element of each line can be used to store a target variable
  397. to predict.
  398. Parameters
  399. ----------
  400. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  401. Training vectors, where `n_samples` is the number of samples and
  402. `n_features` is the number of features.
  403. y : {array-like, sparse matrix}, shape = (n_samples,) or (n_samples, n_labels)
  404. Target values. Class labels must be an
  405. integer or float, or array-like objects of integer or float for
  406. multilabel classifications.
  407. f : str or file-like in binary mode
  408. If string, specifies the path that will contain the data.
  409. If file-like, data will be written to f. f should be opened in binary
  410. mode.
  411. zero_based : bool, default=True
  412. Whether column indices should be written zero-based (True) or one-based
  413. (False).
  414. comment : str or bytes, default=None
  415. Comment to insert at the top of the file. This should be either a
  416. Unicode string, which will be encoded as UTF-8, or an ASCII byte
  417. string.
  418. If a comment is given, then it will be preceded by one that identifies
  419. the file as having been dumped by scikit-learn. Note that not all
  420. tools grok comments in SVMlight files.
  421. query_id : array-like of shape (n_samples,), default=None
  422. Array containing pairwise preference constraints (qid in svmlight
  423. format).
  424. multilabel : bool, default=False
  425. Samples may have several labels each (see
  426. https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html).
  427. .. versionadded:: 0.17
  428. parameter `multilabel` to support multilabel datasets.
  429. """
  430. if comment is not None:
  431. # Convert comment string to list of lines in UTF-8.
  432. # If a byte string is passed, then check whether it's ASCII;
  433. # if a user wants to get fancy, they'll have to decode themselves.
  434. if isinstance(comment, bytes):
  435. comment.decode("ascii") # just for the exception
  436. else:
  437. comment = comment.encode("utf-8")
  438. if b"\0" in comment:
  439. raise ValueError("comment string contains NUL byte")
  440. yval = check_array(y, accept_sparse="csr", ensure_2d=False)
  441. if sp.issparse(yval):
  442. if yval.shape[1] != 1 and not multilabel:
  443. raise ValueError(
  444. "expected y of shape (n_samples, 1), got %r" % (yval.shape,)
  445. )
  446. else:
  447. if yval.ndim != 1 and not multilabel:
  448. raise ValueError("expected y of shape (n_samples,), got %r" % (yval.shape,))
  449. Xval = check_array(X, accept_sparse="csr")
  450. if Xval.shape[0] != yval.shape[0]:
  451. raise ValueError(
  452. "X.shape[0] and y.shape[0] should be the same, got %r and %r instead."
  453. % (Xval.shape[0], yval.shape[0])
  454. )
  455. # We had some issues with CSR matrices with unsorted indices (e.g. #1501),
  456. # so sort them here, but first make sure we don't modify the user's X.
  457. # TODO We can do this cheaper; sorted_indices copies the whole matrix.
  458. if yval is y and hasattr(yval, "sorted_indices"):
  459. y = yval.sorted_indices()
  460. else:
  461. y = yval
  462. if hasattr(y, "sort_indices"):
  463. y.sort_indices()
  464. if Xval is X and hasattr(Xval, "sorted_indices"):
  465. X = Xval.sorted_indices()
  466. else:
  467. X = Xval
  468. if hasattr(X, "sort_indices"):
  469. X.sort_indices()
  470. if query_id is None:
  471. # NOTE: query_id is passed to Cython functions using a fused type on query_id.
  472. # Yet as of Cython>=3.0, memory views can't be None otherwise the runtime
  473. # would not known which concrete implementation to dispatch the Python call to.
  474. # TODO: simplify interfaces and implementations in _svmlight_format_fast.pyx.
  475. query_id = np.array([], dtype=np.int32)
  476. else:
  477. query_id = np.asarray(query_id)
  478. if query_id.shape[0] != y.shape[0]:
  479. raise ValueError(
  480. "expected query_id of shape (n_samples,), got %r" % (query_id.shape,)
  481. )
  482. one_based = not zero_based
  483. if hasattr(f, "write"):
  484. _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id)
  485. else:
  486. with open(f, "wb") as f:
  487. _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id)