test_utils.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664
  1. import warnings
  2. import sys
  3. import os
  4. import itertools
  5. import pytest
  6. import weakref
  7. import numpy as np
  8. from numpy.testing import (
  9. assert_equal, assert_array_equal, assert_almost_equal,
  10. assert_array_almost_equal, assert_array_less, build_err_msg, raises,
  11. assert_raises, assert_warns, assert_no_warnings, assert_allclose,
  12. assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
  13. clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
  14. tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
  15. )
  16. from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
  17. class _GenericTest:
  18. def _test_equal(self, a, b):
  19. self._assert_func(a, b)
  20. def _test_not_equal(self, a, b):
  21. with assert_raises(AssertionError):
  22. self._assert_func(a, b)
  23. def test_array_rank1_eq(self):
  24. """Test two equal array of rank 1 are found equal."""
  25. a = np.array([1, 2])
  26. b = np.array([1, 2])
  27. self._test_equal(a, b)
  28. def test_array_rank1_noteq(self):
  29. """Test two different array of rank 1 are found not equal."""
  30. a = np.array([1, 2])
  31. b = np.array([2, 2])
  32. self._test_not_equal(a, b)
  33. def test_array_rank2_eq(self):
  34. """Test two equal array of rank 2 are found equal."""
  35. a = np.array([[1, 2], [3, 4]])
  36. b = np.array([[1, 2], [3, 4]])
  37. self._test_equal(a, b)
  38. def test_array_diffshape(self):
  39. """Test two arrays with different shapes are found not equal."""
  40. a = np.array([1, 2])
  41. b = np.array([[1, 2], [1, 2]])
  42. self._test_not_equal(a, b)
  43. def test_objarray(self):
  44. """Test object arrays."""
  45. a = np.array([1, 1], dtype=object)
  46. self._test_equal(a, 1)
  47. def test_array_likes(self):
  48. self._test_equal([1, 2, 3], (1, 2, 3))
  49. class TestArrayEqual(_GenericTest):
  50. def setup_method(self):
  51. self._assert_func = assert_array_equal
  52. def test_generic_rank1(self):
  53. """Test rank 1 array for all dtypes."""
  54. def foo(t):
  55. a = np.empty(2, t)
  56. a.fill(1)
  57. b = a.copy()
  58. c = a.copy()
  59. c.fill(0)
  60. self._test_equal(a, b)
  61. self._test_not_equal(c, b)
  62. # Test numeric types and object
  63. for t in '?bhilqpBHILQPfdgFDG':
  64. foo(t)
  65. # Test strings
  66. for t in ['S1', 'U1']:
  67. foo(t)
  68. def test_0_ndim_array(self):
  69. x = np.array(473963742225900817127911193656584771)
  70. y = np.array(18535119325151578301457182298393896)
  71. assert_raises(AssertionError, self._assert_func, x, y)
  72. y = x
  73. self._assert_func(x, y)
  74. x = np.array(43)
  75. y = np.array(10)
  76. assert_raises(AssertionError, self._assert_func, x, y)
  77. y = x
  78. self._assert_func(x, y)
  79. def test_generic_rank3(self):
  80. """Test rank 3 array for all dtypes."""
  81. def foo(t):
  82. a = np.empty((4, 2, 3), t)
  83. a.fill(1)
  84. b = a.copy()
  85. c = a.copy()
  86. c.fill(0)
  87. self._test_equal(a, b)
  88. self._test_not_equal(c, b)
  89. # Test numeric types and object
  90. for t in '?bhilqpBHILQPfdgFDG':
  91. foo(t)
  92. # Test strings
  93. for t in ['S1', 'U1']:
  94. foo(t)
  95. def test_nan_array(self):
  96. """Test arrays with nan values in them."""
  97. a = np.array([1, 2, np.nan])
  98. b = np.array([1, 2, np.nan])
  99. self._test_equal(a, b)
  100. c = np.array([1, 2, 3])
  101. self._test_not_equal(c, b)
  102. def test_string_arrays(self):
  103. """Test two arrays with different shapes are found not equal."""
  104. a = np.array(['floupi', 'floupa'])
  105. b = np.array(['floupi', 'floupa'])
  106. self._test_equal(a, b)
  107. c = np.array(['floupipi', 'floupa'])
  108. self._test_not_equal(c, b)
  109. def test_recarrays(self):
  110. """Test record arrays."""
  111. a = np.empty(2, [('floupi', float), ('floupa', float)])
  112. a['floupi'] = [1, 2]
  113. a['floupa'] = [1, 2]
  114. b = a.copy()
  115. self._test_equal(a, b)
  116. c = np.empty(2, [('floupipi', float),
  117. ('floupi', float), ('floupa', float)])
  118. c['floupipi'] = a['floupi'].copy()
  119. c['floupa'] = a['floupa'].copy()
  120. with pytest.raises(TypeError):
  121. self._test_not_equal(c, b)
  122. def test_masked_nan_inf(self):
  123. # Regression test for gh-11121
  124. a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
  125. b = np.array([3., np.nan, 6.5])
  126. self._test_equal(a, b)
  127. self._test_equal(b, a)
  128. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
  129. b = np.array([np.inf, 4., 6.5])
  130. self._test_equal(a, b)
  131. self._test_equal(b, a)
  132. def test_subclass_that_overrides_eq(self):
  133. # While we cannot guarantee testing functions will always work for
  134. # subclasses, the tests should ideally rely only on subclasses having
  135. # comparison operators, not on them being able to store booleans
  136. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  137. class MyArray(np.ndarray):
  138. def __eq__(self, other):
  139. return bool(np.equal(self, other).all())
  140. def __ne__(self, other):
  141. return not self == other
  142. a = np.array([1., 2.]).view(MyArray)
  143. b = np.array([2., 3.]).view(MyArray)
  144. assert_(type(a == a), bool)
  145. assert_(a == a)
  146. assert_(a != b)
  147. self._test_equal(a, a)
  148. self._test_not_equal(a, b)
  149. self._test_not_equal(b, a)
  150. @pytest.mark.skipif(
  151. not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
  152. def test_subclass_that_does_not_implement_npall(self):
  153. class MyArray(np.ndarray):
  154. def __array_function__(self, *args, **kwargs):
  155. return NotImplemented
  156. a = np.array([1., 2.]).view(MyArray)
  157. b = np.array([2., 3.]).view(MyArray)
  158. with assert_raises(TypeError):
  159. np.all(a)
  160. self._test_equal(a, a)
  161. self._test_not_equal(a, b)
  162. self._test_not_equal(b, a)
  163. def test_suppress_overflow_warnings(self):
  164. # Based on issue #18992
  165. with pytest.raises(AssertionError):
  166. with np.errstate(all="raise"):
  167. np.testing.assert_array_equal(
  168. np.array([1, 2, 3], np.float32),
  169. np.array([1, 1e-40, 3], np.float32))
  170. def test_array_vs_scalar_is_equal(self):
  171. """Test comparing an array with a scalar when all values are equal."""
  172. a = np.array([1., 1., 1.])
  173. b = 1.
  174. self._test_equal(a, b)
  175. def test_array_vs_scalar_not_equal(self):
  176. """Test comparing an array with a scalar when not all values equal."""
  177. a = np.array([1., 2., 3.])
  178. b = 1.
  179. self._test_not_equal(a, b)
  180. def test_array_vs_scalar_strict(self):
  181. """Test comparing an array with a scalar with strict option."""
  182. a = np.array([1., 1., 1.])
  183. b = 1.
  184. with pytest.raises(AssertionError):
  185. assert_array_equal(a, b, strict=True)
  186. def test_array_vs_array_strict(self):
  187. """Test comparing two arrays with strict option."""
  188. a = np.array([1., 1., 1.])
  189. b = np.array([1., 1., 1.])
  190. assert_array_equal(a, b, strict=True)
  191. def test_array_vs_float_array_strict(self):
  192. """Test comparing two arrays with strict option."""
  193. a = np.array([1, 1, 1])
  194. b = np.array([1., 1., 1.])
  195. with pytest.raises(AssertionError):
  196. assert_array_equal(a, b, strict=True)
  197. class TestBuildErrorMessage:
  198. def test_build_err_msg_defaults(self):
  199. x = np.array([1.00001, 2.00002, 3.00003])
  200. y = np.array([1.00002, 2.00003, 3.00004])
  201. err_msg = 'There is a mismatch'
  202. a = build_err_msg([x, y], err_msg)
  203. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  204. '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
  205. '2.00003, 3.00004])')
  206. assert_equal(a, b)
  207. def test_build_err_msg_no_verbose(self):
  208. x = np.array([1.00001, 2.00002, 3.00003])
  209. y = np.array([1.00002, 2.00003, 3.00004])
  210. err_msg = 'There is a mismatch'
  211. a = build_err_msg([x, y], err_msg, verbose=False)
  212. b = '\nItems are not equal: There is a mismatch'
  213. assert_equal(a, b)
  214. def test_build_err_msg_custom_names(self):
  215. x = np.array([1.00001, 2.00002, 3.00003])
  216. y = np.array([1.00002, 2.00003, 3.00004])
  217. err_msg = 'There is a mismatch'
  218. a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
  219. b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
  220. '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
  221. '3.00004])')
  222. assert_equal(a, b)
  223. def test_build_err_msg_custom_precision(self):
  224. x = np.array([1.000000001, 2.00002, 3.00003])
  225. y = np.array([1.000000002, 2.00003, 3.00004])
  226. err_msg = 'There is a mismatch'
  227. a = build_err_msg([x, y], err_msg, precision=10)
  228. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  229. '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
  230. '1.000000002, 2.00003 , 3.00004 ])')
  231. assert_equal(a, b)
  232. class TestEqual(TestArrayEqual):
  233. def setup_method(self):
  234. self._assert_func = assert_equal
  235. def test_nan_items(self):
  236. self._assert_func(np.nan, np.nan)
  237. self._assert_func([np.nan], [np.nan])
  238. self._test_not_equal(np.nan, [np.nan])
  239. self._test_not_equal(np.nan, 1)
  240. def test_inf_items(self):
  241. self._assert_func(np.inf, np.inf)
  242. self._assert_func([np.inf], [np.inf])
  243. self._test_not_equal(np.inf, [np.inf])
  244. def test_datetime(self):
  245. self._test_equal(
  246. np.datetime64("2017-01-01", "s"),
  247. np.datetime64("2017-01-01", "s")
  248. )
  249. self._test_equal(
  250. np.datetime64("2017-01-01", "s"),
  251. np.datetime64("2017-01-01", "m")
  252. )
  253. # gh-10081
  254. self._test_not_equal(
  255. np.datetime64("2017-01-01", "s"),
  256. np.datetime64("2017-01-02", "s")
  257. )
  258. self._test_not_equal(
  259. np.datetime64("2017-01-01", "s"),
  260. np.datetime64("2017-01-02", "m")
  261. )
  262. def test_nat_items(self):
  263. # not a datetime
  264. nadt_no_unit = np.datetime64("NaT")
  265. nadt_s = np.datetime64("NaT", "s")
  266. nadt_d = np.datetime64("NaT", "ns")
  267. # not a timedelta
  268. natd_no_unit = np.timedelta64("NaT")
  269. natd_s = np.timedelta64("NaT", "s")
  270. natd_d = np.timedelta64("NaT", "ns")
  271. dts = [nadt_no_unit, nadt_s, nadt_d]
  272. tds = [natd_no_unit, natd_s, natd_d]
  273. for a, b in itertools.product(dts, dts):
  274. self._assert_func(a, b)
  275. self._assert_func([a], [b])
  276. self._test_not_equal([a], b)
  277. for a, b in itertools.product(tds, tds):
  278. self._assert_func(a, b)
  279. self._assert_func([a], [b])
  280. self._test_not_equal([a], b)
  281. for a, b in itertools.product(tds, dts):
  282. self._test_not_equal(a, b)
  283. self._test_not_equal(a, [b])
  284. self._test_not_equal([a], [b])
  285. self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
  286. self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
  287. self._test_not_equal([a], np.timedelta64(123, "s"))
  288. self._test_not_equal([b], np.timedelta64(123, "s"))
  289. def test_non_numeric(self):
  290. self._assert_func('ab', 'ab')
  291. self._test_not_equal('ab', 'abb')
  292. def test_complex_item(self):
  293. self._assert_func(complex(1, 2), complex(1, 2))
  294. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  295. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  296. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  297. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  298. def test_negative_zero(self):
  299. self._test_not_equal(np.PZERO, np.NZERO)
  300. def test_complex(self):
  301. x = np.array([complex(1, 2), complex(1, np.nan)])
  302. y = np.array([complex(1, 2), complex(1, 2)])
  303. self._assert_func(x, x)
  304. self._test_not_equal(x, y)
  305. def test_object(self):
  306. #gh-12942
  307. import datetime
  308. a = np.array([datetime.datetime(2000, 1, 1),
  309. datetime.datetime(2000, 1, 2)])
  310. self._test_not_equal(a, a[::-1])
  311. class TestArrayAlmostEqual(_GenericTest):
  312. def setup_method(self):
  313. self._assert_func = assert_array_almost_equal
  314. def test_closeness(self):
  315. # Note that in the course of time we ended up with
  316. # `abs(x - y) < 1.5 * 10**(-decimal)`
  317. # instead of the previously documented
  318. # `abs(x - y) < 0.5 * 10**(-decimal)`
  319. # so this check serves to preserve the wrongness.
  320. # test scalars
  321. self._assert_func(1.499999, 0.0, decimal=0)
  322. assert_raises(AssertionError,
  323. lambda: self._assert_func(1.5, 0.0, decimal=0))
  324. # test arrays
  325. self._assert_func([1.499999], [0.0], decimal=0)
  326. assert_raises(AssertionError,
  327. lambda: self._assert_func([1.5], [0.0], decimal=0))
  328. def test_simple(self):
  329. x = np.array([1234.2222])
  330. y = np.array([1234.2223])
  331. self._assert_func(x, y, decimal=3)
  332. self._assert_func(x, y, decimal=4)
  333. assert_raises(AssertionError,
  334. lambda: self._assert_func(x, y, decimal=5))
  335. def test_nan(self):
  336. anan = np.array([np.nan])
  337. aone = np.array([1])
  338. ainf = np.array([np.inf])
  339. self._assert_func(anan, anan)
  340. assert_raises(AssertionError,
  341. lambda: self._assert_func(anan, aone))
  342. assert_raises(AssertionError,
  343. lambda: self._assert_func(anan, ainf))
  344. assert_raises(AssertionError,
  345. lambda: self._assert_func(ainf, anan))
  346. def test_inf(self):
  347. a = np.array([[1., 2.], [3., 4.]])
  348. b = a.copy()
  349. a[0, 0] = np.inf
  350. assert_raises(AssertionError,
  351. lambda: self._assert_func(a, b))
  352. b[0, 0] = -np.inf
  353. assert_raises(AssertionError,
  354. lambda: self._assert_func(a, b))
  355. def test_subclass(self):
  356. a = np.array([[1., 2.], [3., 4.]])
  357. b = np.ma.masked_array([[1., 2.], [0., 4.]],
  358. [[False, False], [True, False]])
  359. self._assert_func(a, b)
  360. self._assert_func(b, a)
  361. self._assert_func(b, b)
  362. # Test fully masked as well (see gh-11123).
  363. a = np.ma.MaskedArray(3.5, mask=True)
  364. b = np.array([3., 4., 6.5])
  365. self._test_equal(a, b)
  366. self._test_equal(b, a)
  367. a = np.ma.masked
  368. b = np.array([3., 4., 6.5])
  369. self._test_equal(a, b)
  370. self._test_equal(b, a)
  371. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  372. b = np.array([1., 2., 3.])
  373. self._test_equal(a, b)
  374. self._test_equal(b, a)
  375. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  376. b = np.array(1.)
  377. self._test_equal(a, b)
  378. self._test_equal(b, a)
  379. def test_subclass_that_cannot_be_bool(self):
  380. # While we cannot guarantee testing functions will always work for
  381. # subclasses, the tests should ideally rely only on subclasses having
  382. # comparison operators, not on them being able to store booleans
  383. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  384. class MyArray(np.ndarray):
  385. def __eq__(self, other):
  386. return super().__eq__(other).view(np.ndarray)
  387. def __lt__(self, other):
  388. return super().__lt__(other).view(np.ndarray)
  389. def all(self, *args, **kwargs):
  390. raise NotImplementedError
  391. a = np.array([1., 2.]).view(MyArray)
  392. self._assert_func(a, a)
  393. class TestAlmostEqual(_GenericTest):
  394. def setup_method(self):
  395. self._assert_func = assert_almost_equal
  396. def test_closeness(self):
  397. # Note that in the course of time we ended up with
  398. # `abs(x - y) < 1.5 * 10**(-decimal)`
  399. # instead of the previously documented
  400. # `abs(x - y) < 0.5 * 10**(-decimal)`
  401. # so this check serves to preserve the wrongness.
  402. # test scalars
  403. self._assert_func(1.499999, 0.0, decimal=0)
  404. assert_raises(AssertionError,
  405. lambda: self._assert_func(1.5, 0.0, decimal=0))
  406. # test arrays
  407. self._assert_func([1.499999], [0.0], decimal=0)
  408. assert_raises(AssertionError,
  409. lambda: self._assert_func([1.5], [0.0], decimal=0))
  410. def test_nan_item(self):
  411. self._assert_func(np.nan, np.nan)
  412. assert_raises(AssertionError,
  413. lambda: self._assert_func(np.nan, 1))
  414. assert_raises(AssertionError,
  415. lambda: self._assert_func(np.nan, np.inf))
  416. assert_raises(AssertionError,
  417. lambda: self._assert_func(np.inf, np.nan))
  418. def test_inf_item(self):
  419. self._assert_func(np.inf, np.inf)
  420. self._assert_func(-np.inf, -np.inf)
  421. assert_raises(AssertionError,
  422. lambda: self._assert_func(np.inf, 1))
  423. assert_raises(AssertionError,
  424. lambda: self._assert_func(-np.inf, np.inf))
  425. def test_simple_item(self):
  426. self._test_not_equal(1, 2)
  427. def test_complex_item(self):
  428. self._assert_func(complex(1, 2), complex(1, 2))
  429. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  430. self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
  431. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  432. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  433. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  434. def test_complex(self):
  435. x = np.array([complex(1, 2), complex(1, np.nan)])
  436. z = np.array([complex(1, 2), complex(np.nan, 1)])
  437. y = np.array([complex(1, 2), complex(1, 2)])
  438. self._assert_func(x, x)
  439. self._test_not_equal(x, y)
  440. self._test_not_equal(x, z)
  441. def test_error_message(self):
  442. """Check the message is formatted correctly for the decimal value.
  443. Also check the message when input includes inf or nan (gh12200)"""
  444. x = np.array([1.00000000001, 2.00000000002, 3.00003])
  445. y = np.array([1.00000000002, 2.00000000003, 3.00004])
  446. # Test with a different amount of decimal digits
  447. with pytest.raises(AssertionError) as exc_info:
  448. self._assert_func(x, y, decimal=12)
  449. msgs = str(exc_info.value).split('\n')
  450. assert_equal(msgs[3], 'Mismatched elements: 3 / 3 (100%)')
  451. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  452. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  453. assert_equal(
  454. msgs[6],
  455. ' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
  456. assert_equal(
  457. msgs[7],
  458. ' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
  459. # With the default value of decimal digits, only the 3rd element
  460. # differs. Note that we only check for the formatting of the arrays
  461. # themselves.
  462. with pytest.raises(AssertionError) as exc_info:
  463. self._assert_func(x, y)
  464. msgs = str(exc_info.value).split('\n')
  465. assert_equal(msgs[3], 'Mismatched elements: 1 / 3 (33.3%)')
  466. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  467. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  468. assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
  469. assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
  470. # Check the error message when input includes inf
  471. x = np.array([np.inf, 0])
  472. y = np.array([np.inf, 1])
  473. with pytest.raises(AssertionError) as exc_info:
  474. self._assert_func(x, y)
  475. msgs = str(exc_info.value).split('\n')
  476. assert_equal(msgs[3], 'Mismatched elements: 1 / 2 (50%)')
  477. assert_equal(msgs[4], 'Max absolute difference: 1.')
  478. assert_equal(msgs[5], 'Max relative difference: 1.')
  479. assert_equal(msgs[6], ' x: array([inf, 0.])')
  480. assert_equal(msgs[7], ' y: array([inf, 1.])')
  481. # Check the error message when dividing by zero
  482. x = np.array([1, 2])
  483. y = np.array([0, 0])
  484. with pytest.raises(AssertionError) as exc_info:
  485. self._assert_func(x, y)
  486. msgs = str(exc_info.value).split('\n')
  487. assert_equal(msgs[3], 'Mismatched elements: 2 / 2 (100%)')
  488. assert_equal(msgs[4], 'Max absolute difference: 2')
  489. assert_equal(msgs[5], 'Max relative difference: inf')
  490. def test_error_message_2(self):
  491. """Check the message is formatted correctly when either x or y is a scalar."""
  492. x = 2
  493. y = np.ones(20)
  494. with pytest.raises(AssertionError) as exc_info:
  495. self._assert_func(x, y)
  496. msgs = str(exc_info.value).split('\n')
  497. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  498. assert_equal(msgs[4], 'Max absolute difference: 1.')
  499. assert_equal(msgs[5], 'Max relative difference: 1.')
  500. y = 2
  501. x = np.ones(20)
  502. with pytest.raises(AssertionError) as exc_info:
  503. self._assert_func(x, y)
  504. msgs = str(exc_info.value).split('\n')
  505. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  506. assert_equal(msgs[4], 'Max absolute difference: 1.')
  507. assert_equal(msgs[5], 'Max relative difference: 0.5')
  508. def test_subclass_that_cannot_be_bool(self):
  509. # While we cannot guarantee testing functions will always work for
  510. # subclasses, the tests should ideally rely only on subclasses having
  511. # comparison operators, not on them being able to store booleans
  512. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  513. class MyArray(np.ndarray):
  514. def __eq__(self, other):
  515. return super().__eq__(other).view(np.ndarray)
  516. def __lt__(self, other):
  517. return super().__lt__(other).view(np.ndarray)
  518. def all(self, *args, **kwargs):
  519. raise NotImplementedError
  520. a = np.array([1., 2.]).view(MyArray)
  521. self._assert_func(a, a)
  522. class TestApproxEqual:
  523. def setup_method(self):
  524. self._assert_func = assert_approx_equal
  525. def test_simple_0d_arrays(self):
  526. x = np.array(1234.22)
  527. y = np.array(1234.23)
  528. self._assert_func(x, y, significant=5)
  529. self._assert_func(x, y, significant=6)
  530. assert_raises(AssertionError,
  531. lambda: self._assert_func(x, y, significant=7))
  532. def test_simple_items(self):
  533. x = 1234.22
  534. y = 1234.23
  535. self._assert_func(x, y, significant=4)
  536. self._assert_func(x, y, significant=5)
  537. self._assert_func(x, y, significant=6)
  538. assert_raises(AssertionError,
  539. lambda: self._assert_func(x, y, significant=7))
  540. def test_nan_array(self):
  541. anan = np.array(np.nan)
  542. aone = np.array(1)
  543. ainf = np.array(np.inf)
  544. self._assert_func(anan, anan)
  545. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  546. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  547. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  548. def test_nan_items(self):
  549. anan = np.array(np.nan)
  550. aone = np.array(1)
  551. ainf = np.array(np.inf)
  552. self._assert_func(anan, anan)
  553. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  554. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  555. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  556. class TestArrayAssertLess:
  557. def setup_method(self):
  558. self._assert_func = assert_array_less
  559. def test_simple_arrays(self):
  560. x = np.array([1.1, 2.2])
  561. y = np.array([1.2, 2.3])
  562. self._assert_func(x, y)
  563. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  564. y = np.array([1.0, 2.3])
  565. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  566. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  567. def test_rank2(self):
  568. x = np.array([[1.1, 2.2], [3.3, 4.4]])
  569. y = np.array([[1.2, 2.3], [3.4, 4.5]])
  570. self._assert_func(x, y)
  571. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  572. y = np.array([[1.0, 2.3], [3.4, 4.5]])
  573. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  574. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  575. def test_rank3(self):
  576. x = np.ones(shape=(2, 2, 2))
  577. y = np.ones(shape=(2, 2, 2))+1
  578. self._assert_func(x, y)
  579. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  580. y[0, 0, 0] = 0
  581. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  582. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  583. def test_simple_items(self):
  584. x = 1.1
  585. y = 2.2
  586. self._assert_func(x, y)
  587. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  588. y = np.array([2.2, 3.3])
  589. self._assert_func(x, y)
  590. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  591. y = np.array([1.0, 3.3])
  592. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  593. def test_nan_noncompare(self):
  594. anan = np.array(np.nan)
  595. aone = np.array(1)
  596. ainf = np.array(np.inf)
  597. self._assert_func(anan, anan)
  598. assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
  599. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  600. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  601. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  602. def test_nan_noncompare_array(self):
  603. x = np.array([1.1, 2.2, 3.3])
  604. anan = np.array(np.nan)
  605. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  606. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  607. x = np.array([1.1, 2.2, np.nan])
  608. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  609. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  610. y = np.array([1.0, 2.0, np.nan])
  611. self._assert_func(y, x)
  612. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  613. def test_inf_compare(self):
  614. aone = np.array(1)
  615. ainf = np.array(np.inf)
  616. self._assert_func(aone, ainf)
  617. self._assert_func(-ainf, aone)
  618. self._assert_func(-ainf, ainf)
  619. assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
  620. assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
  621. assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
  622. assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
  623. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
  624. def test_inf_compare_array(self):
  625. x = np.array([1.1, 2.2, np.inf])
  626. ainf = np.array(np.inf)
  627. assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
  628. assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
  629. assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
  630. assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
  631. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
  632. self._assert_func(-ainf, x)
  633. @pytest.mark.skip(reason="The raises decorator depends on Nose")
  634. class TestRaises:
  635. def setup_method(self):
  636. class MyException(Exception):
  637. pass
  638. self.e = MyException
  639. def raises_exception(self, e):
  640. raise e
  641. def does_not_raise_exception(self):
  642. pass
  643. def test_correct_catch(self):
  644. raises(self.e)(self.raises_exception)(self.e) # raises?
  645. def test_wrong_exception(self):
  646. try:
  647. raises(self.e)(self.raises_exception)(RuntimeError) # raises?
  648. except RuntimeError:
  649. return
  650. else:
  651. raise AssertionError("should have caught RuntimeError")
  652. def test_catch_no_raise(self):
  653. try:
  654. raises(self.e)(self.does_not_raise_exception)() # raises?
  655. except AssertionError:
  656. return
  657. else:
  658. raise AssertionError("should have raised an AssertionError")
  659. class TestWarns:
  660. def test_warn(self):
  661. def f():
  662. warnings.warn("yo")
  663. return 3
  664. before_filters = sys.modules['warnings'].filters[:]
  665. assert_equal(assert_warns(UserWarning, f), 3)
  666. after_filters = sys.modules['warnings'].filters
  667. assert_raises(AssertionError, assert_no_warnings, f)
  668. assert_equal(assert_no_warnings(lambda x: x, 1), 1)
  669. # Check that the warnings state is unchanged
  670. assert_equal(before_filters, after_filters,
  671. "assert_warns does not preserver warnings state")
  672. def test_context_manager(self):
  673. before_filters = sys.modules['warnings'].filters[:]
  674. with assert_warns(UserWarning):
  675. warnings.warn("yo")
  676. after_filters = sys.modules['warnings'].filters
  677. def no_warnings():
  678. with assert_no_warnings():
  679. warnings.warn("yo")
  680. assert_raises(AssertionError, no_warnings)
  681. assert_equal(before_filters, after_filters,
  682. "assert_warns does not preserver warnings state")
  683. def test_warn_wrong_warning(self):
  684. def f():
  685. warnings.warn("yo", DeprecationWarning)
  686. failed = False
  687. with warnings.catch_warnings():
  688. warnings.simplefilter("error", DeprecationWarning)
  689. try:
  690. # Should raise a DeprecationWarning
  691. assert_warns(UserWarning, f)
  692. failed = True
  693. except DeprecationWarning:
  694. pass
  695. if failed:
  696. raise AssertionError("wrong warning caught by assert_warn")
  697. class TestAssertAllclose:
  698. def test_simple(self):
  699. x = 1e-3
  700. y = 1e-9
  701. assert_allclose(x, y, atol=1)
  702. assert_raises(AssertionError, assert_allclose, x, y)
  703. a = np.array([x, y, x, y])
  704. b = np.array([x, y, x, x])
  705. assert_allclose(a, b, atol=1)
  706. assert_raises(AssertionError, assert_allclose, a, b)
  707. b[-1] = y * (1 + 1e-8)
  708. assert_allclose(a, b)
  709. assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
  710. assert_allclose(6, 10, rtol=0.5)
  711. assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
  712. def test_min_int(self):
  713. a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
  714. # Should not raise:
  715. assert_allclose(a, a)
  716. def test_report_fail_percentage(self):
  717. a = np.array([1, 1, 1, 1])
  718. b = np.array([1, 1, 1, 2])
  719. with pytest.raises(AssertionError) as exc_info:
  720. assert_allclose(a, b)
  721. msg = str(exc_info.value)
  722. assert_('Mismatched elements: 1 / 4 (25%)\n'
  723. 'Max absolute difference: 1\n'
  724. 'Max relative difference: 0.5' in msg)
  725. def test_equal_nan(self):
  726. a = np.array([np.nan])
  727. b = np.array([np.nan])
  728. # Should not raise:
  729. assert_allclose(a, b, equal_nan=True)
  730. def test_not_equal_nan(self):
  731. a = np.array([np.nan])
  732. b = np.array([np.nan])
  733. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  734. def test_equal_nan_default(self):
  735. # Make sure equal_nan default behavior remains unchanged. (All
  736. # of these functions use assert_array_compare under the hood.)
  737. # None of these should raise.
  738. a = np.array([np.nan])
  739. b = np.array([np.nan])
  740. assert_array_equal(a, b)
  741. assert_array_almost_equal(a, b)
  742. assert_array_less(a, b)
  743. assert_allclose(a, b)
  744. def test_report_max_relative_error(self):
  745. a = np.array([0, 1])
  746. b = np.array([0, 2])
  747. with pytest.raises(AssertionError) as exc_info:
  748. assert_allclose(a, b)
  749. msg = str(exc_info.value)
  750. assert_('Max relative difference: 0.5' in msg)
  751. def test_timedelta(self):
  752. # see gh-18286
  753. a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]")
  754. assert_allclose(a, a)
  755. def test_error_message_unsigned(self):
  756. """Check the the message is formatted correctly when overflow can occur
  757. (gh21768)"""
  758. # Ensure to test for potential overflow in the case of:
  759. # x - y
  760. # and
  761. # y - x
  762. x = np.asarray([0, 1, 8], dtype='uint8')
  763. y = np.asarray([4, 4, 4], dtype='uint8')
  764. with pytest.raises(AssertionError) as exc_info:
  765. assert_allclose(x, y, atol=3)
  766. msgs = str(exc_info.value).split('\n')
  767. assert_equal(msgs[4], 'Max absolute difference: 4')
  768. class TestArrayAlmostEqualNulp:
  769. def test_float64_pass(self):
  770. # The number of units of least precision
  771. # In this case, use a few places above the lowest level (ie nulp=1)
  772. nulp = 5
  773. x = np.linspace(-20, 20, 50, dtype=np.float64)
  774. x = 10**x
  775. x = np.r_[-x, x]
  776. # Addition
  777. eps = np.finfo(x.dtype).eps
  778. y = x + x*eps*nulp/2.
  779. assert_array_almost_equal_nulp(x, y, nulp)
  780. # Subtraction
  781. epsneg = np.finfo(x.dtype).epsneg
  782. y = x - x*epsneg*nulp/2.
  783. assert_array_almost_equal_nulp(x, y, nulp)
  784. def test_float64_fail(self):
  785. nulp = 5
  786. x = np.linspace(-20, 20, 50, dtype=np.float64)
  787. x = 10**x
  788. x = np.r_[-x, x]
  789. eps = np.finfo(x.dtype).eps
  790. y = x + x*eps*nulp*2.
  791. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  792. x, y, nulp)
  793. epsneg = np.finfo(x.dtype).epsneg
  794. y = x - x*epsneg*nulp*2.
  795. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  796. x, y, nulp)
  797. def test_float64_ignore_nan(self):
  798. # Ignore ULP differences between various NAN's
  799. # Note that MIPS may reverse quiet and signaling nans
  800. # so we use the builtin version as a base.
  801. offset = np.uint64(0xffffffff)
  802. nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64)
  803. nan2_i64 = nan1_i64 ^ offset # nan payload on MIPS is all ones.
  804. nan1_f64 = nan1_i64.view(np.float64)
  805. nan2_f64 = nan2_i64.view(np.float64)
  806. assert_array_max_ulp(nan1_f64, nan2_f64, 0)
  807. def test_float32_pass(self):
  808. nulp = 5
  809. x = np.linspace(-20, 20, 50, dtype=np.float32)
  810. x = 10**x
  811. x = np.r_[-x, x]
  812. eps = np.finfo(x.dtype).eps
  813. y = x + x*eps*nulp/2.
  814. assert_array_almost_equal_nulp(x, y, nulp)
  815. epsneg = np.finfo(x.dtype).epsneg
  816. y = x - x*epsneg*nulp/2.
  817. assert_array_almost_equal_nulp(x, y, nulp)
  818. def test_float32_fail(self):
  819. nulp = 5
  820. x = np.linspace(-20, 20, 50, dtype=np.float32)
  821. x = 10**x
  822. x = np.r_[-x, x]
  823. eps = np.finfo(x.dtype).eps
  824. y = x + x*eps*nulp*2.
  825. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  826. x, y, nulp)
  827. epsneg = np.finfo(x.dtype).epsneg
  828. y = x - x*epsneg*nulp*2.
  829. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  830. x, y, nulp)
  831. def test_float32_ignore_nan(self):
  832. # Ignore ULP differences between various NAN's
  833. # Note that MIPS may reverse quiet and signaling nans
  834. # so we use the builtin version as a base.
  835. offset = np.uint32(0xffff)
  836. nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32)
  837. nan2_i32 = nan1_i32 ^ offset # nan payload on MIPS is all ones.
  838. nan1_f32 = nan1_i32.view(np.float32)
  839. nan2_f32 = nan2_i32.view(np.float32)
  840. assert_array_max_ulp(nan1_f32, nan2_f32, 0)
  841. def test_float16_pass(self):
  842. nulp = 5
  843. x = np.linspace(-4, 4, 10, dtype=np.float16)
  844. x = 10**x
  845. x = np.r_[-x, x]
  846. eps = np.finfo(x.dtype).eps
  847. y = x + x*eps*nulp/2.
  848. assert_array_almost_equal_nulp(x, y, nulp)
  849. epsneg = np.finfo(x.dtype).epsneg
  850. y = x - x*epsneg*nulp/2.
  851. assert_array_almost_equal_nulp(x, y, nulp)
  852. def test_float16_fail(self):
  853. nulp = 5
  854. x = np.linspace(-4, 4, 10, dtype=np.float16)
  855. x = 10**x
  856. x = np.r_[-x, x]
  857. eps = np.finfo(x.dtype).eps
  858. y = x + x*eps*nulp*2.
  859. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  860. x, y, nulp)
  861. epsneg = np.finfo(x.dtype).epsneg
  862. y = x - x*epsneg*nulp*2.
  863. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  864. x, y, nulp)
  865. def test_float16_ignore_nan(self):
  866. # Ignore ULP differences between various NAN's
  867. # Note that MIPS may reverse quiet and signaling nans
  868. # so we use the builtin version as a base.
  869. offset = np.uint16(0xff)
  870. nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16)
  871. nan2_i16 = nan1_i16 ^ offset # nan payload on MIPS is all ones.
  872. nan1_f16 = nan1_i16.view(np.float16)
  873. nan2_f16 = nan2_i16.view(np.float16)
  874. assert_array_max_ulp(nan1_f16, nan2_f16, 0)
  875. def test_complex128_pass(self):
  876. nulp = 5
  877. x = np.linspace(-20, 20, 50, dtype=np.float64)
  878. x = 10**x
  879. x = np.r_[-x, x]
  880. xi = x + x*1j
  881. eps = np.finfo(x.dtype).eps
  882. y = x + x*eps*nulp/2.
  883. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  884. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  885. # The test condition needs to be at least a factor of sqrt(2) smaller
  886. # because the real and imaginary parts both change
  887. y = x + x*eps*nulp/4.
  888. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  889. epsneg = np.finfo(x.dtype).epsneg
  890. y = x - x*epsneg*nulp/2.
  891. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  892. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  893. y = x - x*epsneg*nulp/4.
  894. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  895. def test_complex128_fail(self):
  896. nulp = 5
  897. x = np.linspace(-20, 20, 50, dtype=np.float64)
  898. x = 10**x
  899. x = np.r_[-x, x]
  900. xi = x + x*1j
  901. eps = np.finfo(x.dtype).eps
  902. y = x + x*eps*nulp*2.
  903. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  904. xi, x + y*1j, nulp)
  905. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  906. xi, y + x*1j, nulp)
  907. # The test condition needs to be at least a factor of sqrt(2) smaller
  908. # because the real and imaginary parts both change
  909. y = x + x*eps*nulp
  910. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  911. xi, y + y*1j, nulp)
  912. epsneg = np.finfo(x.dtype).epsneg
  913. y = x - x*epsneg*nulp*2.
  914. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  915. xi, x + y*1j, nulp)
  916. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  917. xi, y + x*1j, nulp)
  918. y = x - x*epsneg*nulp
  919. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  920. xi, y + y*1j, nulp)
  921. def test_complex64_pass(self):
  922. nulp = 5
  923. x = np.linspace(-20, 20, 50, dtype=np.float32)
  924. x = 10**x
  925. x = np.r_[-x, x]
  926. xi = x + x*1j
  927. eps = np.finfo(x.dtype).eps
  928. y = x + x*eps*nulp/2.
  929. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  930. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  931. y = x + x*eps*nulp/4.
  932. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  933. epsneg = np.finfo(x.dtype).epsneg
  934. y = x - x*epsneg*nulp/2.
  935. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  936. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  937. y = x - x*epsneg*nulp/4.
  938. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  939. def test_complex64_fail(self):
  940. nulp = 5
  941. x = np.linspace(-20, 20, 50, dtype=np.float32)
  942. x = 10**x
  943. x = np.r_[-x, x]
  944. xi = x + x*1j
  945. eps = np.finfo(x.dtype).eps
  946. y = x + x*eps*nulp*2.
  947. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  948. xi, x + y*1j, nulp)
  949. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  950. xi, y + x*1j, nulp)
  951. y = x + x*eps*nulp
  952. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  953. xi, y + y*1j, nulp)
  954. epsneg = np.finfo(x.dtype).epsneg
  955. y = x - x*epsneg*nulp*2.
  956. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  957. xi, x + y*1j, nulp)
  958. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  959. xi, y + x*1j, nulp)
  960. y = x - x*epsneg*nulp
  961. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  962. xi, y + y*1j, nulp)
  963. class TestULP:
  964. def test_equal(self):
  965. x = np.random.randn(10)
  966. assert_array_max_ulp(x, x, maxulp=0)
  967. def test_single(self):
  968. # Generate 1 + small deviation, check that adding eps gives a few UNL
  969. x = np.ones(10).astype(np.float32)
  970. x += 0.01 * np.random.randn(10).astype(np.float32)
  971. eps = np.finfo(np.float32).eps
  972. assert_array_max_ulp(x, x+eps, maxulp=20)
  973. def test_double(self):
  974. # Generate 1 + small deviation, check that adding eps gives a few UNL
  975. x = np.ones(10).astype(np.float64)
  976. x += 0.01 * np.random.randn(10).astype(np.float64)
  977. eps = np.finfo(np.float64).eps
  978. assert_array_max_ulp(x, x+eps, maxulp=200)
  979. def test_inf(self):
  980. for dt in [np.float32, np.float64]:
  981. inf = np.array([np.inf]).astype(dt)
  982. big = np.array([np.finfo(dt).max])
  983. assert_array_max_ulp(inf, big, maxulp=200)
  984. def test_nan(self):
  985. # Test that nan is 'far' from small, tiny, inf, max and min
  986. for dt in [np.float32, np.float64]:
  987. if dt == np.float32:
  988. maxulp = 1e6
  989. else:
  990. maxulp = 1e12
  991. inf = np.array([np.inf]).astype(dt)
  992. nan = np.array([np.nan]).astype(dt)
  993. big = np.array([np.finfo(dt).max])
  994. tiny = np.array([np.finfo(dt).tiny])
  995. zero = np.array([np.PZERO]).astype(dt)
  996. nzero = np.array([np.NZERO]).astype(dt)
  997. assert_raises(AssertionError,
  998. lambda: assert_array_max_ulp(nan, inf,
  999. maxulp=maxulp))
  1000. assert_raises(AssertionError,
  1001. lambda: assert_array_max_ulp(nan, big,
  1002. maxulp=maxulp))
  1003. assert_raises(AssertionError,
  1004. lambda: assert_array_max_ulp(nan, tiny,
  1005. maxulp=maxulp))
  1006. assert_raises(AssertionError,
  1007. lambda: assert_array_max_ulp(nan, zero,
  1008. maxulp=maxulp))
  1009. assert_raises(AssertionError,
  1010. lambda: assert_array_max_ulp(nan, nzero,
  1011. maxulp=maxulp))
  1012. class TestStringEqual:
  1013. def test_simple(self):
  1014. assert_string_equal("hello", "hello")
  1015. assert_string_equal("hello\nmultiline", "hello\nmultiline")
  1016. with pytest.raises(AssertionError) as exc_info:
  1017. assert_string_equal("foo\nbar", "hello\nbar")
  1018. msg = str(exc_info.value)
  1019. assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
  1020. assert_raises(AssertionError,
  1021. lambda: assert_string_equal("foo", "hello"))
  1022. def test_regex(self):
  1023. assert_string_equal("a+*b", "a+*b")
  1024. assert_raises(AssertionError,
  1025. lambda: assert_string_equal("aaa", "a+b"))
  1026. def assert_warn_len_equal(mod, n_in_context):
  1027. try:
  1028. mod_warns = mod.__warningregistry__
  1029. except AttributeError:
  1030. # the lack of a __warningregistry__
  1031. # attribute means that no warning has
  1032. # occurred; this can be triggered in
  1033. # a parallel test scenario, while in
  1034. # a serial test scenario an initial
  1035. # warning (and therefore the attribute)
  1036. # are always created first
  1037. mod_warns = {}
  1038. num_warns = len(mod_warns)
  1039. if 'version' in mod_warns:
  1040. # Python 3 adds a 'version' entry to the registry,
  1041. # do not count it.
  1042. num_warns -= 1
  1043. assert_equal(num_warns, n_in_context)
  1044. def test_warn_len_equal_call_scenarios():
  1045. # assert_warn_len_equal is called under
  1046. # varying circumstances depending on serial
  1047. # vs. parallel test scenarios; this test
  1048. # simply aims to probe both code paths and
  1049. # check that no assertion is uncaught
  1050. # parallel scenario -- no warning issued yet
  1051. class mod:
  1052. pass
  1053. mod_inst = mod()
  1054. assert_warn_len_equal(mod=mod_inst,
  1055. n_in_context=0)
  1056. # serial test scenario -- the __warningregistry__
  1057. # attribute should be present
  1058. class mod:
  1059. def __init__(self):
  1060. self.__warningregistry__ = {'warning1':1,
  1061. 'warning2':2}
  1062. mod_inst = mod()
  1063. assert_warn_len_equal(mod=mod_inst,
  1064. n_in_context=2)
  1065. def _get_fresh_mod():
  1066. # Get this module, with warning registry empty
  1067. my_mod = sys.modules[__name__]
  1068. try:
  1069. my_mod.__warningregistry__.clear()
  1070. except AttributeError:
  1071. # will not have a __warningregistry__ unless warning has been
  1072. # raised in the module at some point
  1073. pass
  1074. return my_mod
  1075. def test_clear_and_catch_warnings():
  1076. # Initial state of module, no warnings
  1077. my_mod = _get_fresh_mod()
  1078. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1079. with clear_and_catch_warnings(modules=[my_mod]):
  1080. warnings.simplefilter('ignore')
  1081. warnings.warn('Some warning')
  1082. assert_equal(my_mod.__warningregistry__, {})
  1083. # Without specified modules, don't clear warnings during context.
  1084. # catch_warnings doesn't make an entry for 'ignore'.
  1085. with clear_and_catch_warnings():
  1086. warnings.simplefilter('ignore')
  1087. warnings.warn('Some warning')
  1088. assert_warn_len_equal(my_mod, 0)
  1089. # Manually adding two warnings to the registry:
  1090. my_mod.__warningregistry__ = {'warning1': 1,
  1091. 'warning2': 2}
  1092. # Confirm that specifying module keeps old warning, does not add new
  1093. with clear_and_catch_warnings(modules=[my_mod]):
  1094. warnings.simplefilter('ignore')
  1095. warnings.warn('Another warning')
  1096. assert_warn_len_equal(my_mod, 2)
  1097. # Another warning, no module spec it clears up registry
  1098. with clear_and_catch_warnings():
  1099. warnings.simplefilter('ignore')
  1100. warnings.warn('Another warning')
  1101. assert_warn_len_equal(my_mod, 0)
  1102. def test_suppress_warnings_module():
  1103. # Initial state of module, no warnings
  1104. my_mod = _get_fresh_mod()
  1105. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1106. def warn_other_module():
  1107. # Apply along axis is implemented in python; stacklevel=2 means
  1108. # we end up inside its module, not ours.
  1109. def warn(arr):
  1110. warnings.warn("Some warning 2", stacklevel=2)
  1111. return arr
  1112. np.apply_along_axis(warn, 0, [0])
  1113. # Test module based warning suppression:
  1114. assert_warn_len_equal(my_mod, 0)
  1115. with suppress_warnings() as sup:
  1116. sup.record(UserWarning)
  1117. # suppress warning from other module (may have .pyc ending),
  1118. # if apply_along_axis is moved, had to be changed.
  1119. sup.filter(module=np.lib.shape_base)
  1120. warnings.warn("Some warning")
  1121. warn_other_module()
  1122. # Check that the suppression did test the file correctly (this module
  1123. # got filtered)
  1124. assert_equal(len(sup.log), 1)
  1125. assert_equal(sup.log[0].message.args[0], "Some warning")
  1126. assert_warn_len_equal(my_mod, 0)
  1127. sup = suppress_warnings()
  1128. # Will have to be changed if apply_along_axis is moved:
  1129. sup.filter(module=my_mod)
  1130. with sup:
  1131. warnings.warn('Some warning')
  1132. assert_warn_len_equal(my_mod, 0)
  1133. # And test repeat works:
  1134. sup.filter(module=my_mod)
  1135. with sup:
  1136. warnings.warn('Some warning')
  1137. assert_warn_len_equal(my_mod, 0)
  1138. # Without specified modules
  1139. with suppress_warnings():
  1140. warnings.simplefilter('ignore')
  1141. warnings.warn('Some warning')
  1142. assert_warn_len_equal(my_mod, 0)
  1143. def test_suppress_warnings_type():
  1144. # Initial state of module, no warnings
  1145. my_mod = _get_fresh_mod()
  1146. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1147. # Test module based warning suppression:
  1148. with suppress_warnings() as sup:
  1149. sup.filter(UserWarning)
  1150. warnings.warn('Some warning')
  1151. assert_warn_len_equal(my_mod, 0)
  1152. sup = suppress_warnings()
  1153. sup.filter(UserWarning)
  1154. with sup:
  1155. warnings.warn('Some warning')
  1156. assert_warn_len_equal(my_mod, 0)
  1157. # And test repeat works:
  1158. sup.filter(module=my_mod)
  1159. with sup:
  1160. warnings.warn('Some warning')
  1161. assert_warn_len_equal(my_mod, 0)
  1162. # Without specified modules
  1163. with suppress_warnings():
  1164. warnings.simplefilter('ignore')
  1165. warnings.warn('Some warning')
  1166. assert_warn_len_equal(my_mod, 0)
  1167. def test_suppress_warnings_decorate_no_record():
  1168. sup = suppress_warnings()
  1169. sup.filter(UserWarning)
  1170. @sup
  1171. def warn(category):
  1172. warnings.warn('Some warning', category)
  1173. with warnings.catch_warnings(record=True) as w:
  1174. warnings.simplefilter("always")
  1175. warn(UserWarning) # should be supppressed
  1176. warn(RuntimeWarning)
  1177. assert_equal(len(w), 1)
  1178. def test_suppress_warnings_record():
  1179. sup = suppress_warnings()
  1180. log1 = sup.record()
  1181. with sup:
  1182. log2 = sup.record(message='Some other warning 2')
  1183. sup.filter(message='Some warning')
  1184. warnings.warn('Some warning')
  1185. warnings.warn('Some other warning')
  1186. warnings.warn('Some other warning 2')
  1187. assert_equal(len(sup.log), 2)
  1188. assert_equal(len(log1), 1)
  1189. assert_equal(len(log2),1)
  1190. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1191. # Do it again, with the same context to see if some warnings survived:
  1192. with sup:
  1193. log2 = sup.record(message='Some other warning 2')
  1194. sup.filter(message='Some warning')
  1195. warnings.warn('Some warning')
  1196. warnings.warn('Some other warning')
  1197. warnings.warn('Some other warning 2')
  1198. assert_equal(len(sup.log), 2)
  1199. assert_equal(len(log1), 1)
  1200. assert_equal(len(log2), 1)
  1201. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1202. # Test nested:
  1203. with suppress_warnings() as sup:
  1204. sup.record()
  1205. with suppress_warnings() as sup2:
  1206. sup2.record(message='Some warning')
  1207. warnings.warn('Some warning')
  1208. warnings.warn('Some other warning')
  1209. assert_equal(len(sup2.log), 1)
  1210. assert_equal(len(sup.log), 1)
  1211. def test_suppress_warnings_forwarding():
  1212. def warn_other_module():
  1213. # Apply along axis is implemented in python; stacklevel=2 means
  1214. # we end up inside its module, not ours.
  1215. def warn(arr):
  1216. warnings.warn("Some warning", stacklevel=2)
  1217. return arr
  1218. np.apply_along_axis(warn, 0, [0])
  1219. with suppress_warnings() as sup:
  1220. sup.record()
  1221. with suppress_warnings("always"):
  1222. for i in range(2):
  1223. warnings.warn("Some warning")
  1224. assert_equal(len(sup.log), 2)
  1225. with suppress_warnings() as sup:
  1226. sup.record()
  1227. with suppress_warnings("location"):
  1228. for i in range(2):
  1229. warnings.warn("Some warning")
  1230. warnings.warn("Some warning")
  1231. assert_equal(len(sup.log), 2)
  1232. with suppress_warnings() as sup:
  1233. sup.record()
  1234. with suppress_warnings("module"):
  1235. for i in range(2):
  1236. warnings.warn("Some warning")
  1237. warnings.warn("Some warning")
  1238. warn_other_module()
  1239. assert_equal(len(sup.log), 2)
  1240. with suppress_warnings() as sup:
  1241. sup.record()
  1242. with suppress_warnings("once"):
  1243. for i in range(2):
  1244. warnings.warn("Some warning")
  1245. warnings.warn("Some other warning")
  1246. warn_other_module()
  1247. assert_equal(len(sup.log), 2)
  1248. def test_tempdir():
  1249. with tempdir() as tdir:
  1250. fpath = os.path.join(tdir, 'tmp')
  1251. with open(fpath, 'w'):
  1252. pass
  1253. assert_(not os.path.isdir(tdir))
  1254. raised = False
  1255. try:
  1256. with tempdir() as tdir:
  1257. raise ValueError()
  1258. except ValueError:
  1259. raised = True
  1260. assert_(raised)
  1261. assert_(not os.path.isdir(tdir))
  1262. def test_temppath():
  1263. with temppath() as fpath:
  1264. with open(fpath, 'w'):
  1265. pass
  1266. assert_(not os.path.isfile(fpath))
  1267. raised = False
  1268. try:
  1269. with temppath() as fpath:
  1270. raise ValueError()
  1271. except ValueError:
  1272. raised = True
  1273. assert_(raised)
  1274. assert_(not os.path.isfile(fpath))
  1275. class my_cacw(clear_and_catch_warnings):
  1276. class_modules = (sys.modules[__name__],)
  1277. def test_clear_and_catch_warnings_inherit():
  1278. # Test can subclass and add default modules
  1279. my_mod = _get_fresh_mod()
  1280. with my_cacw():
  1281. warnings.simplefilter('ignore')
  1282. warnings.warn('Some warning')
  1283. assert_equal(my_mod.__warningregistry__, {})
  1284. @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
  1285. class TestAssertNoGcCycles:
  1286. """ Test assert_no_gc_cycles """
  1287. def test_passes(self):
  1288. def no_cycle():
  1289. b = []
  1290. b.append([])
  1291. return b
  1292. with assert_no_gc_cycles():
  1293. no_cycle()
  1294. assert_no_gc_cycles(no_cycle)
  1295. def test_asserts(self):
  1296. def make_cycle():
  1297. a = []
  1298. a.append(a)
  1299. a.append(a)
  1300. return a
  1301. with assert_raises(AssertionError):
  1302. with assert_no_gc_cycles():
  1303. make_cycle()
  1304. with assert_raises(AssertionError):
  1305. assert_no_gc_cycles(make_cycle)
  1306. @pytest.mark.slow
  1307. def test_fails(self):
  1308. """
  1309. Test that in cases where the garbage cannot be collected, we raise an
  1310. error, instead of hanging forever trying to clear it.
  1311. """
  1312. class ReferenceCycleInDel:
  1313. """
  1314. An object that not only contains a reference cycle, but creates new
  1315. cycles whenever it's garbage-collected and its __del__ runs
  1316. """
  1317. make_cycle = True
  1318. def __init__(self):
  1319. self.cycle = self
  1320. def __del__(self):
  1321. # break the current cycle so that `self` can be freed
  1322. self.cycle = None
  1323. if ReferenceCycleInDel.make_cycle:
  1324. # but create a new one so that the garbage collector has more
  1325. # work to do.
  1326. ReferenceCycleInDel()
  1327. try:
  1328. w = weakref.ref(ReferenceCycleInDel())
  1329. try:
  1330. with assert_raises(RuntimeError):
  1331. # this will be unable to get a baseline empty garbage
  1332. assert_no_gc_cycles(lambda: None)
  1333. except AssertionError:
  1334. # the above test is only necessary if the GC actually tried to free
  1335. # our object anyway, which python 2.7 does not.
  1336. if w() is not None:
  1337. pytest.skip("GC does not call __del__ on cyclic objects")
  1338. raise
  1339. finally:
  1340. # make sure that we stop creating reference cycles
  1341. ReferenceCycleInDel.make_cycle = False