forked from alibaba/AliSQL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcolumn_binding_resolver.cpp
More file actions
238 lines (229 loc) · 9.22 KB
/
column_binding_resolver.cpp
File metadata and controls
238 lines (229 loc) · 9.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#include "duckdb/execution/column_binding_resolver.hpp"
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/planner/operator/logical_any_join.hpp"
#include "duckdb/planner/operator/logical_comparison_join.hpp"
#include "duckdb/planner/operator/logical_create_index.hpp"
#include "duckdb/planner/operator/logical_extension_operator.hpp"
#include "duckdb/planner/operator/logical_insert.hpp"
#include "duckdb/planner/operator/logical_recursive_cte.hpp"
namespace duckdb {
ColumnBindingResolver::ColumnBindingResolver(bool verify_only) : verify_only(verify_only) {
}
void ColumnBindingResolver::VisitOperator(LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_ASOF_JOIN:
case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: {
// special case: comparison join
auto &comp_join = op.Cast<LogicalComparisonJoin>();
// first get the bindings of the LHS and resolve the LHS expressions
VisitOperator(*comp_join.children[0]);
for (auto &cond : comp_join.conditions) {
VisitExpression(&cond.left);
}
// visit the duplicate eliminated columns on the LHS, if any
for (auto &expr : comp_join.duplicate_eliminated_columns) {
VisitExpression(&expr);
}
// then get the bindings of the RHS and resolve the RHS expressions
VisitOperator(*comp_join.children[1]);
for (auto &cond : comp_join.conditions) {
VisitExpression(&cond.right);
}
// finally update the bindings with the result bindings of the join
bindings = op.GetColumnBindings();
types = op.types;
// resolve any mixed predicates
// for now, only ASOF supports this.
if (comp_join.predicate) {
D_ASSERT(op.type == LogicalOperatorType::LOGICAL_ASOF_JOIN);
VisitExpression(&comp_join.predicate);
}
return;
}
case LogicalOperatorType::LOGICAL_DELIM_JOIN: {
auto &comp_join = op.Cast<LogicalComparisonJoin>();
// get bindings from the duplicate-eliminated side
auto &delim_side = comp_join.delim_flipped ? *comp_join.children[1] : *comp_join.children[0];
VisitOperator(delim_side);
for (auto &cond : comp_join.conditions) {
auto &expr = comp_join.delim_flipped ? cond.right : cond.left;
VisitExpression(&expr);
}
// visit the duplicate eliminated columns
for (auto &expr : comp_join.duplicate_eliminated_columns) {
VisitExpression(&expr);
}
// now the other side
auto &other_side = comp_join.delim_flipped ? *comp_join.children[0] : *comp_join.children[1];
VisitOperator(other_side);
for (auto &cond : comp_join.conditions) {
auto &expr = comp_join.delim_flipped ? cond.left : cond.right;
VisitExpression(&expr);
}
// finally update the bindings with the result bindings of the join
bindings = op.GetColumnBindings();
types = op.types;
return;
}
case LogicalOperatorType::LOGICAL_ANY_JOIN: {
// ANY join, this join is different because we evaluate the expression on the bindings of BOTH join sides at
// once i.e. we set the bindings first to the bindings of the entire join, and then resolve the expressions of
// this operator
VisitOperatorChildren(op);
bindings = op.GetColumnBindings();
types = op.types;
auto &any_join = op.Cast<LogicalAnyJoin>();
if (any_join.join_type == JoinType::SEMI || any_join.join_type == JoinType::ANTI) {
auto right_bindings = op.children[1]->GetColumnBindings();
bindings.insert(bindings.end(), right_bindings.begin(), right_bindings.end());
auto &right_types = op.children[1]->types;
types.insert(types.end(), right_types.begin(), right_types.end());
}
if (any_join.join_type == JoinType::RIGHT_SEMI || any_join.join_type == JoinType::RIGHT_ANTI) {
throw InternalException("RIGHT SEMI/ANTI any join not supported yet");
}
VisitOperatorExpressions(op);
return;
}
case LogicalOperatorType::LOGICAL_CREATE_INDEX: {
// CREATE INDEX statement, add the columns of the table with table index 0 to the binding set
// afterwards bind the expressions of the CREATE INDEX statement
auto &create_index = op.Cast<LogicalCreateIndex>();
bindings = LogicalOperator::GenerateColumnBindings(0, create_index.table.GetColumns().LogicalColumnCount());
// TODO: fill types in too (clearing skips type checks)
types.clear();
VisitOperatorExpressions(op);
return;
}
case LogicalOperatorType::LOGICAL_GET: {
//! We first need to update the current set of bindings and then visit operator expressions
bindings = op.GetColumnBindings();
types = op.types;
VisitOperatorExpressions(op);
return;
}
case LogicalOperatorType::LOGICAL_INSERT: {
//! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a
// ON CONFLICT DO UPDATE clause
auto &insert_op = op.Cast<LogicalInsert>();
if (insert_op.action_type != OnConflictAction::THROW) {
// Get the bindings from the children
VisitOperatorChildren(op);
auto column_count = insert_op.table.GetColumns().PhysicalColumnCount();
auto dummy_bindings = LogicalOperator::GenerateColumnBindings(insert_op.excluded_table_index, column_count);
// Now insert our dummy bindings at the start of the bindings,
// so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns
bindings.insert(bindings.begin(), dummy_bindings.begin(), dummy_bindings.end());
// TODO: fill types in too (clearing skips type checks)
types.clear();
if (insert_op.on_conflict_condition) {
VisitExpression(&insert_op.on_conflict_condition);
}
if (insert_op.do_update_condition) {
VisitExpression(&insert_op.do_update_condition);
}
VisitOperatorExpressions(op);
bindings = op.GetColumnBindings();
types = op.types;
return;
}
break;
}
case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: {
auto &ext_op = op.Cast<LogicalExtensionOperator>();
// Just to be very sure, we clear before and after resolving extension operator column bindings
// This skips checks, but makes sure we don't break any extension operators with type verification
types.clear();
ext_op.ResolveColumnBindings(*this, bindings);
types.clear();
return;
}
case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: {
auto &rec = op.Cast<LogicalRecursiveCTE>();
VisitOperatorChildren(op);
bindings = op.GetColumnBindings();
types = op.types;
for (auto &expr : rec.key_targets) {
VisitExpression(&expr);
}
return;
}
default:
break;
}
// general case
// first visit the children of this operator
VisitOperatorChildren(op);
// now visit the expressions of this operator to resolve any bound column references
VisitOperatorExpressions(op);
// finally update the current set of bindings to the current set of column bindings
bindings = op.GetColumnBindings();
types = op.types;
}
unique_ptr<Expression> ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr,
unique_ptr<Expression> *expr_ptr) {
D_ASSERT(expr.depth == 0);
// check the current set of column bindings to see which index corresponds to the column reference
for (idx_t i = 0; i < bindings.size(); i++) {
if (expr.binding == bindings[i]) {
if (!types.empty()) {
if (bindings.size() != types.size()) {
throw InternalException(
"Failed to bind column reference \"%s\" [%d.%d]: inequal num bindings/types (%llu != %llu)",
expr.GetAlias(), expr.binding.table_index, expr.binding.column_index, bindings.size(),
types.size());
}
if (expr.return_type != types[i]) {
throw InternalException("Failed to bind column reference \"%s\" [%d.%d]: inequal types (%s != %s)",
expr.GetAlias(), expr.binding.table_index, expr.binding.column_index,
expr.return_type.ToString(), types[i].ToString());
}
}
if (verify_only) {
// in verification mode
return nullptr;
}
return make_uniq<BoundReferenceExpression>(expr.GetAlias(), expr.return_type, i);
}
}
// LCOV_EXCL_START
// could not bind the column reference, this should never happen and indicates a bug in the code
// generate an error message
throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.GetAlias(),
expr.binding.table_index, expr.binding.column_index,
LogicalOperator::ColumnBindingsToString(bindings));
// LCOV_EXCL_STOP
}
unordered_set<idx_t> ColumnBindingResolver::VerifyInternal(LogicalOperator &op) {
unordered_set<idx_t> result;
for (auto &child : op.children) {
auto child_indexes = VerifyInternal(*child);
for (auto index : child_indexes) {
D_ASSERT(index != DConstants::INVALID_INDEX);
if (result.find(index) != result.end()) {
throw InternalException("Duplicate table index \"%lld\" found", index);
}
result.insert(index);
}
}
auto indexes = op.GetTableIndex();
for (auto index : indexes) {
D_ASSERT(index != DConstants::INVALID_INDEX);
if (result.find(index) != result.end()) {
throw InternalException("Duplicate table index \"%lld\" found", index);
}
result.insert(index);
}
return result;
}
void ColumnBindingResolver::Verify(LogicalOperator &op) {
#ifdef DEBUG
op.ResolveOperatorTypes();
ColumnBindingResolver resolver(true);
resolver.VisitOperator(op);
VerifyInternal(op);
#endif
}
} // namespace duckdb