forked from tensorlayer/TensorLayer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextend.py
More file actions
100 lines (81 loc) · 3.12 KB
/
Copy pathextend.py
File metadata and controls
100 lines (81 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
import tensorflow as tf
from .. import _logging as logging
from .core import *
from ..deprecation import deprecated_alias
__all__ = [
'ExpandDimsLayer',
'TileLayer',
]
class ExpandDimsLayer(Layer):
"""
The :class:`ExpandDimsLayer` class inserts a dimension of 1 into a tensor's shape,
see `tf.expand_dims() <https://www.tensorflow.org/api_docs/python/tf/expand_dims>`__ .
Parameters
----------
prev_layer : :class:`Layer`
The previous layer.
axis : int
The dimension index at which to expand the shape of input.
name : str
A unique layer name.
Examples
--------
>>> x = tf.placeholder(tf.float32, (None, 100))
>>> n = tl.layers.InputLayer(x, name='in')
>>> n = tl.layers.ExpandDimsLayer(n, 2)
... [None, 100, 1]
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self,
prev_layer,
axis,
name='expand_dims',
):
super(ExpandDimsLayer, self).__init__(prev_layer=prev_layer, name=name)
logging.info("ExpandDimsLayer %s: axis:%d" % (name, axis))
self.inputs = prev_layer.outputs
with tf.variable_scope(name):
try: # TF12 TF1.0
self.outputs = tf.expand_dims(self.inputs, axis=axis)
except Exception: # TF11
self.outputs = tf.expand_dims(self.inputs, dim=axis)
# self.all_layers = list(layer.all_layers)
self.all_params = list(prev_layer.all_params)
self.all_drop = dict(prev_layer.all_drop)
self.all_layers.append(self.outputs)
# self.all_params.extend( variables )
class TileLayer(Layer):
"""
The :class:`TileLayer` class constructs a tensor by tiling a given tensor,
see `tf.tile() <https://www.tensorflow.org/api_docs/python/tf/tile>`__ .
Parameters
----------
prev_layer : :class:`Layer`
The previous layer.
multiples: tensor
Must be one of the following types: int32, int64.
1-D Length must be the same as the number of dimensions in input.
name : str
A unique layer name.
Examples
--------
>>> x = tf.placeholder(tf.float32, (None, 100))
>>> n = tl.layers.InputLayer(x, name='in')
>>> n = tl.layers.ExpandDimsLayer(n, 2)
>>> n = tl.layers.TileLayer(n, [-1, 1, 3])
... [None, 100, 3]
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(self, prev_layer, multiples=None, name='tile'):
super(TileLayer, self).__init__(prev_layer=prev_layer, name=name)
logging.info("TileLayer %s: multiples:%s" % (name, multiples))
self.inputs = prev_layer.outputs
with tf.variable_scope(name):
self.outputs = tf.tile(self.inputs, multiples=multiples)
# self.all_layers = list(layer.all_layers)
# self.all_params = list(layer.all_params)
# self.all_drop = dict(layer.all_drop)
self.all_layers.append(self.outputs)
# self.all_params.extend( variables )