-
Notifications
You must be signed in to change notification settings - Fork 683
Expand file tree
/
Copy pathconftest.py
More file actions
140 lines (118 loc) · 4.47 KB
/
conftest.py
File metadata and controls
140 lines (118 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
import sys
from typing import Optional
import pytest
class ProgressRecorder:
"""Reusable progress callback recorder for index build tests."""
def __init__(
self,
fail_after: Optional[int] = None,
fail_on_tag: Optional[str] = None,
):
self.events = []
self.fail_after = fail_after
self.fail_on_tag = fail_on_tag
def __call__(self, event):
self.events.append(event)
event_tag = f"{event.event}:{event.stage}"
if self.fail_on_tag is not None and event_tag == self.fail_on_tag:
raise RuntimeError("progress callback failure")
if self.fail_after is not None and len(self.events) >= self.fail_after:
raise RuntimeError("progress callback failure")
def progress_event_tags(events):
return [f"{event.event}:{event.stage}" for event in events]
def stage_progress_values(events, stage):
return [
event.completed
for event in events
if event.event == "progress"
and event.stage == stage
and event.completed is not None
]
@pytest.fixture(params=(True, False))
def provide_pandas(request, monkeypatch):
if not request.param:
monkeypatch.setitem(sys.modules, "pd", None)
return request.param
def disable_items_with_mark(items, mark, reason):
skipper = pytest.mark.skip(reason=reason)
for item in items:
if mark in item.keywords:
item.add_marker(skipper)
# These are initialization hooks and must have an exact name for pytest to pick them up
# https://docs.pytest.org/en/7.1.x/reference/reference.html
def pytest_addoption(parser):
parser.addoption(
"--run-integration",
action="store_true",
default=False,
help="Run integration tests (requires S3 buckets to be setup with access)",
)
parser.addoption(
"--run-slow",
action="store_true",
default=False,
help="Run slow tests",
)
parser.addoption(
"--run-forward",
action="store_true",
default=False,
help="Run forward compatibility tests (requires files to be generated already)",
)
parser.addoption(
"--run-compat",
action="store_true",
default=False,
help="Run upgrade/downgrade compatibility tests (creates virtual environments)",
)
def pytest_configure(config):
config.addinivalue_line(
"markers",
"forward: mark tests that require forward compatibility datagen files",
)
config.addinivalue_line(
"markers", "integration: mark test that requires object storage integration"
)
config.addinivalue_line(
"markers", "slow: mark tests that require large CPU or RAM resources"
)
config.addinivalue_line(
"markers",
"compat: mark tests that run upgrade/downgrade compatibility checks",
)
def pytest_collection_modifyitems(config, items):
if not config.getoption("--run-integration"):
disable_items_with_mark(items, "integration", "--run-integration not specified")
if not config.getoption("--run-slow"):
disable_items_with_mark(items, "slow", "--run-slow not specified")
if not config.getoption("--run-forward"):
disable_items_with_mark(items, "forward", "--run-forward not specified")
if not config.getoption("--run-compat"):
disable_items_with_mark(items, "compat", "--run-compat not specified")
try:
import torch
# torch.cuda.is_available will return True on some CI machines even though any
# attempt to use CUDA will then fail. torch.cuda.device_count seems to be more
# reliable
if (
torch.backends.cuda.is_built()
and not torch.cuda.is_available
or torch.cuda.device_count() <= 0
):
disable_items_with_mark(
items, "cuda", "torch is installed but cuda is not available"
)
if (
not torch.backends.mps.is_available()
or not torch.backends.mps.is_built()
):
disable_items_with_mark(
items, "gpu", "torch is installed but no gpu is available"
)
except ImportError as err:
reason = f"torch not installed ({err})"
disable_items_with_mark(items, "torch", reason)
disable_items_with_mark(items, "cuda", reason)
disable_items_with_mark(items, "gpu", reason)