Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 1998256

Browse files
committed
Add mandatory getFieldName for all nodes.
Signed-off-by: ienkovich <[email protected]>
1 parent e760e09 commit 1998256

File tree

3 files changed

+40
-174
lines changed

3 files changed

+40
-174
lines changed

omniscidb/IR/Node.cpp

+1-50
Original file line numberDiff line numberDiff line change
@@ -63,41 +63,6 @@ bool is_one_of(const Node* node) {
6363
return dynamic_cast<const T1*>(node) || is_one_of<T2, Ts...>(node);
6464
}
6565

66-
bool isRenamedInput(const Node* node, const size_t index, const std::string& new_name) {
67-
CHECK_LT(index, node->size());
68-
if (auto join = dynamic_cast<const Join*>(node)) {
69-
CHECK_EQ(size_t(2), join->inputCount());
70-
const auto lhs_size = join->getInput(0)->size();
71-
if (index < lhs_size) {
72-
return isRenamedInput(join->getInput(0), index, new_name);
73-
}
74-
CHECK_GE(index, lhs_size);
75-
return isRenamedInput(join->getInput(1), index - lhs_size, new_name);
76-
}
77-
78-
if (auto scan = dynamic_cast<const Scan*>(node)) {
79-
return new_name != scan->getFieldName(index);
80-
}
81-
82-
if (auto aggregate = dynamic_cast<const Aggregate*>(node)) {
83-
return new_name != aggregate->getFieldName(index);
84-
}
85-
86-
if (auto project = dynamic_cast<const Project*>(node)) {
87-
return new_name != project->getFieldName(index);
88-
}
89-
90-
if (auto logical_values = dynamic_cast<const LogicalValues*>(node)) {
91-
const auto& tuple_type = logical_values->getTupleType();
92-
CHECK_LT(index, tuple_type.size());
93-
return new_name != tuple_type[index].get_resname();
94-
}
95-
96-
CHECK(dynamic_cast<const Sort*>(node) || dynamic_cast<const Filter*>(node) ||
97-
dynamic_cast<const LogicalUnion*>(node));
98-
return isRenamedInput(node->getInput(0), index, new_name);
99-
}
100-
10166
} // namespace
10267

10368
std::atomic<unsigned> Node::crt_id_ = FIRST_NODE_ID;
@@ -190,7 +155,7 @@ bool Project::isRenaming() const {
190155
for (size_t i = 0; i < fields_.size(); ++i) {
191156
auto col_ref = dynamic_cast<const ColumnRef*>(exprs_[i].get());
192157
CHECK(col_ref);
193-
if (isRenamedInput(col_ref->node(), col_ref->index(), fields_[i])) {
158+
if (col_ref->node()->getFieldName(col_ref->index()) != fields_[i]) {
194159
return true;
195160
}
196161
}
@@ -280,20 +245,6 @@ size_t LogicalUnion::toHash() const {
280245
return *hash_;
281246
}
282247

283-
std::string LogicalUnion::getFieldName(const size_t i) const {
284-
if (auto const* input = dynamic_cast<Project const*>(inputs_[0].get())) {
285-
return input->getFieldName(i);
286-
} else if (auto const* input = dynamic_cast<LogicalUnion const*>(inputs_[0].get())) {
287-
return input->getFieldName(i);
288-
} else if (auto const* input = dynamic_cast<Aggregate const*>(inputs_[0].get())) {
289-
return input->getFieldName(i);
290-
} else if (auto const* input = dynamic_cast<Scan const*>(inputs_[0].get())) {
291-
return input->getFieldName(i);
292-
}
293-
UNREACHABLE() << "Unhandled input type: " << ::toString(inputs_.front());
294-
return {};
295-
}
296-
297248
void LogicalUnion::checkForMatchingMetaInfoTypes() const {
298249
std::vector<TargetMetaInfo> const& tmis0 = inputs_[0]->getOutputMetainfo();
299250
std::vector<TargetMetaInfo> const& tmis1 = inputs_[1]->getOutputMetainfo();

omniscidb/IR/Node.h

+28-5
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ class Node {
174174

175175
virtual std::shared_ptr<Node> deepCopy() const = 0;
176176

177+
virtual const std::string& getFieldName(size_t i) const = 0;
178+
177179
/**
178180
* Clears the ptr to the result for this descriptor. Is only used for overriding step
179181
* results in distributed mode.
@@ -224,7 +226,7 @@ class Scan : public Node {
224226

225227
const size_t getNumFragments() const { return table_info_->fragments; }
226228

227-
const std::string& getFieldName(const size_t i) const {
229+
const std::string& getFieldName(size_t i) const override {
228230
CHECK_LT(i, column_infos_.size());
229231
return column_infos_[i]->name;
230232
}
@@ -325,7 +327,7 @@ class Project : public Node {
325327
const std::vector<std::string>& getFields() const { return fields_; }
326328
void setFields(std::vector<std::string> fields) { fields_ = std::move(fields); }
327329

328-
const std::string getFieldName(const size_t i) const {
330+
const std::string& getFieldName(size_t i) const override {
329331
CHECK_LT(i, fields_.size());
330332
return fields_[i];
331333
}
@@ -388,7 +390,7 @@ class Aggregate : public Node {
388390
const std::vector<std::string>& getFields() const { return fields_; }
389391
void setFields(std::vector<std::string> new_fields) { fields_ = std::move(new_fields); }
390392

391-
const std::string getFieldName(const size_t i) const {
393+
const std::string& getFieldName(size_t i) const override {
392394
CHECK_LT(i, fields_.size());
393395
return fields_[i];
394396
}
@@ -492,6 +494,13 @@ class Join : public Node {
492494
return std::make_shared<Join>(*this);
493495
}
494496

497+
const std::string& getFieldName(size_t i) const override {
498+
if (i < getInput(0)->size()) {
499+
return getInput(0)->getFieldName(i);
500+
}
501+
return getInput(1)->getFieldName(i - getInput(0)->size());
502+
}
503+
495504
private:
496505
ExprPtr condition_;
497506
const JoinType join_type_;
@@ -582,7 +591,7 @@ class TranslatedJoin : public Node {
582591
CHECK(false);
583592
return nullptr;
584593
}
585-
std::string getFieldName(const size_t i) const;
594+
const std::string& getFieldName(size_t i) const override { CHECK(false); }
586595
std::vector<const ColumnVar*> getJoinCols(bool lhs) const {
587596
if (lhs) {
588597
return lhs_join_cols_;
@@ -656,6 +665,10 @@ class Filter : public Node {
656665
return std::make_shared<Filter>(*this);
657666
}
658667
668+
const std::string& getFieldName(size_t i) const override {
669+
return getInput(0)->getFieldName(i);
670+
}
671+
659672
private:
660673
ExprPtr condition_;
661674
};
@@ -739,6 +752,10 @@ class Sort : public Node {
739752
return std::make_shared<Sort>(*this);
740753
}
741754
755+
const std::string& getFieldName(size_t i) const override {
756+
return getInput(0)->getFieldName(i);
757+
}
758+
742759
private:
743760
std::vector<SortField> collation_;
744761
const size_t limit_;
@@ -803,6 +820,10 @@ class LogicalValues : public Node {
803820
return std::make_shared<LogicalValues>(*this);
804821
}
805822
823+
const std::string& getFieldName(size_t i) const override {
824+
return tuple_type_[i].get_resname();
825+
}
826+
806827
private:
807828
std::vector<TargetMetaInfo> tuple_type_;
808829
std::vector<ExprPtrVector> values_;
@@ -818,7 +839,9 @@ class LogicalUnion : public Node {
818839
std::string toString() const override;
819840
size_t toHash() const override;
820841
821-
std::string getFieldName(const size_t i) const;
842+
const std::string& getFieldName(size_t i) const override {
843+
return getInput(0)->getFieldName(i);
844+
}
822845
823846
inline bool isAll() const { return is_all_; }
824847
// Will throw a std::runtime_error if MetaInfo types don't match.

omniscidb/QueryBuilder/QueryBuilder.cpp

+11-119
Original file line numberDiff line numberDiff line change
@@ -31,125 +31,24 @@ int normalizeColIndex(const Node* node, int col_idx) {
3131
std::unordered_set<std::string> getColNames(const Node* node) {
3232
std::unordered_set<std::string> res;
3333
res.reserve(node->size());
34-
if (auto scan = node->as<Scan>()) {
35-
for (size_t col_idx = 0; col_idx < node->size(); ++col_idx) {
36-
if (!scan->isVirtualCol(col_idx)) {
37-
res.insert(scan->getFieldName(col_idx));
38-
}
34+
auto scan = node->as<Scan>();
35+
for (size_t col_idx = 0; col_idx < node->size(); ++col_idx) {
36+
if (!scan || !scan->isVirtualCol(col_idx)) {
37+
res.insert(node->getFieldName(col_idx));
3938
}
40-
} else if (auto proj = node->as<Project>()) {
41-
auto fields = proj->getFields();
42-
res.insert(fields.begin(), fields.end());
43-
} else if (auto filter = node->as<Filter>()) {
44-
res = getColNames(filter->getInput(0));
45-
} else if (auto agg = node->as<Aggregate>()) {
46-
auto fields = agg->getFields();
47-
res.insert(fields.begin(), fields.end());
48-
} else if (auto sort = node->as<Sort>()) {
49-
res = getColNames(sort->getInput(0));
50-
} else if (auto join = node->as<Join>()) {
51-
res = getColNames(join->getInput(0));
52-
auto rhs_names = getColNames(join->getInput(1));
53-
res.insert(rhs_names.begin(), rhs_names.end());
54-
} else {
55-
throw InvalidQueryError() << "getColNames error: unsupported node: "
56-
<< node->toString();
5739
}
5840
return res;
5941
}
6042

6143
std::string getFieldName(const Node* node, int col_idx) {
6244
col_idx = normalizeColIndex(node, col_idx);
63-
if (auto scan = node->as<Scan>()) {
64-
return scan->getFieldName(col_idx);
65-
}
66-
if (auto proj = node->as<Project>()) {
67-
return proj->getFieldName(col_idx);
68-
}
69-
if (auto filter = node->as<Filter>()) {
70-
return getFieldName(filter->getInput(0), col_idx);
71-
}
72-
if (auto agg = node->as<Aggregate>()) {
73-
return agg->getFieldName(col_idx);
74-
}
75-
if (auto sort = node->as<Sort>()) {
76-
return getFieldName(sort->getInput(0), col_idx);
77-
}
78-
if (auto join = node->as<Join>()) {
79-
if (col_idx < (int)join->getInput(0)->size()) {
80-
return getFieldName(join->getInput(0), col_idx);
81-
}
82-
return getFieldName(join->getInput(1), col_idx - join->getInput(0)->size());
83-
}
84-
85-
throw InvalidQueryError() << "getFieldName error: unsupported node: "
86-
<< node->toString();
45+
return node->getFieldName(col_idx);
8746
}
8847

8948
ExprPtr getRefByName(const Node* node,
9049
const std::string& col_name,
9150
bool allow_null_res = false);
9251

93-
int getRefIndexByName(const Node* node, const std::string& col_name) {
94-
if (node->is<Join>()) {
95-
auto lhs_ref_idx = getRefIndexByName(node->getInput(0), col_name);
96-
if (lhs_ref_idx >= 0) {
97-
return lhs_ref_idx;
98-
}
99-
auto rhs_ref_idx = getRefIndexByName(node->getInput(1), col_name);
100-
if (rhs_ref_idx >= 0) {
101-
return rhs_ref_idx + node->getInput(0)->size();
102-
}
103-
return -1;
104-
}
105-
106-
auto ref = getRefByName(node, col_name, true);
107-
return ref ? ref->as<ir::ColumnRef>()->index() : -1;
108-
}
109-
110-
ExprPtr getRefByName(const Scan* scan, const std::string& col_name) {
111-
for (size_t i = 0; i < scan->size(); ++i) {
112-
if (scan->getColumnInfo(i)->name == col_name) {
113-
return getNodeColumnRef(scan, (unsigned)i);
114-
}
115-
}
116-
return nullptr;
117-
}
118-
119-
ExprPtr getRefByName(const Project* proj, const std::string& col_name) {
120-
for (size_t i = 0; i < proj->size(); ++i) {
121-
if (proj->getFieldName(i) == col_name) {
122-
return getNodeColumnRef(proj, (unsigned)i);
123-
}
124-
}
125-
return nullptr;
126-
}
127-
128-
ExprPtr getRefByName(const Filter* filter, const std::string& col_name) {
129-
auto idx = getRefIndexByName(filter->getInput(0), col_name);
130-
if (idx >= 0) {
131-
return getNodeColumnRef(filter, idx);
132-
}
133-
return nullptr;
134-
}
135-
136-
ExprPtr getRefByName(const Aggregate* agg, const std::string& col_name) {
137-
for (size_t i = 0; i < agg->size(); ++i) {
138-
if (agg->getFieldName(i) == col_name) {
139-
return getNodeColumnRef(agg, (unsigned)i);
140-
}
141-
}
142-
return nullptr;
143-
}
144-
145-
ExprPtr getRefByName(const Sort* sort, const std::string& col_name) {
146-
auto idx = getRefIndexByName(sort->getInput(0), col_name);
147-
if (idx >= 0) {
148-
return getNodeColumnRef(sort, idx);
149-
}
150-
return nullptr;
151-
}
152-
15352
ExprPtr getRefByName(const Join* join, const std::string& col_name) {
15453
auto lhs_input_ref = getRefByName(join->getInput(0), col_name, true);
15554
auto rhs_input_ref = getRefByName(join->getInput(1), col_name, true);
@@ -162,21 +61,14 @@ ExprPtr getRefByName(const Join* join, const std::string& col_name) {
16261

16362
ExprPtr getRefByName(const Node* node, const std::string& col_name, bool allow_null_res) {
16463
ExprPtr res = nullptr;
165-
if (auto scan = node->as<Scan>()) {
166-
res = getRefByName(scan, col_name);
167-
} else if (auto proj = node->as<Project>()) {
168-
res = getRefByName(proj, col_name);
169-
} else if (auto filter = node->as<Filter>()) {
170-
res = getRefByName(filter, col_name);
171-
} else if (auto agg = node->as<Aggregate>()) {
172-
res = getRefByName(agg, col_name);
173-
} else if (auto sort = node->as<Sort>()) {
174-
res = getRefByName(sort, col_name);
175-
} else if (auto join = node->as<Join>()) {
64+
if (auto join = node->as<Join>()) {
17665
res = getRefByName(join, col_name);
17766
} else {
178-
throw InvalidQueryError() << "getRefByName error: unsupported node: "
179-
<< node->toString();
67+
for (size_t i = 0; i < node->size(); ++i) {
68+
if (node->getFieldName(i) == col_name) {
69+
res = getNodeColumnRef(node, (unsigned)i);
70+
}
71+
}
18072
}
18173

18274
if (!res && !allow_null_res) {

0 commit comments

Comments
 (0)