-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathLossArgMax.java
More file actions
45 lines (38 loc) · 979 Bytes
/
Copy pathLossArgMax.java
File metadata and controls
45 lines (38 loc) · 979 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package loss;
import matrix.Matrix;
public class LossArgMax implements Loss {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public void backward(Matrix actualOutput, Matrix targetOutput) throws Exception {
throw new Exception("not implemented");
}
@Override
public double measure(Matrix actualOutput, Matrix targetOutput) throws Exception {
if (actualOutput.w.length != targetOutput.w.length) {
throw new Exception("mismatch");
}
double maxActual = Double.NEGATIVE_INFINITY;
double maxTarget = Double.NEGATIVE_INFINITY;
int indxMaxActual = -1;
int indxMaxTarget = -1;
for (int i = 0; i < actualOutput.w.length; i++) {
if (actualOutput.w[i] > maxActual) {
maxActual = actualOutput.w[i];
indxMaxActual = i;
}
if (targetOutput.w[i] > maxTarget) {
maxTarget = targetOutput.w[i];
indxMaxTarget = i;
}
}
if (indxMaxActual == indxMaxTarget) {
return 0;
}
else {
return 1;
}
}
}