forked from alibaba/AliSQL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwindow_rownumber_function.cpp
More file actions
209 lines (182 loc) · 8.81 KB
/
window_rownumber_function.cpp
File metadata and controls
209 lines (182 loc) · 8.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#include "duckdb/function/window/window_rownumber_function.hpp"
#include "duckdb/function/window/window_shared_expressions.hpp"
#include "duckdb/function/window/window_token_tree.hpp"
#include "duckdb/planner/expression/bound_window_expression.hpp"
namespace duckdb {
//===--------------------------------------------------------------------===//
// WindowRowNumberGlobalState
//===--------------------------------------------------------------------===//
class WindowRowNumberGlobalState : public WindowExecutorGlobalState {
public:
WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count,
const ValidityMask &partition_mask, const ValidityMask &order_mask)
: WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask),
ntile_idx(executor.ntile_idx) {
if (!executor.arg_order_idx.empty()) {
use_framing = true;
// If the argument order is prefix of the partition ordering,
// then we can just use the partition ordering.
auto &wexpr = executor.wexpr;
auto &arg_orders = executor.wexpr.arg_orders;
const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer;
if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) {
// "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their
// position in the input data, such that two elements never compare as equal."
token_tree = make_uniq<WindowTokenTree>(executor.context, executor.wexpr.arg_orders,
executor.arg_order_idx, payload_count, true);
}
}
}
//! Use framing instead of partitions (ORDER BY arguments)
bool use_framing = false;
//! The token tree for ORDER BY arguments
unique_ptr<WindowTokenTree> token_tree;
//! The evaluation index for NTILE
const column_t ntile_idx;
};
//===--------------------------------------------------------------------===//
// WindowRowNumberLocalState
//===--------------------------------------------------------------------===//
class WindowRowNumberLocalState : public WindowExecutorBoundsState {
public:
explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate)
: WindowExecutorBoundsState(grstate), grstate(grstate) {
if (grstate.token_tree) {
local_tree = grstate.token_tree->GetLocalState();
}
}
//! Accumulate the secondary sort values
void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk,
idx_t input_idx) override;
//! Finish the sinking and prepare to scan
void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override;
//! The corresponding global peer state
const WindowRowNumberGlobalState &grstate;
//! The optional sorting state for secondary sorts
unique_ptr<WindowAggregatorState> local_tree;
};
void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk,
idx_t input_idx) {
WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx);
if (local_tree) {
auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>();
local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0);
}
}
void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) {
WindowExecutorBoundsState::Finalize(gstate, collection);
if (local_tree) {
auto &local_tokens = local_tree->Cast<WindowMergeSortTreeLocalState>();
local_tokens.Sort();
local_tokens.window_tree.Build();
}
}
//===--------------------------------------------------------------------===//
// WindowRowNumberExecutor
//===--------------------------------------------------------------------===//
WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context,
WindowSharedExpressions &shared)
: WindowExecutor(wexpr, context, shared) {
for (const auto &order : wexpr.arg_orders) {
arg_order_idx.emplace_back(shared.RegisterSink(order.expression));
}
}
unique_ptr<WindowExecutorGlobalState> WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count,
const ValidityMask &partition_mask,
const ValidityMask &order_mask) const {
return make_uniq<WindowRowNumberGlobalState>(*this, payload_count, partition_mask, order_mask);
}
unique_ptr<WindowExecutorLocalState>
WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const {
return make_uniq<WindowRowNumberLocalState>(gstate.Cast<WindowRowNumberGlobalState>());
}
void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate,
DataChunk &eval_chunk, Vector &result, idx_t count,
idx_t row_idx) const {
auto &grstate = gstate.Cast<WindowRowNumberGlobalState>();
auto &lrstate = lstate.Cast<WindowRowNumberLocalState>();
auto rdata = FlatVector::GetData<uint64_t>(result);
if (grstate.use_framing) {
auto frame_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]);
auto frame_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]);
if (grstate.token_tree) {
for (idx_t i = 0; i < count; ++i, ++row_idx) {
// Row numbers are unique ranks
rdata[i] = grstate.token_tree->Rank(frame_begin[i], frame_end[i], row_idx);
}
} else {
for (idx_t i = 0; i < count; ++i, ++row_idx) {
rdata[i] = row_idx - frame_begin[i] + 1;
}
}
return;
}
auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]);
for (idx_t i = 0; i < count; ++i, ++row_idx) {
rdata[i] = row_idx - partition_begin[i] + 1;
}
}
//===--------------------------------------------------------------------===//
// WindowNtileExecutor
//===--------------------------------------------------------------------===//
WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context,
WindowSharedExpressions &shared)
: WindowRowNumberExecutor(wexpr, context, shared) {
// NTILE has one argument
ntile_idx = shared.RegisterEvaluate(wexpr.children[0]);
}
void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate,
DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const {
auto &grstate = gstate.Cast<WindowRowNumberGlobalState>();
auto &lrstate = lstate.Cast<WindowRowNumberLocalState>();
auto partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_BEGIN]);
auto partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[PARTITION_END]);
if (grstate.use_framing) {
// With secondary sorts, we restrict to the frame boundaries, but everything else should compute the same.
partition_begin = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_BEGIN]);
partition_end = FlatVector::GetData<const idx_t>(lrstate.bounds.data[FRAME_END]);
}
auto rdata = FlatVector::GetData<int64_t>(result);
WindowInputExpression ntile_col(eval_chunk, ntile_idx);
for (idx_t i = 0; i < count; ++i, ++row_idx) {
if (ntile_col.CellIsNull(i)) {
FlatVector::SetNull(result, i, true);
} else {
auto n_param = ntile_col.GetCell<int64_t>(i);
if (n_param < 1) {
throw InvalidInputException("Argument for ntile must be greater than zero");
}
// With thanks from SQLite's ntileValueFunc()
auto n_total = NumericCast<int64_t>(partition_end[i] - partition_begin[i]);
if (n_param > n_total) {
// more groups allowed than we have values
// map every entry to a unique group
n_param = n_total;
}
int64_t n_size = (n_total / n_param);
// find the row idx within the group
D_ASSERT(row_idx >= partition_begin[i]);
idx_t partition_idx = 0;
if (grstate.token_tree) {
partition_idx = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx) - 1;
} else {
partition_idx = row_idx - partition_begin[i];
}
auto adjusted_row_idx = NumericCast<int64_t>(partition_idx);
// now compute the ntile
int64_t n_large = n_total - n_param * n_size;
int64_t i_small = n_large * (n_size + 1);
int64_t result_ntile;
D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total);
if (adjusted_row_idx < i_small) {
result_ntile = 1 + adjusted_row_idx / (n_size + 1);
} else {
result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size;
}
// result has to be between [1, NTILE]
D_ASSERT(result_ntile >= 1 && result_ntile <= n_param);
rdata[i] = result_ntile;
}
}
}
} // namespace duckdb