forked from WinVector/RcppDynProg
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.R
More file actions
192 lines (168 loc) · 3.94 KB
/
utils.R
File metadata and controls
192 lines (168 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#' @importFrom utils combn
NULL
#' @importFrom wrapr stop_if_dot_args
NULL
#' Increasing whole-number sequence.
#'
#' Return an in increaing whole-number sequence from a to b inclusive (return integer(0) if none such). Allows for safe iteraton.
#'
#' TODO: switch to wrapr version after next wrapr release.
#'
#' @param a scalar lower bound
#' @param b scalar upper bound
#' @return whole number sequence
#'
#' @examples
#'
#' # print 3, 4, and then 5
#' for(i in seqi(3, 5)) {
#' print(i)
#' }
#'
#' # empty
#' for(i in seqi(5, 2)) {
#' print(i)
#' }
#'
#' @noRd
#'
seqi <- function(a, b) {
a = ceiling(a)
b = floor(b)
if(a>b) {
return(integer(0))
}
base::seq(a, b, by = 1L)
}
#' Build all partitions into intervals.
#'
#' @param n integer, sequence lenght to choose from.
#' @param kmax int, maximum number of segments in solution.
#' @return list of all partitions.
#'
#' @examples
#'
#' all_partitions(4, 2)
#'
#' @keywords internal
#'
#' @export
#'
all_partitions <- function(n, kmax = n) {
# get shape of problem
kmax <- min(kmax, n)
syms <- seqi(2, n)
res <- list(c(1, n+1))
for(kf in seqi(1, kmax-1)) {
ci = combn(syms, kf)
for(j in seq_len(ncol(ci))) {
soln <- sort(c(1, ci[, j, drop=TRUE], n+1))
res <- c(res, list(soln))
}
}
res
}
is.wholenumber <- function(x, tol = .Machine$double.eps^0.5) {
abs(x - round(x)) < tol
}
#' compute the price of a partition solution (and check is valid).
#'
#' @param x NumericMatix, for j>=i x(i,j) is the cost of partition element [i,...,j] (inclusive).
#' @param solution vector of indices
#' @return price
#'
#' @examples
#'
#' x <- matrix(c(1,1,5,1,1,0,5,0,1), nrow=3)
#' s <- c(1, 2, 4)
#' score_solution(x, s)
#'
#' @export
#'
score_solution <- function(x, solution) {
n <- nrow(x)
ls <- length(solution)
if(ls<2) {
stop("solutions must have length at least 2")
}
if(ls>(n+1)) {
stop("soltuions must have length no more than nrow(x)+1")
}
if(solution[1]!=1) {
stop("solution[1] must equal 1")
}
if(solution[ls]!=(n+1)) {
stop("solution[length(solution)] must equal nrow(x)+1")
}
if(!isTRUE(all(solution[-1]>solution[-ls]))) {
stop("solution indices must be increasing")
}
if(!isTRUE(all(is.wholenumber(solution)))) {
stop("solution must be wholenumbers")
}
score <- 0
for(i in seqi(1, ls-1)) {
score <- score + x[solution[i], solution[i+1]-1]
}
return(score)
}
test_solvers <- function(x, k) {
msg <- NULL
tryCatch({
sl <- all_partitions(nrow(x), k)
if(length(sl)<1) {
stop("brute force didn't return any solutions")
}
for(si in sl) {
if(!(length(si)<=(k+1))) {
stop("brute solution too long")
}
}
sc <- vapply(
sl,
function(si) {
score_solution(x, si)
}, numeric(1))
sm <- min(sc)
soln1 <- solve_interval_partition_R(x, k)
score1 <- score_solution(x, soln1)
if(!(length(soln1)<=(k+1))) {
stop("soln1 too long")
}
if(!(abs(score1-sm)<=1e-5)) {
stop("R solution has wrong score")
}
soln2 <- solve_interval_partition_k(x, k)
score2 <- score_solution(x, soln2)
if(!(length(soln2)<=(k+1))) {
stop("soln2 too long")
}
if(!(abs(score2-sm)<=1e-5)) {
stop("C++ k solution has wrong score")
}
soln3 <- solve_interval_partition(x, k)
score3 <- score_solution(x, soln3)
if(!(length(soln3)<=(k+1))) {
stop("soln3 too long")
}
if(!(abs(score3-sm)<=1e-5)) {
stop("C++ solution has wrong score")
}
if(k>=nrow(x)) {
soln4 <- solve_interval_partition_no_k(x)
score4 <- score_solution(x, soln4)
if(!(length(soln4)<=(k+1))) {
stop("soln4 too long")
}
if(!(abs(score4-sm)<=1e-5)) {
stop("C++ no_k solution has wrong score")
}
}
},
error = function(e) { msg <<- paste(as.character(e), sep = " ") }
)
if(!is.null(msg)) {
return(msg)
}
return(TRUE)
}