forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparameter.py
More file actions
173 lines (131 loc) · 5.59 KB
/
parameter.py
File metadata and controls
173 lines (131 loc) · 5.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import json
from typing import Union
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.utils import to_string
class ParameterRange(object):
"""Base class for representing parameter ranges.
This is used to define what hyperparameters to tune for an Amazon SageMaker
hyperparameter tuning job and to verify hyperparameters for Marketplace Algorithms.
"""
__all_types__ = ("Continuous", "Categorical", "Integer")
def __init__(
self,
min_value: Union[int, float, PipelineVariable],
max_value: Union[int, float, PipelineVariable],
scaling_type: Union[str, PipelineVariable] = "Auto",
):
"""Initialize a parameter range.
Args:
min_value (float or int or PipelineVariable): The minimum value for the range.
max_value (float or int or PipelineVariable): The maximum value for the range.
scaling_type (str or PipelineVariable): The scale used for searching the range during
tuning (default: 'Auto'). Valid values: 'Auto', 'Linear',
'Logarithmic' and 'ReverseLogarithmic'.
"""
self.min_value = min_value
self.max_value = max_value
self.scaling_type = scaling_type
def is_valid(self, value):
"""Determine if a value is valid within this ParameterRange.
Args:
value (float or int): The value to be verified.
Returns:
bool: True if valid, False otherwise.
"""
return self.min_value <= value <= self.max_value
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return float(value)
def as_tuning_range(self, name):
"""Represent the parameter range as a dictionary.
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, str]: A dictionary that contains the name and values of
the hyperparameter.
"""
return {
"Name": name,
"MinValue": to_string(self.min_value),
"MaxValue": to_string(self.max_value),
"ScalingType": self.scaling_type,
}
class ContinuousParameter(ParameterRange):
"""A class for representing hyperparameters that have a continuous range of possible values.
Args:
min_value (float): The minimum value for the range.
max_value (float): The maximum value for the range.
"""
__name__ = "Continuous"
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return float(value)
class CategoricalParameter(ParameterRange):
"""A class for representing hyperparameters that have a discrete list of possible values."""
__name__ = "Categorical"
def __init__(self, values): # pylint: disable=super-init-not-called
"""Initialize a ``CategoricalParameter``.
Args:
values (list or object): The possible values for the hyperparameter.
This input will be converted into a list of strings.
"""
values = values if isinstance(values, list) else [values]
self.values = [to_string(v) for v in values]
def as_tuning_range(self, name):
"""Represent the parameter range as a dictionary.
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, list[str]]: A dictionary that contains the name and values
of the hyperparameter.
"""
return {"Name": name, "Values": self.values}
def as_json_range(self, name):
"""Represent the parameter range as a dictionary.
Dictionary is suitable for a request to create an Amazon SageMaker hyperparameter tuning job
using one of the deep learning frameworks.
The deep learning framework images require that hyperparameters be
serialized as JSON.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, list[str]]: A dictionary that contains the name and values of the
hyperparameter, where the values are serialized as JSON.
"""
return {"Name": name, "Values": [json.dumps(v) for v in self.values]}
def is_valid(self, value):
"""Placeholder docstring"""
return value in self.values
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return str(value)
class IntegerParameter(ParameterRange):
"""A class for representing hyperparameters that have an integer range of possible values.
Args:
min_value (int): The minimum value for the range.
max_value (int): The maximum value for the range.
"""
__name__ = "Integer"
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return int(value)