@@ -31,125 +31,24 @@ int normalizeColIndex(const Node* node, int col_idx) {
31
31
std::unordered_set<std::string> getColNames (const Node* node) {
32
32
std::unordered_set<std::string> res;
33
33
res.reserve (node->size ());
34
- if (auto scan = node->as <Scan>()) {
35
- for (size_t col_idx = 0 ; col_idx < node->size (); ++col_idx) {
36
- if (!scan->isVirtualCol (col_idx)) {
37
- res.insert (scan->getFieldName (col_idx));
38
- }
34
+ auto scan = node->as <Scan>();
35
+ for (size_t col_idx = 0 ; col_idx < node->size (); ++col_idx) {
36
+ if (!scan || !scan->isVirtualCol (col_idx)) {
37
+ res.insert (node->getFieldName (col_idx));
39
38
}
40
- } else if (auto proj = node->as <Project>()) {
41
- auto fields = proj->getFields ();
42
- res.insert (fields.begin (), fields.end ());
43
- } else if (auto filter = node->as <Filter>()) {
44
- res = getColNames (filter->getInput (0 ));
45
- } else if (auto agg = node->as <Aggregate>()) {
46
- auto fields = agg->getFields ();
47
- res.insert (fields.begin (), fields.end ());
48
- } else if (auto sort = node->as <Sort>()) {
49
- res = getColNames (sort->getInput (0 ));
50
- } else if (auto join = node->as <Join>()) {
51
- res = getColNames (join->getInput (0 ));
52
- auto rhs_names = getColNames (join->getInput (1 ));
53
- res.insert (rhs_names.begin (), rhs_names.end ());
54
- } else {
55
- throw InvalidQueryError () << " getColNames error: unsupported node: "
56
- << node->toString ();
57
39
}
58
40
return res;
59
41
}
60
42
61
43
std::string getFieldName (const Node* node, int col_idx) {
62
44
col_idx = normalizeColIndex (node, col_idx);
63
- if (auto scan = node->as <Scan>()) {
64
- return scan->getFieldName (col_idx);
65
- }
66
- if (auto proj = node->as <Project>()) {
67
- return proj->getFieldName (col_idx);
68
- }
69
- if (auto filter = node->as <Filter>()) {
70
- return getFieldName (filter->getInput (0 ), col_idx);
71
- }
72
- if (auto agg = node->as <Aggregate>()) {
73
- return agg->getFieldName (col_idx);
74
- }
75
- if (auto sort = node->as <Sort>()) {
76
- return getFieldName (sort->getInput (0 ), col_idx);
77
- }
78
- if (auto join = node->as <Join>()) {
79
- if (col_idx < (int )join->getInput (0 )->size ()) {
80
- return getFieldName (join->getInput (0 ), col_idx);
81
- }
82
- return getFieldName (join->getInput (1 ), col_idx - join->getInput (0 )->size ());
83
- }
84
-
85
- throw InvalidQueryError () << " getFieldName error: unsupported node: "
86
- << node->toString ();
45
+ return node->getFieldName (col_idx);
87
46
}
88
47
89
48
ExprPtr getRefByName (const Node* node,
90
49
const std::string& col_name,
91
50
bool allow_null_res = false );
92
51
93
- int getRefIndexByName (const Node* node, const std::string& col_name) {
94
- if (node->is <Join>()) {
95
- auto lhs_ref_idx = getRefIndexByName (node->getInput (0 ), col_name);
96
- if (lhs_ref_idx >= 0 ) {
97
- return lhs_ref_idx;
98
- }
99
- auto rhs_ref_idx = getRefIndexByName (node->getInput (1 ), col_name);
100
- if (rhs_ref_idx >= 0 ) {
101
- return rhs_ref_idx + node->getInput (0 )->size ();
102
- }
103
- return -1 ;
104
- }
105
-
106
- auto ref = getRefByName (node, col_name, true );
107
- return ref ? ref->as <ir::ColumnRef>()->index () : -1 ;
108
- }
109
-
110
- ExprPtr getRefByName (const Scan* scan, const std::string& col_name) {
111
- for (size_t i = 0 ; i < scan->size (); ++i) {
112
- if (scan->getColumnInfo (i)->name == col_name) {
113
- return getNodeColumnRef (scan, (unsigned )i);
114
- }
115
- }
116
- return nullptr ;
117
- }
118
-
119
- ExprPtr getRefByName (const Project* proj, const std::string& col_name) {
120
- for (size_t i = 0 ; i < proj->size (); ++i) {
121
- if (proj->getFieldName (i) == col_name) {
122
- return getNodeColumnRef (proj, (unsigned )i);
123
- }
124
- }
125
- return nullptr ;
126
- }
127
-
128
- ExprPtr getRefByName (const Filter* filter, const std::string& col_name) {
129
- auto idx = getRefIndexByName (filter->getInput (0 ), col_name);
130
- if (idx >= 0 ) {
131
- return getNodeColumnRef (filter, idx);
132
- }
133
- return nullptr ;
134
- }
135
-
136
- ExprPtr getRefByName (const Aggregate* agg, const std::string& col_name) {
137
- for (size_t i = 0 ; i < agg->size (); ++i) {
138
- if (agg->getFieldName (i) == col_name) {
139
- return getNodeColumnRef (agg, (unsigned )i);
140
- }
141
- }
142
- return nullptr ;
143
- }
144
-
145
- ExprPtr getRefByName (const Sort* sort, const std::string& col_name) {
146
- auto idx = getRefIndexByName (sort->getInput (0 ), col_name);
147
- if (idx >= 0 ) {
148
- return getNodeColumnRef (sort, idx);
149
- }
150
- return nullptr ;
151
- }
152
-
153
52
ExprPtr getRefByName (const Join* join, const std::string& col_name) {
154
53
auto lhs_input_ref = getRefByName (join->getInput (0 ), col_name, true );
155
54
auto rhs_input_ref = getRefByName (join->getInput (1 ), col_name, true );
@@ -162,21 +61,14 @@ ExprPtr getRefByName(const Join* join, const std::string& col_name) {
162
61
163
62
ExprPtr getRefByName (const Node* node, const std::string& col_name, bool allow_null_res) {
164
63
ExprPtr res = nullptr ;
165
- if (auto scan = node->as <Scan>()) {
166
- res = getRefByName (scan, col_name);
167
- } else if (auto proj = node->as <Project>()) {
168
- res = getRefByName (proj, col_name);
169
- } else if (auto filter = node->as <Filter>()) {
170
- res = getRefByName (filter, col_name);
171
- } else if (auto agg = node->as <Aggregate>()) {
172
- res = getRefByName (agg, col_name);
173
- } else if (auto sort = node->as <Sort>()) {
174
- res = getRefByName (sort, col_name);
175
- } else if (auto join = node->as <Join>()) {
64
+ if (auto join = node->as <Join>()) {
176
65
res = getRefByName (join, col_name);
177
66
} else {
178
- throw InvalidQueryError () << " getRefByName error: unsupported node: "
179
- << node->toString ();
67
+ for (size_t i = 0 ; i < node->size (); ++i) {
68
+ if (node->getFieldName (i) == col_name) {
69
+ res = getNodeColumnRef (node, (unsigned )i);
70
+ }
71
+ }
180
72
}
181
73
182
74
if (!res && !allow_null_res) {
0 commit comments