test_seq_dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Author: Tom Dupre la Tour
  2. # Joan Massich <mailsik@gmail.com>
  3. #
  4. # License: BSD 3 clause
  5. import numpy as np
  6. import pytest
  7. import scipy.sparse as sp
  8. from numpy.testing import assert_array_equal
  9. from sklearn.datasets import load_iris
  10. from sklearn.utils._seq_dataset import (
  11. ArrayDataset32,
  12. ArrayDataset64,
  13. CSRDataset32,
  14. CSRDataset64,
  15. )
  16. from sklearn.utils._testing import assert_allclose
  17. iris = load_iris()
  18. X64 = iris.data.astype(np.float64)
  19. y64 = iris.target.astype(np.float64)
  20. X_csr64 = sp.csr_matrix(X64)
  21. sample_weight64 = np.arange(y64.size, dtype=np.float64)
  22. X32 = iris.data.astype(np.float32)
  23. y32 = iris.target.astype(np.float32)
  24. X_csr32 = sp.csr_matrix(X32)
  25. sample_weight32 = np.arange(y32.size, dtype=np.float32)
  26. def assert_csr_equal_values(current, expected):
  27. current.eliminate_zeros()
  28. expected.eliminate_zeros()
  29. expected = expected.astype(current.dtype)
  30. assert current.shape[0] == expected.shape[0]
  31. assert current.shape[1] == expected.shape[1]
  32. assert_array_equal(current.data, expected.data)
  33. assert_array_equal(current.indices, expected.indices)
  34. assert_array_equal(current.indptr, expected.indptr)
  35. def make_dense_dataset_32():
  36. return ArrayDataset32(X32, y32, sample_weight32, seed=42)
  37. def make_dense_dataset_64():
  38. return ArrayDataset64(X64, y64, sample_weight64, seed=42)
  39. def make_sparse_dataset_32():
  40. return CSRDataset32(
  41. X_csr32.data, X_csr32.indptr, X_csr32.indices, y32, sample_weight32, seed=42
  42. )
  43. def make_sparse_dataset_64():
  44. return CSRDataset64(
  45. X_csr64.data, X_csr64.indptr, X_csr64.indices, y64, sample_weight64, seed=42
  46. )
  47. @pytest.mark.parametrize(
  48. "dataset_constructor",
  49. [
  50. make_dense_dataset_32,
  51. make_dense_dataset_64,
  52. make_sparse_dataset_32,
  53. make_sparse_dataset_64,
  54. ],
  55. )
  56. def test_seq_dataset_basic_iteration(dataset_constructor):
  57. NUMBER_OF_RUNS = 5
  58. dataset = dataset_constructor()
  59. for _ in range(NUMBER_OF_RUNS):
  60. # next sample
  61. xi_, yi, swi, idx = dataset._next_py()
  62. xi = sp.csr_matrix((xi_), shape=(1, X64.shape[1]))
  63. assert_csr_equal_values(xi, X_csr64[idx])
  64. assert yi == y64[idx]
  65. assert swi == sample_weight64[idx]
  66. # random sample
  67. xi_, yi, swi, idx = dataset._random_py()
  68. xi = sp.csr_matrix((xi_), shape=(1, X64.shape[1]))
  69. assert_csr_equal_values(xi, X_csr64[idx])
  70. assert yi == y64[idx]
  71. assert swi == sample_weight64[idx]
  72. @pytest.mark.parametrize(
  73. "make_dense_dataset,make_sparse_dataset",
  74. [
  75. (make_dense_dataset_32, make_sparse_dataset_32),
  76. (make_dense_dataset_64, make_sparse_dataset_64),
  77. ],
  78. )
  79. def test_seq_dataset_shuffle(make_dense_dataset, make_sparse_dataset):
  80. dense_dataset, sparse_dataset = make_dense_dataset(), make_sparse_dataset()
  81. # not shuffled
  82. for i in range(5):
  83. _, _, _, idx1 = dense_dataset._next_py()
  84. _, _, _, idx2 = sparse_dataset._next_py()
  85. assert idx1 == i
  86. assert idx2 == i
  87. for i in [132, 50, 9, 18, 58]:
  88. _, _, _, idx1 = dense_dataset._random_py()
  89. _, _, _, idx2 = sparse_dataset._random_py()
  90. assert idx1 == i
  91. assert idx2 == i
  92. seed = 77
  93. dense_dataset._shuffle_py(seed)
  94. sparse_dataset._shuffle_py(seed)
  95. idx_next = [63, 91, 148, 87, 29]
  96. idx_shuffle = [137, 125, 56, 121, 127]
  97. for i, j in zip(idx_next, idx_shuffle):
  98. _, _, _, idx1 = dense_dataset._next_py()
  99. _, _, _, idx2 = sparse_dataset._next_py()
  100. assert idx1 == i
  101. assert idx2 == i
  102. _, _, _, idx1 = dense_dataset._random_py()
  103. _, _, _, idx2 = sparse_dataset._random_py()
  104. assert idx1 == j
  105. assert idx2 == j
  106. @pytest.mark.parametrize(
  107. "make_dataset_32,make_dataset_64",
  108. [
  109. (make_dense_dataset_32, make_dense_dataset_64),
  110. (make_sparse_dataset_32, make_sparse_dataset_64),
  111. ],
  112. )
  113. def test_fused_types_consistency(make_dataset_32, make_dataset_64):
  114. dataset_32, dataset_64 = make_dataset_32(), make_dataset_64()
  115. NUMBER_OF_RUNS = 5
  116. for _ in range(NUMBER_OF_RUNS):
  117. # next sample
  118. (xi_data32, _, _), yi32, _, _ = dataset_32._next_py()
  119. (xi_data64, _, _), yi64, _, _ = dataset_64._next_py()
  120. assert xi_data32.dtype == np.float32
  121. assert xi_data64.dtype == np.float64
  122. assert_allclose(xi_data64, xi_data32, rtol=1e-5)
  123. assert_allclose(yi64, yi32, rtol=1e-5)
  124. def test_buffer_dtype_mismatch_error():
  125. with pytest.raises(ValueError, match="Buffer dtype mismatch"):
  126. ArrayDataset64(X32, y32, sample_weight32, seed=42),
  127. with pytest.raises(ValueError, match="Buffer dtype mismatch"):
  128. ArrayDataset32(X64, y64, sample_weight64, seed=42),
  129. with pytest.raises(ValueError, match="Buffer dtype mismatch"):
  130. CSRDataset64(
  131. X_csr32.data, X_csr32.indptr, X_csr32.indices, y32, sample_weight32, seed=42
  132. ),
  133. with pytest.raises(ValueError, match="Buffer dtype mismatch"):
  134. CSRDataset32(
  135. X_csr64.data, X_csr64.indptr, X_csr64.indices, y64, sample_weight64, seed=42
  136. ),