-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathLstmLayer.java
More file actions
108 lines (90 loc) · 3.31 KB
/
Copy pathLstmLayer.java
File metadata and controls
108 lines (90 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package model;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import autodiff.Graph;
public class LstmLayer implements Model {
private static final long serialVersionUID = 1L;
int inputDimension;
int outputDimension;
Matrix Wix, Wih, bi;
Matrix Wfx, Wfh, bf;
Matrix Wox, Woh, bo;
Matrix Wcx, Wch, bc;
Matrix hiddenContext;
Matrix cellContext;
Nonlinearity fInputGate = new SigmoidUnit();
Nonlinearity fForgetGate = new SigmoidUnit();
Nonlinearity fOutputGate = new SigmoidUnit();
Nonlinearity fCellInput = new TanhUnit();
Nonlinearity fCellOutput = new TanhUnit();
public LstmLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
this.inputDimension = inputDimension;
this.outputDimension = outputDimension;
Wix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
Wih = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
bi = new Matrix(outputDimension);
Wfx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
Wfh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
//set forget bias to 1.0, as described here: http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf
bf = Matrix.ones(outputDimension, 1);
Wox = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
Woh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
bo = new Matrix(outputDimension);
Wcx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
Wch = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
bc = new Matrix(outputDimension);
}
@Override
public Matrix forward(Matrix input, Graph g) throws Exception {
//input gate
Matrix sum0 = g.mul(Wix, input);
Matrix sum1 = g.mul(Wih, hiddenContext);
Matrix inputGate = g.nonlin(fInputGate, g.add(g.add(sum0, sum1), bi));
//forget gate
Matrix sum2 = g.mul(Wfx, input);
Matrix sum3 = g.mul(Wfh, hiddenContext);
Matrix forgetGate = g.nonlin(fForgetGate, g.add(g.add(sum2, sum3), bf));
//output gate
Matrix sum4 = g.mul(Wox, input);
Matrix sum5 = g.mul(Woh, hiddenContext);
Matrix outputGate = g.nonlin(fOutputGate, g.add(g.add(sum4, sum5), bo));
//write operation on cells
Matrix sum6 = g.mul(Wcx, input);
Matrix sum7 = g.mul(Wch, hiddenContext);
Matrix cellInput = g.nonlin(fCellInput, g.add(g.add(sum6, sum7), bc));
//compute new cell activation
Matrix retainCell = g.elmul(forgetGate, cellContext);
Matrix writeCell = g.elmul(inputGate, cellInput);
Matrix cellAct = g.add(retainCell, writeCell);
//compute hidden state as gated, saturated cell activations
Matrix output = g.elmul(outputGate, g.nonlin(fCellOutput, cellAct));
//rollover activations for next iteration
hiddenContext = output;
cellContext = cellAct;
return output;
}
@Override
public void resetState() {
hiddenContext = new Matrix(outputDimension);
cellContext = new Matrix(outputDimension);
}
@Override
public List<Matrix> getParameters() {
List<Matrix> result = new ArrayList<>();
result.add(Wix);
result.add(Wih);
result.add(bi);
result.add(Wfx);
result.add(Wfh);
result.add(bf);
result.add(Wox);
result.add(Woh);
result.add(bo);
result.add(Wcx);
result.add(Wch);
result.add(bc);
return result;
}
}