test_perceptron.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import numpy as np
  2. import pytest
  3. import scipy.sparse as sp
  4. from sklearn.datasets import load_iris
  5. from sklearn.linear_model import Perceptron
  6. from sklearn.utils import check_random_state
  7. from sklearn.utils._testing import assert_allclose, assert_array_almost_equal
  8. iris = load_iris()
  9. random_state = check_random_state(12)
  10. indices = np.arange(iris.data.shape[0])
  11. random_state.shuffle(indices)
  12. X = iris.data[indices]
  13. y = iris.target[indices]
  14. X_csr = sp.csr_matrix(X)
  15. X_csr.sort_indices()
  16. class MyPerceptron:
  17. def __init__(self, n_iter=1):
  18. self.n_iter = n_iter
  19. def fit(self, X, y):
  20. n_samples, n_features = X.shape
  21. self.w = np.zeros(n_features, dtype=np.float64)
  22. self.b = 0.0
  23. for t in range(self.n_iter):
  24. for i in range(n_samples):
  25. if self.predict(X[i])[0] != y[i]:
  26. self.w += y[i] * X[i]
  27. self.b += y[i]
  28. def project(self, X):
  29. return np.dot(X, self.w) + self.b
  30. def predict(self, X):
  31. X = np.atleast_2d(X)
  32. return np.sign(self.project(X))
  33. def test_perceptron_accuracy():
  34. for data in (X, X_csr):
  35. clf = Perceptron(max_iter=100, tol=None, shuffle=False)
  36. clf.fit(data, y)
  37. score = clf.score(data, y)
  38. assert score > 0.7
  39. def test_perceptron_correctness():
  40. y_bin = y.copy()
  41. y_bin[y != 1] = -1
  42. clf1 = MyPerceptron(n_iter=2)
  43. clf1.fit(X, y_bin)
  44. clf2 = Perceptron(max_iter=2, shuffle=False, tol=None)
  45. clf2.fit(X, y_bin)
  46. assert_array_almost_equal(clf1.w, clf2.coef_.ravel())
  47. def test_undefined_methods():
  48. clf = Perceptron(max_iter=100)
  49. for meth in ("predict_proba", "predict_log_proba"):
  50. with pytest.raises(AttributeError):
  51. getattr(clf, meth)
  52. def test_perceptron_l1_ratio():
  53. """Check that `l1_ratio` has an impact when `penalty='elasticnet'`"""
  54. clf1 = Perceptron(l1_ratio=0, penalty="elasticnet")
  55. clf1.fit(X, y)
  56. clf2 = Perceptron(l1_ratio=0.15, penalty="elasticnet")
  57. clf2.fit(X, y)
  58. assert clf1.score(X, y) != clf2.score(X, y)
  59. # check that the bounds of elastic net which should correspond to an l1 or
  60. # l2 penalty depending of `l1_ratio` value.
  61. clf_l1 = Perceptron(penalty="l1").fit(X, y)
  62. clf_elasticnet = Perceptron(l1_ratio=1, penalty="elasticnet").fit(X, y)
  63. assert_allclose(clf_l1.coef_, clf_elasticnet.coef_)
  64. clf_l2 = Perceptron(penalty="l2").fit(X, y)
  65. clf_elasticnet = Perceptron(l1_ratio=0, penalty="elasticnet").fit(X, y)
  66. assert_allclose(clf_l2.coef_, clf_elasticnet.coef_)