_olivetti_faces.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """Modified Olivetti faces dataset.
  2. The original database was available from (now defunct)
  3. https://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html
  4. The version retrieved here comes in MATLAB format from the personal
  5. web page of Sam Roweis:
  6. https://cs.nyu.edu/~roweis/
  7. """
  8. # Copyright (c) 2011 David Warde-Farley <wardefar at iro dot umontreal dot ca>
  9. # License: BSD 3 clause
  10. from os import PathLike, makedirs, remove
  11. from os.path import exists
  12. import joblib
  13. import numpy as np
  14. from scipy.io import loadmat
  15. from ..utils import Bunch, check_random_state
  16. from ..utils._param_validation import validate_params
  17. from . import get_data_home
  18. from ._base import RemoteFileMetadata, _fetch_remote, _pkl_filepath, load_descr
  19. # The original data can be found at:
  20. # https://cs.nyu.edu/~roweis/data/olivettifaces.mat
  21. FACES = RemoteFileMetadata(
  22. filename="olivettifaces.mat",
  23. url="https://ndownloader.figshare.com/files/5976027",
  24. checksum="b612fb967f2dc77c9c62d3e1266e0c73d5fca46a4b8906c18e454d41af987794",
  25. )
  26. @validate_params(
  27. {
  28. "data_home": [str, PathLike, None],
  29. "shuffle": ["boolean"],
  30. "random_state": ["random_state"],
  31. "download_if_missing": ["boolean"],
  32. "return_X_y": ["boolean"],
  33. },
  34. prefer_skip_nested_validation=True,
  35. )
  36. def fetch_olivetti_faces(
  37. *,
  38. data_home=None,
  39. shuffle=False,
  40. random_state=0,
  41. download_if_missing=True,
  42. return_X_y=False,
  43. ):
  44. """Load the Olivetti faces data-set from AT&T (classification).
  45. Download it if necessary.
  46. ================= =====================
  47. Classes 40
  48. Samples total 400
  49. Dimensionality 4096
  50. Features real, between 0 and 1
  51. ================= =====================
  52. Read more in the :ref:`User Guide <olivetti_faces_dataset>`.
  53. Parameters
  54. ----------
  55. data_home : str or path-like, default=None
  56. Specify another download and cache folder for the datasets. By default
  57. all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
  58. shuffle : bool, default=False
  59. If True the order of the dataset is shuffled to avoid having
  60. images of the same person grouped.
  61. random_state : int, RandomState instance or None, default=0
  62. Determines random number generation for dataset shuffling. Pass an int
  63. for reproducible output across multiple function calls.
  64. See :term:`Glossary <random_state>`.
  65. download_if_missing : bool, default=True
  66. If False, raise an OSError if the data is not locally available
  67. instead of trying to download the data from the source site.
  68. return_X_y : bool, default=False
  69. If True, returns `(data, target)` instead of a `Bunch` object. See
  70. below for more information about the `data` and `target` object.
  71. .. versionadded:: 0.22
  72. Returns
  73. -------
  74. data : :class:`~sklearn.utils.Bunch`
  75. Dictionary-like object, with the following attributes.
  76. data: ndarray, shape (400, 4096)
  77. Each row corresponds to a ravelled
  78. face image of original size 64 x 64 pixels.
  79. images : ndarray, shape (400, 64, 64)
  80. Each row is a face image
  81. corresponding to one of the 40 subjects of the dataset.
  82. target : ndarray, shape (400,)
  83. Labels associated to each face image.
  84. Those labels are ranging from 0-39 and correspond to the
  85. Subject IDs.
  86. DESCR : str
  87. Description of the modified Olivetti Faces Dataset.
  88. (data, target) : tuple if `return_X_y=True`
  89. Tuple with the `data` and `target` objects described above.
  90. .. versionadded:: 0.22
  91. """
  92. data_home = get_data_home(data_home=data_home)
  93. if not exists(data_home):
  94. makedirs(data_home)
  95. filepath = _pkl_filepath(data_home, "olivetti.pkz")
  96. if not exists(filepath):
  97. if not download_if_missing:
  98. raise OSError("Data not found and `download_if_missing` is False")
  99. print("downloading Olivetti faces from %s to %s" % (FACES.url, data_home))
  100. mat_path = _fetch_remote(FACES, dirname=data_home)
  101. mfile = loadmat(file_name=mat_path)
  102. # delete raw .mat data
  103. remove(mat_path)
  104. faces = mfile["faces"].T.copy()
  105. joblib.dump(faces, filepath, compress=6)
  106. del mfile
  107. else:
  108. faces = joblib.load(filepath)
  109. # We want floating point data, but float32 is enough (there is only
  110. # one byte of precision in the original uint8s anyway)
  111. faces = np.float32(faces)
  112. faces = faces - faces.min()
  113. faces /= faces.max()
  114. faces = faces.reshape((400, 64, 64)).transpose(0, 2, 1)
  115. # 10 images per class, 400 images total, each class is contiguous.
  116. target = np.array([i // 10 for i in range(400)])
  117. if shuffle:
  118. random_state = check_random_state(random_state)
  119. order = random_state.permutation(len(faces))
  120. faces = faces[order]
  121. target = target[order]
  122. faces_vectorized = faces.reshape(len(faces), -1)
  123. fdescr = load_descr("olivetti_faces.rst")
  124. if return_X_y:
  125. return faces_vectorized, target
  126. return Bunch(data=faces_vectorized, images=faces, target=target, DESCR=fdescr)