Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions lib/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,8 @@ def __setitem__(self, key, val):
and val is rcsetup._auto_backend_sentinel
and "backend" in self):
return
valid_key = _api.check_getitem(
self.validate, rcParam=key, _error_cls=KeyError
)
valid_key = _api.getitem_checked(
self.validate, _error_cls=KeyError)("rcParam", key)
try:
cval = valid_key(val)
except ValueError as ve:
Expand Down
178 changes: 89 additions & 89 deletions lib/matplotlib/_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import difflib
import functools
import itertools
import operator
import pathlib
import re
import sys
Expand Down Expand Up @@ -69,95 +70,77 @@ def fget(self):
return self._fget


# In the following check_foo() functions, the first parameter is positional-only to make
# e.g. `_api.check_isinstance([...], types=foo)` work.

def check_isinstance(types, /, **kwargs):
@functools.cache
def check_isinstance(types, /):
"""
For each *key, value* pair in *kwargs*, check that *value* is an instance
of one of *types*; if not, raise an appropriate TypeError.
Return a function that checks isinstance, raising TypeError on mismatch.

As a special case, a ``None`` entry in *types* is treated as NoneType.

Examples
--------
>>> _api.check_isinstance((SomeClass, None), arg=arg)
>>> _api.check_isinstance((SomeClass, None))("arg", arg)
"""
none_type = type(None)
types = ((types,) if isinstance(types, type) else
(none_type,) if types is None else
tuple(none_type if tp is None else tp for tp in types))
if types is None:
types = (none_type,)
elif isinstance(types, type):
types = (types,)
else:
types = tuple(none_type if tp is None else tp for tp in types)

def type_name(tp):
return ("None" if tp is none_type
else tp.__qualname__ if tp.__module__ == "builtins"
else f"{tp.__module__}.{tp.__qualname__}")

for k, v in kwargs.items():
if not isinstance(v, types):
def check(name, value, _types=types):
if not isinstance(value, _types):
names = [*map(type_name, types)]
if "None" in names: # Move it to the end for better wording.
names.remove("None")
names.append("None")
raise TypeError(
"{!r} must be an instance of {}, not a {}".format(
k,
name,
", ".join(names[:-1]) + " or " + names[-1]
if len(names) > 1 else names[0],
type_name(type(v))))
type_name(type(value))))

return check


@functools.cache
def _check_in_tuple(values, /):
def check(name, value, _values=values):
if value not in _values:
raise ValueError(
f"{value!r} is not a valid value for {name}; "
f"supported values are {', '.join(map(repr, _values))}")

return check

def check_in_list(values, /, *, _print_supported_values=True, **kwargs):

def check_in_list(values, /):
"""
For each *key, value* pair in *kwargs*, check that *value* is in *values*;
if not, raise an appropriate ValueError.
Return a function that checks if a value is in the allowed list.

Parameters
----------
values : iterable
Sequence of values to check on.

Note: All values must support == comparisons.
This means in particular the entries must not be numpy arrays.
_print_supported_values : bool, default: True
Whether to print *values* when raising ValueError.
**kwargs : dict
*key, value* pairs as keyword arguments to find in *values*.

Raises
------
ValueError
If any *value* in *kwargs* is not found in *values*.
Sequence of allowed values.

Examples
--------
>>> _api.check_in_list(["foo", "bar"], arg=arg, other_arg=other_arg)
>>> _api.check_in_list(["foo", "bar"])("arg", arg)
"""
if not kwargs:
raise TypeError("No argument to check!")
for key, val in kwargs.items():
try:
exists = val in values
except ValueError:
# `in` internally uses `val == values[i]`. There are some objects
# that do not support == to arbitrary other objects, in particular
# numpy arrays.
# Since such objects are not allowed in values, we can gracefully
# handle the case that val (typically provided by users) is of such
# type and directly state it's not in the list instead of letting
# the individual `val == values[i]` ValueError surface.
exists = False
if not exists:
msg = f"{val!r} is not a valid value for {key}"
if _print_supported_values:
msg += f"; supported values are {', '.join(map(repr, values))}"
raise ValueError(msg)
return _check_in_tuple(tuple(values))


def check_shape(shape, /, **kwargs):
@functools.cache
def check_shape(shape, /):
"""
For each *key, value* pair in *kwargs*, check that *value* has the shape *shape*;
if not, raise an appropriate ValueError.
Return a function that checks array shape, raising ValueError on mismatch.

*None* in the shape is treated as a "free" size that can have any length.
e.g. (None, 2) -> (N, 2)
Expand All @@ -168,56 +151,73 @@ def check_shape(shape, /, **kwargs):
--------
To check for (N, 2) shaped arrays

>>> _api.check_shape((None, 2), arg=arg, other_arg=other_arg)
>>> _api.check_shape((None, 2))("arg", arg)
"""
for k, v in kwargs.items():
data_shape = v.shape

if (len(data_shape) != len(shape)
or any(s != t and t is not None for s, t in zip(data_shape, shape))):
dim_labels = iter(itertools.chain(
'NMLKJIH',
(f"D{i}" for i in itertools.count())))
text_shape = ", ".join([str(n) if n is not None else next(dim_labels)
for n in shape[::-1]][::-1])
if len(shape) == 1:
text_shape += ","
def raise_error(name, vshape):
dim_labels = iter(itertools.chain(
'NMLKJIH', (f"D{i}" for i in itertools.count())))
text_shape = ", ".join([str(n) if n is not None else next(dim_labels)
for n in shape[::-1]][::-1])
if len(shape) == 1:
text_shape += ","
raise ValueError(
f"{name!r} must be {len(shape)}D with shape ({text_shape}), "
f"but your input has shape {vshape}"
)

ndim = len(shape)
fixed = tuple((i, n) for i, n in enumerate(shape) if n is not None)

if not fixed:
# All dimensions are None, only check ndim
def check(name, value, _ndim=ndim, _raise=raise_error):
if len(value.shape) != _ndim:
_raise(name, value.shape)
else:
# Use itemgetter for fixed dimension extraction
get_dims = operator.itemgetter(*(i for i, _ in fixed))
expected = get_dims(shape)

raise ValueError(
f"{k!r} must be {len(shape)}D with shape ({text_shape}), "
f"but your input has shape {v.shape}"
)
def check(name, value,
_ndim=ndim, _get=get_dims, _exp=expected, _raise=raise_error):
vshape = value.shape
if len(vshape) != _ndim or _get(vshape) != _exp:
_raise(name, vshape)

return check


def check_getitem(mapping, /, _error_cls=ValueError, **kwargs):
def getitem_checked(mapping, /, _error_cls=ValueError):
"""
*kwargs* must consist of a single *key, value* pair. If *key* is in
*mapping*, return ``mapping[value]``; else, raise an appropriate
ValueError.
Return a function that looks up a value in *mapping*, raising on invalid keys.

Parameters
----------
mapping : dict
The mapping to look up values in.
_error_cls :
Class of error to raise.
Class of error to raise on invalid key.

Examples
--------
>>> _api.check_getitem({"foo": "bar"}, arg=arg)
>>> _api.getitem_checked({"foo": "bar"})("arg", arg)
"""
if len(kwargs) != 1:
raise ValueError("check_getitem takes a single keyword argument")
(k, v), = kwargs.items()
try:
return mapping[v]
except KeyError:
if len(mapping) > 5:
if len(best := difflib.get_close_matches(v, mapping.keys(), cutoff=0.5)):
suggestion = f"Did you mean one of {best}?"
def check(name, value, _mapping=mapping, _err=_error_cls):
try:
return _mapping[value]
except KeyError:
if len(_mapping) > 5:
best = difflib.get_close_matches(value, _mapping.keys(), cutoff=0.5)
if best:
suggestion = f"Did you mean one of {best}?"
else:
suggestion = ""
else:
suggestion = ""
else:
suggestion = f"Supported values are {', '.join(map(repr, mapping))}"
raise _error_cls(f"{v!r} is not a valid value for {k}. {suggestion}") from None
suggestion = f"Supported values are {', '.join(map(repr, _mapping))}"
raise _err(f"{value!r} is not a valid value for {name}. "
f"{suggestion}") from None

return check


def caching_module_getattr(cls):
Expand Down
16 changes: 9 additions & 7 deletions lib/matplotlib/_api/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ class classproperty(Any):
def fget(self) -> Callable[[_T], Any]: ...

def check_isinstance(
types: type | tuple[type | None, ...], /, **kwargs: Any
) -> None: ...
types: type | tuple[type | None, ...], /
) -> Callable[[str, Any], None]: ...
def check_in_list(
values: Sequence[Any], /, *, _print_supported_values: bool = ..., **kwargs: Any
) -> None: ...
def check_shape(shape: tuple[int | None, ...], /, **kwargs: NDArray) -> None: ...
values: Sequence[Any], /
) -> Callable[[str, Any], None]: ...
def check_shape(
shape: tuple[int | None, ...], /
) -> Callable[[str, NDArray], None]: ...
def check_getitem(
mapping: Mapping[Any, _T], /, _error_cls: type[Exception], **kwargs: Any
) -> _T: ...
mapping: Mapping[Any, _T], /, _error_cls: type[Exception] = ...
) -> Callable[[str, Any], _T]: ...
def caching_module_getattr(cls: type) -> Callable[[str], Any]: ...
@overload
def define_aliases(
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/_type1font.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _decrypt(ciphertext, key, ndiscard=4):
That number of bytes is discarded from the beginning of plaintext.
"""

key = _api.check_getitem({'eexec': 55665, 'charstring': 4330}, key=key)
key = _api.getitem_checked({'eexec': 55665, 'charstring': 4330})("key", key)
plaintext = []
for byte in ciphertext:
plaintext.append(byte ^ (key >> 8))
Expand All @@ -483,7 +483,7 @@ def _encrypt(plaintext, key, ndiscard=4):
cryptanalysis.
"""

key = _api.check_getitem({'eexec': 55665, 'charstring': 4330}, key=key)
key = _api.getitem_checked({'eexec': 55665, 'charstring': 4330})("key", key)
ciphertext = []
for byte in b'\0' * ndiscard + plaintext:
c = byte ^ (key >> 8)
Expand Down
6 changes: 3 additions & 3 deletions lib/matplotlib/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ def __init__(self, fps=30, codec=None, bitrate=None, extra_args=None,
extra_args = () # Don't lookup nonexistent rcParam[args_key].
self.embed_frames = embed_frames
self.default_mode = default_mode.lower()
_api.check_in_list(['loop', 'once', 'reflect'],
default_mode=self.default_mode)
_api.check_in_list(('loop', 'once', 'reflect'))(
"default_mode", self.default_mode)

# Save embed limit, which is given in MB
self._bytes_limit = mpl._val_or_rc(embed_limit, 'animation.embed_limit')
Expand All @@ -771,7 +771,7 @@ def __init__(self, fps=30, codec=None, bitrate=None, extra_args=None,

def setup(self, fig, outfile, dpi=None, frame_dir=None):
outfile = Path(outfile)
_api.check_in_list(['.html', '.htm'], outfile_extension=outfile.suffix)
_api.check_in_list(('.html', '.htm'))("outfile_extension", outfile.suffix)

self._saved_frames = []
self._total_bytes = 0
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def set_clip_box(self, clipbox):
clipping for an artist added to an Axes.

"""
_api.check_isinstance((BboxBase, None), clipbox=clipbox)
_api.check_isinstance((BboxBase, None))("clipbox", clipbox)
if clipbox != self.clipbox:
self.clipbox = clipbox
self.pchanged()
Expand Down
Loading
Loading