-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathTrainer.java
More file actions
executable file
·140 lines (123 loc) · 4.74 KB
/
Copy pathTrainer.java
File metadata and controls
executable file
·140 lines (123 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package trainer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import util.FileIO;
import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import loss.Loss;
import matrix.Matrix;
import model.Model;
public class Trainer {
public static double decayRate = 0.999;
public static double smoothEpsilon = 1e-8;
public static double gradientClipValue = 5;
public static double regularization = 0.000001; // L2 regularization strength
public static double train(int trainingEpochs, double learningRate, Model model, DataSet data, int reportEveryNthEpoch, Random rng) throws Exception {
return train(trainingEpochs, learningRate, model, data, reportEveryNthEpoch, false, false, null, rng);
}
public static double train(int trainingEpochs, double learningRate, Model model, DataSet data, int reportEveryNthEpoch, boolean initFromSaved, boolean overwriteSaved, String savePath, Random rng) throws Exception {
System.out.println("--------------------------------------------------------------");
if (initFromSaved) {
System.out.println("initializing model from saved state...");
try {
model = (Model)FileIO.deserialize(savePath);
data.DisplayReport(model, rng);
}
catch (Exception e) {
System.out.println("Oops. Unable to load from a saved state.");
System.out.println("WARNING: " + e.getMessage());
System.out.println("Continuing from freshly initialized model instead.");
}
}
double result = 1.0;
for (int epoch = 0; epoch < trainingEpochs; epoch++) {
String show = "epoch["+(epoch+1)+"/"+trainingEpochs+"]";
double reportedLossTrain = pass(learningRate, model, data.training, true, data.lossTraining, data.lossReporting);
result = reportedLossTrain;
if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
}
double reportedLossValidation = 0;
double reportedLossTesting = 0;
if (data.validation != null) {
reportedLossValidation = pass(learningRate, model, data.validation, false, data.lossTraining, data.lossReporting);
result = reportedLossValidation;
}
if (data.testing != null) {
reportedLossTesting = pass(learningRate, model, data.testing, false, data.lossTraining, data.lossReporting);
result = reportedLossTesting;
}
show += "\ttrain loss = "+String.format("%.5f", reportedLossTrain);
if (data.validation != null) {
show += "\tvalid loss = "+String.format("%.5f", reportedLossValidation);
}
if (data.testing != null) {
show += "\ttest loss = "+String.format("%.5f", reportedLossTesting);
}
System.out.println(show);
if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
data.DisplayReport(model, rng);
}
if (overwriteSaved) {
FileIO.serialize(savePath, model);
}
if (reportedLossTrain == 0 && reportedLossValidation == 0) {
System.out.println("--------------------------------------------------------------");
System.out.println("\nDONE.");
break;
}
}
return result;
}
public static double pass(double learningRate, Model model, List<DataSequence> sequences, boolean applyTraining, Loss lossTraining, Loss lossReporting) throws Exception {
double numerLoss = 0;
double denomLoss = 0;
for (DataSequence seq : sequences) {
model.resetState();
Graph g = new Graph(applyTraining);
for (DataStep step : seq.steps) {
Matrix output = model.forward(step.input, g);
if (step.targetOutput != null) {
double loss = lossReporting.measure(output, step.targetOutput);
if (Double.isNaN(loss) || Double.isInfinite(loss)) {
return loss;
}
numerLoss += loss;
denomLoss++;
if (applyTraining) {
lossTraining.backward(output, step.targetOutput);
}
}
}
List<DataSequence> thisSequence = new ArrayList<>();
thisSequence.add(seq);
if (applyTraining) {
g.backward(); //backprop dw values
updateModelParams(model, learningRate); //update params
}
}
return numerLoss/denomLoss;
}
public static void updateModelParams(Model model, double stepSize) throws Exception {
for (Matrix m : model.getParameters()) {
for (int i = 0; i < m.w.length; i++) {
// rmsprop adaptive learning rate
double mdwi = m.dw[i];
m.stepCache[i] = m.stepCache[i] * decayRate + (1 - decayRate) * mdwi * mdwi;
// gradient clip
if (mdwi > gradientClipValue) {
mdwi = gradientClipValue;
}
if (mdwi < -gradientClipValue) {
mdwi = -gradientClipValue;
}
// update (and regularize)
m.w[i] += - stepSize * mdwi / Math.sqrt(m.stepCache[i] + smoothEpsilon) - regularization * m.w[i];
m.dw[i] = 0;
}
}
}
}