-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathTextGenerationUnbroken.java
More file actions
165 lines (144 loc) · 4.91 KB
/
Copy pathTextGenerationUnbroken.java
File metadata and controls
165 lines (144 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package datasets;
import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import util.Util;
import loss.LossSoftmax;
import matrix.Matrix;
import model.LinearUnit;
import model.Model;
import model.Nonlinearity;
public class TextGenerationUnbroken extends DataSet {
private static final long serialVersionUID = 1L;
public static int reportSequenceLength = 100;
public static boolean reportPerplexity = true;
private static Map<String, Integer> charToIndex = new HashMap<>();
private static Map<Integer, String> indexToChar = new HashMap<>();
private static int dimension;
public static String generateText(Model model, int steps, boolean argmax, double temperature, Random rng) throws Exception {
Matrix start = new Matrix(dimension);
model.resetState();
Graph g = new Graph(false);
Matrix input = start.clone();
String result = "";
for (int s = 0; s < steps; s++) {
Matrix logprobs = model.forward(input, g);
Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
int indxChosen = -1;
if (argmax) {
double high = Double.NEGATIVE_INFINITY;
for (int i = 0; i < probs.w.length; i++) {
if (probs.w[i] > high) {
high = probs.w[i];
indxChosen = i;
}
}
}
else {
indxChosen = Util.pickIndexFromRandomVector(probs, rng);
}
String ch = indexToChar.get(indxChosen);
result += ch;
for (int i = 0; i < input.w.length; i++) {
input.w[i] = 0;
}
input.w[indxChosen] = 1.0;
}
result = result.replace("\n", "\"\n\t\"");
return result;
}
public TextGenerationUnbroken(String path, int totalSequences, int sequenceMinLength, int sequenceMaxLength, Random rng) throws Exception {
System.out.println("Text generation task");
System.out.println("loading " + path + "...");
File file = new File(path);
List<String> lines_ = Files.readAllLines(file.toPath(), Charset.defaultCharset());
String text = "";
for (String line : lines_) {
text += line + "\n";
}
Set<String> chars = new HashSet<>();
int id = 0;
System.out.println("Characters:");
System.out.print("\t");
for (int i = 0; i < text.length(); i++) {
String ch = text.charAt(i) + "";
if (chars.contains(ch) == false) {
if (ch.equals("\n")) {
System.out.print("\\n");
}
else {
System.out.print(ch);
}
chars.add(ch);
charToIndex.put(ch, id);
indexToChar.put(id, ch);
id++;
}
}
System.out.println("");
dimension = chars.size();
List<DataSequence> sequences = new ArrayList<>();
for (int s = 0; s < totalSequences; s++) {
List<double[]> vecs = new ArrayList<>();
int len = rng.nextInt(sequenceMaxLength - sequenceMinLength + 1) + sequenceMinLength;
int start = rng.nextInt(text.length() - len);
for (int i = 0; i < len; i++) {
String ch = text.charAt(i+start) + "";
int index = charToIndex.get(ch);
double[] vec = new double[dimension];
vec[index] = 1.0;
vecs.add(vec);
}
DataSequence sequence = new DataSequence();
for (int i = 0; i < vecs.size() - 1; i++) {
sequence.steps.add(new DataStep(vecs.get(i), vecs.get(i+1)));
}
sequences.add(sequence);
}
System.out.println("Total unique chars = " + chars.size());
training = sequences;
lossTraining = new LossSoftmax();
lossReporting = new LossSoftmax();
inputDimension = sequences.get(0).steps.get(0).input.w.length;
int loc = 0;
while (sequences.get(0).steps.get(loc).targetOutput == null) {
loc++;
}
outputDimension = sequences.get(0).steps.get(loc).targetOutput.w.length;
}
@Override
public void DisplayReport(Model model, Random rng) throws Exception {
System.out.println("========================================");
System.out.println("REPORT:");
if (reportPerplexity) {
System.out.println("\ncalculating perplexity over entire data set...");
double perplexity = LossSoftmax.calculateMedianPerplexity(model, training);
System.out.println("\nMedian Perplexity = " + String.format("%.4f", perplexity));
}
double[] temperatures = {1, 0.75, 0.5, 0.25, 0.1};
for (double temperature : temperatures) {
System.out.println("\nTemperature "+temperature+" prediction:");
String guess = TextGenerationUnbroken.generateText(model, reportSequenceLength, false, temperature, rng);
System.out.println("\t\"..." + guess + "...\"");
}
System.out.println("\nArgmax prediction:");
String guess = TextGenerationUnbroken.generateText(model, reportSequenceLength, true, 1.0, rng);
System.out.println("\t\"..." + guess + "...\"");
System.out.println("========================================");
}
@Override
public Nonlinearity getModelOutputUnitToUse() {
return new LinearUnit();
}
}