-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathtorch_wrapper.py
More file actions
39 lines (31 loc) · 1.11 KB
/
torch_wrapper.py
File metadata and controls
39 lines (31 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import importlib
TORCH_AVAILABLE = False
_torch = None
_torch_import_error = None
def _import_torch():
global _torch, TORCH_AVAILABLE, _torch_import_error
try:
_torch = importlib.import_module("torch")
TORCH_AVAILABLE = True
except Exception as e:
# Catch import errors including CUDA lib missing
TORCH_AVAILABLE = False
_torch_import_error = e
_import_torch()
def get_torch():
"""
Return the torch module if available, else raise a friendly error.
This prevents crashing on import if CUDA libs are missing.
"""
if TORCH_AVAILABLE:
return _torch
else:
error_message = (
"Torch is not available or failed to import.\n"
"Original error:\n"
f"{_torch_import_error}\n\n"
"If you are on a CPU-only system, make sure you install the CPU-only torch wheel:\n"
" pip install torch torchvision -f https://download.pytorch.org/whl/cpu\n"
"Or check your CUDA installation if using GPU torch.\n"
)
raise ImportError(error_message) from _torch_import_error