test_weight_vector.py 665 B

12345678910111213141516171819202122232425
  1. import numpy as np
  2. import pytest
  3. from sklearn.utils._weight_vector import (
  4. WeightVector32,
  5. WeightVector64,
  6. )
  7. @pytest.mark.parametrize(
  8. "dtype, WeightVector",
  9. [
  10. (np.float32, WeightVector32),
  11. (np.float64, WeightVector64),
  12. ],
  13. )
  14. def test_type_invariance(dtype, WeightVector):
  15. """Check the `dtype` consistency of `WeightVector`."""
  16. weights = np.random.rand(100).astype(dtype)
  17. average_weights = np.random.rand(100).astype(dtype)
  18. weight_vector = WeightVector(weights, average_weights)
  19. assert np.asarray(weight_vector.w).dtype is np.dtype(dtype)
  20. assert np.asarray(weight_vector.aw).dtype is np.dtype(dtype)