forked from tensorlayer/TensorLayer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathshape.py
More file actions
123 lines (88 loc) · 3.6 KB
/
Copy pathshape.py
File metadata and controls
123 lines (88 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
import tensorflow as tf
from .. import _logging as logging
from .core import *
from ..deprecation import deprecated_alias
__all__ = [
'FlattenLayer',
'ReshapeLayer',
'TransposeLayer',
]
class FlattenLayer(Layer):
"""A layer that reshapes high-dimension input into a vector.
Then we often apply DenseLayer, RNNLayer, ConcatLayer and etc on the top of a flatten layer.
[batch_size, mask_row, mask_col, n_mask] ---> [batch_size, mask_row * mask_col * n_mask]
Parameters
----------
prev_layer : :class:`Layer`
Previous layer.
name : str
A unique layer name.
Examples
--------
>>> x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
>>> net = tl.layers.InputLayer(x, name='input')
>>> net = tl.layers.FlattenLayer(net, name='flatten')
... [?, 784]
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(self, prev_layer, name='flatten'):
super(FlattenLayer, self).__init__(prev_layer=prev_layer, name=name)
self.inputs = prev_layer.outputs
self.outputs = flatten_reshape(self.inputs, name=name)
self.n_units = int(self.outputs.get_shape()[-1])
self.all_layers.append(self.outputs)
logging.info("FlattenLayer %s: %d" % (self.name, self.n_units))
class ReshapeLayer(Layer):
"""A layer that reshapes a given tensor.
Parameters
----------
prev_layer : :class:`Layer`
Previous layer
shape : tuple of int
The output shape, see ``tf.reshape``.
name : str
A unique layer name.
Examples
--------
>>> x = tf.placeholder(tf.float32, shape=(None, 784))
>>> net = tl.layers.InputLayer(x, name='input')
>>> net = tl.layers.ReshapeLayer(net, [-1, 28, 28, 1], name='reshape')
>>> print(net.outputs)
... (?, 28, 28, 1)
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(self, prev_layer, shape, name='reshape'):
super(ReshapeLayer, self).__init__(prev_layer=prev_layer, name=name)
self.inputs = prev_layer.outputs
if not shape:
raise ValueError("Shape list can not be empty")
self.outputs = tf.reshape(self.inputs, shape=shape, name=name)
self.all_layers.append(self.outputs)
logging.info("ReshapeLayer %s: %s" % (self.name, self.outputs.get_shape()))
class TransposeLayer(Layer):
"""A layer that transposes the dimension of a tensor.
See `tf.transpose() <https://www.tensorflow.org/api_docs/python/tf/transpose>`__ .
Parameters
----------
prev_layer : :class:`Layer`
Previous layer
perm: list of int
The permutation of the dimensions, similar with ``numpy.transpose``.
name : str
A unique layer name.
Examples
----------
>>> x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
>>> net = tl.layers.InputLayer(x, name='input')
>>> net = tl.layers.TransposeLayer(net, perm=[0, 1, 3, 2], name='trans')
... [None, 28, 1, 28]
"""
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(self, prev_layer, perm, name='transpose'):
super(TransposeLayer, self).__init__(prev_layer=prev_layer, name=name)
logging.info("TransposeLayer %s: perm:%s" % (name, perm))
self.inputs = prev_layer.outputs
assert perm is not None
self.outputs = tf.transpose(self.inputs, perm=perm, name=name)
self.all_layers.append(self.outputs)