| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import numpy as np
- import pytest
- from numpy.testing import assert_allclose, assert_array_equal
- from sklearn._loss.link import (
- _LINKS,
- HalfLogitLink,
- Interval,
- MultinomialLogit,
- _inclusive_low_high,
- )
- LINK_FUNCTIONS = list(_LINKS.values())
- def test_interval_raises():
- """Test that interval with low > high raises ValueError."""
- with pytest.raises(
- ValueError, match="One must have low <= high; got low=1, high=0."
- ):
- Interval(1, 0, False, False)
- @pytest.mark.parametrize(
- "interval",
- [
- Interval(0, 1, False, False),
- Interval(0, 1, False, True),
- Interval(0, 1, True, False),
- Interval(0, 1, True, True),
- Interval(-np.inf, np.inf, False, False),
- Interval(-np.inf, np.inf, False, True),
- Interval(-np.inf, np.inf, True, False),
- Interval(-np.inf, np.inf, True, True),
- Interval(-10, -1, False, False),
- Interval(-10, -1, False, True),
- Interval(-10, -1, True, False),
- Interval(-10, -1, True, True),
- ],
- )
- def test_is_in_range(interval):
- # make sure low and high are always within the interval, used for linspace
- low, high = _inclusive_low_high(interval)
- x = np.linspace(low, high, num=10)
- assert interval.includes(x)
- # x contains lower bound
- assert interval.includes(np.r_[x, interval.low]) == interval.low_inclusive
- # x contains upper bound
- assert interval.includes(np.r_[x, interval.high]) == interval.high_inclusive
- # x contains upper and lower bound
- assert interval.includes(np.r_[x, interval.low, interval.high]) == (
- interval.low_inclusive and interval.high_inclusive
- )
- @pytest.mark.parametrize("link", LINK_FUNCTIONS)
- def test_link_inverse_identity(link, global_random_seed):
- # Test that link of inverse gives identity.
- rng = np.random.RandomState(global_random_seed)
- link = link()
- n_samples, n_classes = 100, None
- # The values for `raw_prediction` are limited from -20 to 20 because in the
- # class `LogitLink` the term `expit(x)` comes very close to 1 for large
- # positive x and therefore loses precision.
- if link.is_multiclass:
- n_classes = 10
- raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples, n_classes))
- if isinstance(link, MultinomialLogit):
- raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
- elif isinstance(link, HalfLogitLink):
- raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))
- else:
- raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples))
- assert_allclose(link.link(link.inverse(raw_prediction)), raw_prediction)
- y_pred = link.inverse(raw_prediction)
- assert_allclose(link.inverse(link.link(y_pred)), y_pred)
- @pytest.mark.parametrize("link", LINK_FUNCTIONS)
- def test_link_out_argument(link):
- # Test that out argument gets assigned the result.
- rng = np.random.RandomState(42)
- link = link()
- n_samples, n_classes = 100, None
- if link.is_multiclass:
- n_classes = 10
- raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes))
- if isinstance(link, MultinomialLogit):
- raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
- else:
- # So far, the valid interval of raw_prediction is (-inf, inf) and
- # we do not need to distinguish.
- raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))
- y_pred = link.inverse(raw_prediction, out=None)
- out = np.empty_like(raw_prediction)
- y_pred_2 = link.inverse(raw_prediction, out=out)
- assert_allclose(y_pred, out)
- assert_array_equal(out, y_pred_2)
- assert np.shares_memory(out, y_pred_2)
- out = np.empty_like(y_pred)
- raw_prediction_2 = link.link(y_pred, out=out)
- assert_allclose(raw_prediction, out)
- assert_array_equal(out, raw_prediction_2)
- assert np.shares_memory(out, raw_prediction_2)
|