forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredictor.py
More file actions
122 lines (107 loc) · 5.03 KB
/
predictor.py
File metadata and controls
122 lines (107 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import print_function, absolute_import
from typing import Optional
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint
from sagemaker.session import Session
# base_predictor was refactored from predictor.
# this import ensures backward compatibility.
from sagemaker.base_predictor import ( # noqa: F401 # pylint: disable=W0611
Predictor,
PredictorBase,
RealTimePredictor,
)
def retrieve_default(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> Predictor:
"""Retrieves the default predictor for the model matching the given arguments.
Args:
endpoint_name (str): Endpoint name for which to create a predictor.
inference_component_name (str): Name of the Amazon SageMaker inference component
from which to optionally create a predictor. (Default: None).
sagemaker_session (Session): The SageMaker Session to attach to the predictor.
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
region (str): The AWS Region for which to retrieve the default predictor.
(Default: None).
model_id (str): The model ID of the model for which to
retrieve the default predictor. (Default: None).
model_version (str): The version of the model for which to retrieve the
default predictor. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
config_name (Optional[str]): The name of the configuration to use for the
predictor. (Default: None)
Returns:
Predictor: The default predictor to use for the model.
Raises:
ValueError: If the combination of arguments specified is not supported, or if a model ID or
version cannot be inferred from the endpoint.
"""
if model_id is None:
(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
inferred_config_name,
_,
) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session)
if not inferred_model_id:
raise ValueError(
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
"Please specify JumpStart `model_id` when retrieving default "
"predictor for this endpoint."
)
model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
config_name = config_name or inferred_config_name or None
else:
model_version = model_version or "*"
predictor = Predictor(
endpoint_name=endpoint_name,
component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)
return get_default_predictor(
predictor=predictor,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)