forked from tensorlayer/TensorLayer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_layers_shape.py
More file actions
96 lines (65 loc) · 2.55 KB
/
Copy pathtest_layers_shape.py
File metadata and controls
96 lines (65 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import unittest
import tensorflow as tf
import tensorlayer as tl
class Layer_Shape_Test(unittest.TestCase):
@classmethod
def setUpClass(cls):
x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
net = tl.layers.InputLayer(x, name='input')
## Flatten
net1 = tl.layers.FlattenLayer(net, name='flatten')
net1.print_layers()
net1.print_params(False)
cls.net1_shape = net1.outputs.get_shape().as_list()
cls.net1_layers = net1.all_layers
cls.net1_params = net1.all_params
cls.net1_n_params = net1.count_params()
## Reshape
net2 = tl.layers.ReshapeLayer(net1, shape=(-1, 28, 28, 1), name='reshape')
net2.print_layers()
net2.print_params(False)
cls.net2_shape = net2.outputs.get_shape().as_list()
cls.net2_layers = net2.all_layers
cls.net2_params = net2.all_params
cls.net2_n_params = net2.count_params()
## TransposeLayer
net3 = tl.layers.TransposeLayer(net2, perm=[0, 1, 3, 2], name='trans')
net3.print_layers()
net3.print_params(False)
cls.net3_shape = net3.outputs.get_shape().as_list()
cls.net3_layers = net3.all_layers
cls.net3_params = net3.all_params
cls.net3_n_params = net3.count_params()
@classmethod
def tearDownClass(cls):
tf.reset_default_graph()
def test_net1_shape(self):
self.assertEqual(self.net1_shape[-1], 784)
def test_net1_layers(self):
self.assertEqual(len(self.net1_layers), 1)
def test_net1_params(self):
self.assertEqual(len(self.net1_params), 0)
def test_net1_n_params(self):
self.assertEqual(self.net1_n_params, 0)
def test_net2_shape(self):
self.assertEqual(self.net2_shape[1:], [28, 28, 1])
def test_net2_layers(self):
self.assertEqual(len(self.net2_layers), 2)
def test_net2_params(self):
self.assertEqual(len(self.net2_params), 0)
def test_net2_n_params(self):
self.assertEqual(self.net2_n_params, 0)
def test_net3_shape(self):
self.assertEqual(self.net3_shape[1:], [28, 1, 28])
def test_net3_layers(self):
self.assertEqual(len(self.net3_layers), 3)
def test_net3_params(self):
self.assertEqual(len(self.net3_params), 0)
def test_net3_n_params(self):
self.assertEqual(self.net3_n_params, 0)
if __name__ == '__main__':
# tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.set_verbosity(tf.logging.DEBUG)
unittest.main()