forked from tqchen/tinyflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn.py
More file actions
20 lines (18 loc) · 651 Bytes
/
nn.py
File metadata and controls
20 lines (18 loc) · 651 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from nnvm.symbol import *
from nnvm import symbol as _sym
def conv2d(data, weight=None,
strides=[1, 1, 1, 1],
padding='VALID',
data_format='NCHW',
**kwargs):
kwargs = kwargs.copy()
kwargs['data'] = data
if weight:
kwargs['weight'] = weight
return _sym.conv2d(strides=strides, padding=padding, data_format=data_format, **kwargs)
def max_pool(data,
strides=[1, 1, 1, 1],
padding='VALID',
data_format='NCHW', **kwargs):
return _sym.max_pool(data, strides=strides, padding=padding,
data_format=data_format, **kwargs)