test_cython_optimize.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. Test Cython optimize zeros API functions: ``bisect``, ``ridder``, ``brenth``,
  3. and ``brentq`` in `scipy.optimize.cython_optimize`, by finding the roots of a
  4. 3rd order polynomial given a sequence of constant terms, ``a0``, and fixed 1st,
  5. 2nd, and 3rd order terms in ``args``.
  6. .. math::
  7. f(x, a0, args) = ((args[2]*x + args[1])*x + args[0])*x + a0
  8. The 3rd order polynomial function is written in Cython and called in a Python
  9. wrapper named after the zero function. See the private ``_zeros`` Cython module
  10. in `scipy.optimize.cython_optimze` for more information.
  11. """
  12. import numpy.testing as npt
  13. from scipy.optimize.cython_optimize import _zeros
  14. # CONSTANTS
  15. # Solve x**3 - A0 = 0 for A0 = [2.0, 2.1, ..., 2.9].
  16. # The ARGS have 3 elements just to show how this could be done for any cubic
  17. # polynomial.
  18. A0 = tuple(-2.0 - x/10.0 for x in range(10)) # constant term
  19. ARGS = (0.0, 0.0, 1.0) # 1st, 2nd, and 3rd order terms
  20. XLO, XHI = 0.0, 2.0 # first and second bounds of zeros functions
  21. # absolute and relative tolerances and max iterations for zeros functions
  22. XTOL, RTOL, MITR = 0.001, 0.001, 10
  23. EXPECTED = [(-a0) ** (1.0/3.0) for a0 in A0]
  24. # = [1.2599210498948732,
  25. # 1.2805791649874942,
  26. # 1.300591446851387,
  27. # 1.3200061217959123,
  28. # 1.338865900164339,
  29. # 1.3572088082974532,
  30. # 1.375068867074141,
  31. # 1.3924766500838337,
  32. # 1.4094597464129783,
  33. # 1.4260431471424087]
  34. # test bisect
  35. def test_bisect():
  36. npt.assert_allclose(
  37. EXPECTED,
  38. list(
  39. _zeros.loop_example('bisect', A0, ARGS, XLO, XHI, XTOL, RTOL, MITR)
  40. ),
  41. rtol=RTOL, atol=XTOL
  42. )
  43. # test ridder
  44. def test_ridder():
  45. npt.assert_allclose(
  46. EXPECTED,
  47. list(
  48. _zeros.loop_example('ridder', A0, ARGS, XLO, XHI, XTOL, RTOL, MITR)
  49. ),
  50. rtol=RTOL, atol=XTOL
  51. )
  52. # test brenth
  53. def test_brenth():
  54. npt.assert_allclose(
  55. EXPECTED,
  56. list(
  57. _zeros.loop_example('brenth', A0, ARGS, XLO, XHI, XTOL, RTOL, MITR)
  58. ),
  59. rtol=RTOL, atol=XTOL
  60. )
  61. # test brentq
  62. def test_brentq():
  63. npt.assert_allclose(
  64. EXPECTED,
  65. list(
  66. _zeros.loop_example('brentq', A0, ARGS, XLO, XHI, XTOL, RTOL, MITR)
  67. ),
  68. rtol=RTOL, atol=XTOL
  69. )
  70. # test brentq with full output
  71. def test_brentq_full_output():
  72. output = _zeros.full_output_example(
  73. (A0[0],) + ARGS, XLO, XHI, XTOL, RTOL, MITR)
  74. npt.assert_allclose(EXPECTED[0], output['root'], rtol=RTOL, atol=XTOL)
  75. npt.assert_equal(6, output['iterations'])
  76. npt.assert_equal(7, output['funcalls'])
  77. npt.assert_equal(0, output['error_num'])