test_factor_analysis.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Author: Christian Osendorfer <osendorf@gmail.com>
  2. # Alexandre Gramfort <alexandre.gramfort@inria.fr>
  3. # License: BSD3
  4. from itertools import combinations
  5. import numpy as np
  6. import pytest
  7. from sklearn.decomposition import FactorAnalysis
  8. from sklearn.decomposition._factor_analysis import _ortho_rotation
  9. from sklearn.exceptions import ConvergenceWarning
  10. from sklearn.utils._testing import (
  11. assert_almost_equal,
  12. assert_array_almost_equal,
  13. ignore_warnings,
  14. )
  15. # Ignore warnings from switching to more power iterations in randomized_svd
  16. @ignore_warnings
  17. def test_factor_analysis():
  18. # Test FactorAnalysis ability to recover the data covariance structure
  19. rng = np.random.RandomState(0)
  20. n_samples, n_features, n_components = 20, 5, 3
  21. # Some random settings for the generative model
  22. W = rng.randn(n_components, n_features)
  23. # latent variable of dim 3, 20 of it
  24. h = rng.randn(n_samples, n_components)
  25. # using gamma to model different noise variance
  26. # per component
  27. noise = rng.gamma(1, size=n_features) * rng.randn(n_samples, n_features)
  28. # generate observations
  29. # wlog, mean is 0
  30. X = np.dot(h, W) + noise
  31. fas = []
  32. for method in ["randomized", "lapack"]:
  33. fa = FactorAnalysis(n_components=n_components, svd_method=method)
  34. fa.fit(X)
  35. fas.append(fa)
  36. X_t = fa.transform(X)
  37. assert X_t.shape == (n_samples, n_components)
  38. assert_almost_equal(fa.loglike_[-1], fa.score_samples(X).sum())
  39. assert_almost_equal(fa.score_samples(X).mean(), fa.score(X))
  40. diff = np.all(np.diff(fa.loglike_))
  41. assert diff > 0.0, "Log likelihood dif not increase"
  42. # Sample Covariance
  43. scov = np.cov(X, rowvar=0.0, bias=1.0)
  44. # Model Covariance
  45. mcov = fa.get_covariance()
  46. diff = np.sum(np.abs(scov - mcov)) / W.size
  47. assert diff < 0.1, "Mean absolute difference is %f" % diff
  48. fa = FactorAnalysis(
  49. n_components=n_components, noise_variance_init=np.ones(n_features)
  50. )
  51. with pytest.raises(ValueError):
  52. fa.fit(X[:, :2])
  53. def f(x, y):
  54. return np.abs(getattr(x, y)) # sign will not be equal
  55. fa1, fa2 = fas
  56. for attr in ["loglike_", "components_", "noise_variance_"]:
  57. assert_almost_equal(f(fa1, attr), f(fa2, attr))
  58. fa1.max_iter = 1
  59. fa1.verbose = True
  60. with pytest.warns(ConvergenceWarning):
  61. fa1.fit(X)
  62. # Test get_covariance and get_precision with n_components == n_features
  63. # with n_components < n_features and with n_components == 0
  64. for n_components in [0, 2, X.shape[1]]:
  65. fa.n_components = n_components
  66. fa.fit(X)
  67. cov = fa.get_covariance()
  68. precision = fa.get_precision()
  69. assert_array_almost_equal(np.dot(cov, precision), np.eye(X.shape[1]), 12)
  70. # test rotation
  71. n_components = 2
  72. results, projections = {}, {}
  73. for method in (None, "varimax", "quartimax"):
  74. fa_var = FactorAnalysis(n_components=n_components, rotation=method)
  75. results[method] = fa_var.fit_transform(X)
  76. projections[method] = fa_var.get_covariance()
  77. for rot1, rot2 in combinations([None, "varimax", "quartimax"], 2):
  78. assert not np.allclose(results[rot1], results[rot2])
  79. assert np.allclose(projections[rot1], projections[rot2], atol=3)
  80. # test against R's psych::principal with rotate="varimax"
  81. # (i.e., the values below stem from rotating the components in R)
  82. # R's factor analysis returns quite different values; therefore, we only
  83. # test the rotation itself
  84. factors = np.array(
  85. [
  86. [0.89421016, -0.35854928, -0.27770122, 0.03773647],
  87. [-0.45081822, -0.89132754, 0.0932195, -0.01787973],
  88. [0.99500666, -0.02031465, 0.05426497, -0.11539407],
  89. [0.96822861, -0.06299656, 0.24411001, 0.07540887],
  90. ]
  91. )
  92. r_solution = np.array(
  93. [[0.962, 0.052], [-0.141, 0.989], [0.949, -0.300], [0.937, -0.251]]
  94. )
  95. rotated = _ortho_rotation(factors[:, :n_components], method="varimax").T
  96. assert_array_almost_equal(np.abs(rotated), np.abs(r_solution), decimal=3)