-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathext_test_case.py
More file actions
185 lines (152 loc) · 5.26 KB
/
ext_test_case.py
File metadata and controls
185 lines (152 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Callable, List, Optional
import numpy
from numpy.testing import assert_allclose
def unit_test_going():
"""
Enables a flag telling the script is running while testing it.
Avois unit tests to be very long.
"""
going = int(os.environ.get("UNITTEST_GOING", 0))
return going == 1
def ignore_warnings(warns: List[Warning]) -> Callable:
"""
Catches warnings.
:param warns: warnings to ignore
"""
def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)
return call_f
return wrapper
def hide_stdout(f: Optional[Callable] = None) -> Callable:
"""
Catches warnings, hides standard output.
The function may be disabled by setting ``UNHIDE=1``
before running the unit test.
:param f: the function is called with the stdout as an argument
"""
def wrapper(fct):
def call_f(self):
if os.environ.get("UNHIDE", ""):
fct(self)
return
st = StringIO()
with redirect_stdout(st), warnings.catch_warnings():
warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))
try:
fct(self)
except AssertionError as e:
if "torch is not recent enough, file" in str(e):
raise unittest.SkipTest(str(e)) # noqa: B904
raise
if f is not None:
f(st.getvalue())
return None
try: # noqa: SIM105
call_f.__name__ = fct.__name__
except AttributeError:
pass
return call_f
return wrapper
class sys_path_append:
"""
Stores the content of :epkg:`*py:sys:path` and
restores it afterwards.
"""
def __init__(self, paths, position=-1):
"""
:param paths: paths to add
:param position: where to add it
"""
self.to_add = paths if isinstance(paths, list) else [paths]
self.position = position
def __enter__(self):
"""
Modifies ``sys.path``.
"""
self.store = sys.path.copy()
if self.position == -1:
sys.path.extend(self.to_add)
else:
for p in reversed(self.to_add):
sys.path.insert(self.position, p)
def __exit__(self, exc_type, exc_value, traceback):
"""
Restores``sys.path``.
"""
sys.path = self.store
class ExtTestCase(unittest.TestCase):
_warns = []
def assertExists(self, name):
if not os.path.exists(name):
raise AssertionError(f"File or folder {name!r} does not exists.")
def assertEqualArray(
self,
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
):
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
assert_allclose(expected, value, atol=atol, rtol=rtol)
def assertAlmostEqual(
self,
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
):
if not isinstance(expected, numpy.ndarray):
expected = numpy.array(expected)
if not isinstance(value, numpy.ndarray):
value = numpy.array(value).astype(expected.dtype)
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
def assertRaise(self, fct: Callable, exc_type: Exception):
try:
fct()
except exc_type as e:
if not isinstance(e, exc_type):
raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
return
raise AssertionError("No exception was raised.")
def assertEmpty(self, value: Any):
if value is None:
return
if len(value) == 0:
return
raise AssertionError(f"value is not empty: {value!r}.")
def assertNotEmpty(self, value: Any):
if value is None:
raise AssertionError(f"value is empty: {value!r}.")
if isinstance(value, (list, dict, tuple, set)):
if len(value) == 0:
raise AssertionError(f"value is empty: {value!r}.")
def assertStartsWith(self, prefix: str, full: str):
if not full.startswith(prefix):
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
@classmethod
def tearDownClass(cls):
for name, line, w in cls._warns:
warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}", stacklevel=0)
def capture(self, fct: Callable):
"""
Runs a function and capture standard output and error.
:param fct: function to run
:return: result of *fct*, output, error
"""
sout = StringIO()
serr = StringIO()
with redirect_stdout(sout), redirect_stderr(serr):
res = fct()
return res, sout.getvalue(), serr.getvalue()