-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathGraphInfSolver.java
More file actions
98 lines (90 loc) · 3.87 KB
/
GraphInfSolver.java
File metadata and controls
98 lines (90 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package graph;
import com.google.common.collect.MinMaxPriorityQueue;
import edu.illinois.cs.cogcomp.core.datastructures.Pair;
import edu.illinois.cs.cogcomp.sl.core.AbstractInferenceSolver;
import edu.illinois.cs.cogcomp.sl.core.IInstance;
import edu.illinois.cs.cogcomp.sl.core.IStructure;
import edu.illinois.cs.cogcomp.sl.util.WeightVector;
import structure.PairComparator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GraphInfSolver extends AbstractInferenceSolver {
private GraphFeatGen featGen;
public GraphInfSolver(GraphFeatGen featGen) {
this.featGen = featGen;
}
@Override
public IStructure getBestStructure(WeightVector weightVector, IInstance iInstance)
throws Exception {
return getLossAugmentedBestStructure(weightVector, iInstance, null);
}
@Override
public IStructure getLossAugmentedBestStructure(
WeightVector weightVector, IInstance iInstance, IStructure iStructure)
throws Exception {
boolean useConstraints = true;
GraphX x = (GraphX) iInstance;
PairComparator<List<String>> nodePairComparator =
new PairComparator<List<String>>() {};
MinMaxPriorityQueue<Pair<List<String>, Double>> beam1 =
MinMaxPriorityQueue.orderedBy(nodePairComparator)
.maximumSize(200).create();
MinMaxPriorityQueue<Pair<List<String>, Double>> beam2 =
MinMaxPriorityQueue.orderedBy(nodePairComparator)
.maximumSize(200).create();
int n = x.relevantQuantIndices.size();
List<String> init = new ArrayList<>();
for(int i=0; i<n; ++i) {
init.add("NOT_RATE");
}
List<String> labels = new ArrayList<>();
labels.addAll(init);
beam1.add(new Pair<>(labels, 0.0));
for(int i=0; i<n; ++i) {
labels = new ArrayList<>();
labels.addAll(init);
labels.set(i, "RATE");
beam1.add(new Pair<>(labels, 1.0*weightVector.dotProduct(
featGen.getRateFeatureVector(x, i, "RATE"))));
for(int j=i+1; j<n; ++j) {
labels = new ArrayList<>();
labels.addAll(init);
labels.set(i, "RATE");
labels.set(j, "RATE");
beam1.add(new Pair<>(labels, 1.0*weightVector.dotProduct(
featGen.getRunFeatureVector(x, i, j, "RATE"))));
}
}
for(int i=0; i<n-1; ++i) {
for(int j=i+1; j<n; ++j) {
for(String label : Arrays.asList(
"SAME_UNIT", "1_RATE_1", "1_RATE_2",
"2_RATE_1", "2_RATE_2", "NO_REL")) {
for(Pair<List<String>, Double> pair : beam1) {
labels = new ArrayList<>();
labels.addAll(pair.getFirst());
labels.add(label);
if(useConstraints && !constraints.GraphInfSolver.satisfyConstraints(
labels.get(i),
labels.get(j),
label
)) continue;
beam2.add(new Pair<>(labels,
pair.getSecond() + 1.0*weightVector.dotProduct(
featGen.getRunFeatureVector(x, i, j, label))));
}
}
beam1.clear();
beam1.addAll(beam2);
beam2.clear();
}
}
assert beam1.element().getFirst().size() == n+n*(n-1)/2;
return new GraphY(beam1.element().getFirst());
}
@Override
public float getLoss(IInstance iInstance, IStructure iStructure, IStructure iStructure1) {
return GraphY.getLoss((GraphY) iStructure, (GraphY) iStructure1);
}
}