# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to represent a device."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
class DeviceSpec(object):
"""Represents a (possibly partial) specification for a TensorFlow device.
`DeviceSpec`s are used throughout TensorFlow to describe where state is stored
and computations occur. Using `DeviceSpec` allows you to parse device spec
strings to verify their validity, merge them or compose them programmatically.
Example:
```python
# Place the operations on device "GPU:0" in the "ps" job.
device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
with tf.device(device_spec):
# Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
my_var = tf.Variable(..., name="my_variable")
squared_var = tf.square(my_var)
```
If a `DeviceSpec` is partially specified, it will be merged with other
`DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
components defined in inner scopes take precedence over those defined in
outer scopes.
```python
with tf.device(DeviceSpec(job="train", )):
with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
# Nodes created here will be assigned to /job:ps/device:GPU:0.
with tf.device(DeviceSpec(device_type="GPU", device_index=1):
# Nodes created here will be assigned to /job:train/device:GPU:1.
```
A `DeviceSpec` consists of 5 components -- each of
which is optionally specified:
* Job: The job name.
* Replica: The replica index.
* Task: The task index.
* Device type: The device type string (e.g. "CPU" or "GPU").
* Device index: The device index.
"""
def __init__(self, job=None, replica=None, task=None, device_type=None,
device_index=None):
"""Create a new `DeviceSpec` object.
Args:
job: string. Optional job name.
replica: int. Optional replica index.
task: int. Optional task index.
device_type: Optional device type string (e.g. "CPU" or "GPU")
device_index: int. Optional device index. If left
unspecified, device represents 'any' device_index.
"""
self.job = job
self.replica = replica
self.task = task
if device_type == "cpu" or device_type == "gpu":
# For backwards compatibility only, we support lowercase variants of
# cpu and gpu but turn them into uppercase here.
self.device_type = device_type.upper()
else:
self.device_type = device_type
self.device_index = device_index
def _clear(self):
self._job = None
self._replica = None
self._task = None
self.device_type = None
self.device_index = None
@property
def job(self):
return self._job
@job.setter
def job(self, job):
if job is not None:
self._job = str(job)
else:
self._job = None
@property
def replica(self):
return self._replica
@replica.setter
def replica(self, replica):
if replica is not None:
self._replica = int(replica)
else:
self._replica = None
@property
def task(self):
return self._task
@task.setter
def task(self, task):
if task is not None:
self._task = int(task)
else:
self._task = None
def parse_from_string(self, spec):
"""Parse a `DeviceSpec` name into its components.
Args:
spec: a string of the form
/job: