test_link.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import numpy as np
  2. import pytest
  3. from numpy.testing import assert_allclose, assert_array_equal
  4. from sklearn._loss.link import (
  5. _LINKS,
  6. HalfLogitLink,
  7. Interval,
  8. MultinomialLogit,
  9. _inclusive_low_high,
  10. )
  11. LINK_FUNCTIONS = list(_LINKS.values())
  12. def test_interval_raises():
  13. """Test that interval with low > high raises ValueError."""
  14. with pytest.raises(
  15. ValueError, match="One must have low <= high; got low=1, high=0."
  16. ):
  17. Interval(1, 0, False, False)
  18. @pytest.mark.parametrize(
  19. "interval",
  20. [
  21. Interval(0, 1, False, False),
  22. Interval(0, 1, False, True),
  23. Interval(0, 1, True, False),
  24. Interval(0, 1, True, True),
  25. Interval(-np.inf, np.inf, False, False),
  26. Interval(-np.inf, np.inf, False, True),
  27. Interval(-np.inf, np.inf, True, False),
  28. Interval(-np.inf, np.inf, True, True),
  29. Interval(-10, -1, False, False),
  30. Interval(-10, -1, False, True),
  31. Interval(-10, -1, True, False),
  32. Interval(-10, -1, True, True),
  33. ],
  34. )
  35. def test_is_in_range(interval):
  36. # make sure low and high are always within the interval, used for linspace
  37. low, high = _inclusive_low_high(interval)
  38. x = np.linspace(low, high, num=10)
  39. assert interval.includes(x)
  40. # x contains lower bound
  41. assert interval.includes(np.r_[x, interval.low]) == interval.low_inclusive
  42. # x contains upper bound
  43. assert interval.includes(np.r_[x, interval.high]) == interval.high_inclusive
  44. # x contains upper and lower bound
  45. assert interval.includes(np.r_[x, interval.low, interval.high]) == (
  46. interval.low_inclusive and interval.high_inclusive
  47. )
  48. @pytest.mark.parametrize("link", LINK_FUNCTIONS)
  49. def test_link_inverse_identity(link, global_random_seed):
  50. # Test that link of inverse gives identity.
  51. rng = np.random.RandomState(global_random_seed)
  52. link = link()
  53. n_samples, n_classes = 100, None
  54. # The values for `raw_prediction` are limited from -20 to 20 because in the
  55. # class `LogitLink` the term `expit(x)` comes very close to 1 for large
  56. # positive x and therefore loses precision.
  57. if link.is_multiclass:
  58. n_classes = 10
  59. raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples, n_classes))
  60. if isinstance(link, MultinomialLogit):
  61. raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
  62. elif isinstance(link, HalfLogitLink):
  63. raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))
  64. else:
  65. raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples))
  66. assert_allclose(link.link(link.inverse(raw_prediction)), raw_prediction)
  67. y_pred = link.inverse(raw_prediction)
  68. assert_allclose(link.inverse(link.link(y_pred)), y_pred)
  69. @pytest.mark.parametrize("link", LINK_FUNCTIONS)
  70. def test_link_out_argument(link):
  71. # Test that out argument gets assigned the result.
  72. rng = np.random.RandomState(42)
  73. link = link()
  74. n_samples, n_classes = 100, None
  75. if link.is_multiclass:
  76. n_classes = 10
  77. raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes))
  78. if isinstance(link, MultinomialLogit):
  79. raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
  80. else:
  81. # So far, the valid interval of raw_prediction is (-inf, inf) and
  82. # we do not need to distinguish.
  83. raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))
  84. y_pred = link.inverse(raw_prediction, out=None)
  85. out = np.empty_like(raw_prediction)
  86. y_pred_2 = link.inverse(raw_prediction, out=out)
  87. assert_allclose(y_pred, out)
  88. assert_array_equal(out, y_pred_2)
  89. assert np.shares_memory(out, y_pred_2)
  90. out = np.empty_like(y_pred)
  91. raw_prediction_2 = link.link(y_pred, out=out)
  92. assert_allclose(raw_prediction, out)
  93. assert_array_equal(out, raw_prediction_2)
  94. assert np.shares_memory(out, raw_prediction_2)