-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathGraphDriver.java
More file actions
112 lines (103 loc) · 4.81 KB
/
GraphDriver.java
File metadata and controls
112 lines (103 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package graph;
import java.util.*;
import structure.Problem;
import utils.Folds;
import utils.Params;
import edu.illinois.cs.cogcomp.core.datastructures.Pair;
import edu.illinois.cs.cogcomp.sl.core.SLModel;
import edu.illinois.cs.cogcomp.sl.core.SLParameters;
import edu.illinois.cs.cogcomp.sl.core.SLProblem;
import edu.illinois.cs.cogcomp.sl.learner.Learner;
import edu.illinois.cs.cogcomp.sl.learner.LearnerFactory;
import edu.illinois.cs.cogcomp.sl.util.Lexiconer;
public class GraphDriver {
public static void crossVal(List<Problem> probs, List<List<Integer>> foldIndices)
throws Exception {
double acc1 = 0.0, acc2 = 0.0;
for(int i=0;i<foldIndices.size(); i++) {
List<Integer> train = new ArrayList<>();
List<Integer> test = new ArrayList<>();
for(int j=0; j<foldIndices.size(); ++j) {
if(i==j) test.addAll(foldIndices.get(j));
else train.addAll(foldIndices.get(j));
}
Pair<Double, Double> pair = doTrainTest(probs, train, test, i);
acc1 += pair.getFirst();
acc2 += pair.getSecond();
}
System.out.println("CV : " + (acc1/foldIndices.size()) + " " + (acc2/foldIndices.size()));
}
public static Pair<Double, Double> doTrainTest(List<Problem> probs, List<Integer> trainIndices,
List<Integer> testIndices, int id) throws Exception {
List<List<Problem>> split = Folds.getDataSplit(probs, trainIndices, testIndices, 0.0);
List<Problem> trainProbs = split.get(0);
List<Problem> testProbs = split.get(2);
SLProblem train = getSP(trainProbs);
SLProblem test = getSP(testProbs);
System.out.println("Train : "+train.instanceList.size()+" Test : "+test.instanceList.size());
trainModel(Params.modelDir+Params.graphPrefix+id+Params.modelSuffix, train);
return testModel(Params.modelDir+Params.graphPrefix+id+Params.modelSuffix, test);
}
public static SLProblem getSP(List<Problem> problemList) throws Exception{
SLProblem problem = new SLProblem();
for(Problem prob : problemList){
List<Integer> indices = constraints.GraphInfSolver.createRelevantQuantIndexList(prob);
GraphX x = new GraphX(prob, indices);
GraphY y = new GraphY(constraints.GraphInfSolver.getGoldLabels(prob));
problem.addExample(x, y);
}
return problem;
}
public static Pair<Double, Double> testModel(String modelPath, SLProblem sp)
throws Exception {
SLModel model = SLModel.loadModel(modelPath);
Set<Integer> incorrect = new HashSet<>();
Set<Integer> total = new HashSet<>();
double acc = 0.0;
for (int i = 0; i < sp.instanceList.size(); i++) {
GraphX prob = (GraphX) sp.instanceList.get(i);
GraphY gold = (GraphY) sp.goldStructureList.get(i);
GraphY pred = (GraphY) model.infSolver.getBestStructure(model.wv, prob);
total.add(prob.problemId);
boolean correct =false;
if(GraphY.getLoss(gold, pred) < 0.0001) {
acc += 1;
correct = true;
} else {
incorrect.add(prob.problemId);
}
if((correct && Params.printCorrect) ||
(!correct && Params.printMistakes)){
System.out.println(prob.problemId+" : "+prob.ta.getText());
System.out.println();
System.out.println("Schema : "+prob.schema);
System.out.println();
System.out.println("Quantities : "+prob.quantities);
System.out.println("Gold : "+gold);
System.out.println("Pred : "+pred);
System.out.println();
}
}
System.out.println("Accuracy : = " + acc + " / " + sp.instanceList.size()
+ " = " + (acc/sp.instanceList.size()));
System.out.println("Strict Accuracy : ="+ (1-1.0*incorrect.size()/total.size()));
return new Pair<>(acc/sp.instanceList.size(),
1-1.0*incorrect.size()/total.size());
}
public static void trainModel(String modelPath, SLProblem train) throws Exception {
SLModel model = new SLModel();
Lexiconer lm = new Lexiconer();
lm.setAllowNewFeatures(true);
model.lm = lm;
GraphFeatGen fg = new GraphFeatGen(lm);
model.featureGenerator = fg;
model.infSolver = new GraphInfSolver(fg);
SLParameters para = new SLParameters();
para.loadConfigFile(Params.spConfigFile);
para.MAX_NUM_ITER = 5;
Learner learner = LearnerFactory.getLearner(model.infSolver, fg, para);
model.wv = learner.train(train);
lm.setAllowNewFeatures(false);
model.saveModel(modelPath);
}
}