Skip to content

Commit 857bf57

Browse files
committed
GDV-55: [C++] Added validation to projector build. (apache#33)
Validating the input schema and expressions during the projector build.
1 parent 1b1ac90 commit 857bf57

16 files changed

+561
-31
lines changed

cpp/src/gandiva/codegen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_library(gandiva SHARED
3232
projector.cc
3333
status.cc
3434
tree_expr_builder.cc
35+
expr_validator.cc
3536
${BC_FILE_PATH_CC})
3637

3738
# For users of gandiva library (including integ tests), include-dir is :

cpp/src/gandiva/codegen/expr_decomposer.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,18 @@
2828
namespace gandiva {
2929

3030
// Decompose a field node - simply seperate out validity & value arrays.
31-
void ExprDecomposer::Visit(const FieldNode &node) {
31+
Status ExprDecomposer::Visit(const FieldNode &node) {
3232
auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field());
3333

3434
DexPtr validity_dex = std::make_shared<VectorReadValidityDex>(desc);
3535
DexPtr value_dex = std::make_shared<VectorReadValueDex>(desc);
3636
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
37+
return Status::OK();
3738
}
3839

3940
// Decompose a field node - wherever possible, merge the validity vectors of the
4041
// child nodes.
41-
void ExprDecomposer::Visit(const FunctionNode &node) {
42+
Status ExprDecomposer::Visit(const FunctionNode &node) {
4243
auto desc = node.descriptor();
4344
FunctionSignature signature(desc->name(),
4445
desc->params(),
@@ -84,10 +85,11 @@ void ExprDecomposer::Visit(const FunctionNode &node) {
8485
local_bitmap_idx);
8586
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
8687
}
88+
return Status::OK();
8789
}
8890

8991
// Decompose an IfNode
90-
void ExprDecomposer::Visit(const IfNode &node) {
92+
Status ExprDecomposer::Visit(const IfNode &node) {
9193
// Add a local bitmap to track the output validity.
9294
node.condition()->Accept(*this);
9395
auto condition_vv = result();
@@ -111,11 +113,13 @@ void ExprDecomposer::Visit(const IfNode &node) {
111113
is_terminal_else);
112114

113115
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
116+
return Status::OK();
114117
}
115118

116-
void ExprDecomposer::Visit(const LiteralNode &node) {
119+
Status ExprDecomposer::Visit(const LiteralNode &node) {
117120
auto value_dex = std::make_shared<LiteralDex>(node.return_type(), node.holder());
118121
result_ = std::make_shared<ValueValidityPair>(value_dex);
122+
return Status::OK();
119123
}
120124

121125
// The bolow functions use a stack to detect :

cpp/src/gandiva/codegen/expr_decomposer.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ class ExprDecomposer : public NodeVisitor {
4949
FRIEND_TEST(TestExprDecomposer, TestInternalIf);
5050
FRIEND_TEST(TestExprDecomposer, TestParallelIf);
5151

52-
void Visit(const FieldNode &node) override;
53-
void Visit(const FunctionNode &node) override;
54-
void Visit(const IfNode &node) override;
55-
void Visit(const LiteralNode &node) override;
52+
Status Visit(const FieldNode &node) override;
53+
Status Visit(const FunctionNode &node) override;
54+
Status Visit(const IfNode &node) override;
55+
Status Visit(const LiteralNode &node) override;
5656

5757
// stack of if nodes.
5858
class IfStackEntry {
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Copyright (C) 2017-2018 Dremio Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
#include <sstream>
17+
#include <vector>
18+
19+
#include "codegen/expr_validator.h"
20+
21+
namespace gandiva {
22+
23+
Status ExprValidator::Validate(const ExpressionPtr &expr) {
24+
if (expr == nullptr) {
25+
return Status::ExpressionValidationError("Expression cannot be null.");
26+
}
27+
Node &root = *expr->root();
28+
Status status = root.Accept(*this);
29+
if (!status.ok()) {
30+
return status;
31+
}
32+
// validate return type matches
33+
// no need to check if type is supported
34+
// since root type has been validated.
35+
if (!root.return_type()->Equals(*expr->result()->type())) {
36+
std::stringstream ss;
37+
ss << "Return type of root node " << root.return_type()->name()
38+
<< " does not match that of expression " << *expr->result()->type();
39+
return Status::ExpressionValidationError(ss.str());
40+
}
41+
return Status::OK();
42+
}
43+
44+
Status ExprValidator::Visit(const FieldNode &node) {
45+
auto llvm_type = types_->IRType(node.return_type()->id());
46+
if (llvm_type == nullptr) {
47+
std::stringstream ss;
48+
ss << "Field "<< node.field()->name() << " has unsupported data type "
49+
<< node.return_type()->name();
50+
return Status::ExpressionValidationError(ss.str());
51+
}
52+
53+
auto field_in_schema_entry = field_map_.find(node.field()->name());
54+
55+
// validate that field is in schema.
56+
if (field_in_schema_entry == field_map_.end()) {
57+
std::stringstream ss;
58+
ss << "Field " << node.field()->name() << " not in schema.";
59+
return Status::ExpressionValidationError(ss.str());
60+
}
61+
62+
FieldPtr field_in_schema = field_in_schema_entry->second;
63+
// validate that field matches the definition in schema.
64+
if (!field_in_schema->Equals(node.field())) {
65+
std::stringstream ss;
66+
ss << "Field definition in schema " << field_in_schema->ToString()
67+
<< " different from field in expression " << node.field()->ToString();
68+
return Status::ExpressionValidationError(ss.str());
69+
}
70+
return Status::OK();
71+
}
72+
73+
Status ExprValidator::Visit(const FunctionNode &node) {
74+
auto desc = node.descriptor();
75+
FunctionSignature signature(desc->name(),
76+
desc->params(),
77+
desc->return_type());
78+
const NativeFunction *native_function = registry_.LookupSignature(signature);
79+
if (native_function == nullptr) {
80+
std::stringstream ss;
81+
ss << "Function "<< signature.ToString() << " not supported yet. ";
82+
return Status::ExpressionValidationError(ss.str());
83+
}
84+
85+
for (auto &child : node.children()) {
86+
Status status = child->Accept(*this);
87+
GANDIVA_RETURN_NOT_OK(status);
88+
}
89+
return Status::OK();
90+
}
91+
92+
Status ExprValidator::Visit(const IfNode &node) {
93+
Status status = node.condition()->Accept(*this);
94+
GANDIVA_RETURN_NOT_OK(status);
95+
status = node.then_node()->Accept(*this);
96+
GANDIVA_RETURN_NOT_OK(status);
97+
status = node.else_node()->Accept(*this);
98+
GANDIVA_RETURN_NOT_OK(status);
99+
100+
auto if_node_ret_type = node.return_type();
101+
auto then_node_ret_type = node.then_node()->return_type();
102+
auto else_node_ret_type = node.else_node()->return_type();
103+
104+
if (if_node_ret_type != then_node_ret_type) {
105+
std::stringstream ss;
106+
ss << "Return type of if "<< *if_node_ret_type << " and then "
107+
<< then_node_ret_type->name() << " not matching.";
108+
return Status::ExpressionValidationError(ss.str());
109+
}
110+
111+
if (if_node_ret_type != else_node_ret_type) {
112+
std::stringstream ss;
113+
ss << "Return type of if "<< *if_node_ret_type << " and else "
114+
<< else_node_ret_type->name() << " not matching.";
115+
return Status::ExpressionValidationError(ss.str());
116+
}
117+
118+
return Status::OK();
119+
}
120+
121+
Status ExprValidator::Visit(const LiteralNode &node) {
122+
auto llvm_type = types_->IRType(node.return_type()->id());
123+
if (llvm_type == nullptr) {
124+
std::stringstream ss;
125+
ss << "Value "<< node.holder() << " has unsupported data type "
126+
<< node.return_type()->name();
127+
return Status::ExpressionValidationError(ss.str());
128+
}
129+
return Status::OK();
130+
}
131+
132+
} // namespace gandiva
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (C) 2017-2018 Dremio Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef GANDIVA_EXPR_VALIDATOR_H
16+
#define GANDIVA_EXPR_VALIDATOR_H
17+
18+
#include <string>
19+
#include <unordered_map>
20+
21+
#include "boost/functional/hash.hpp"
22+
#include "codegen/function_registry.h"
23+
#include "codegen/node_visitor.h"
24+
#include "codegen/node.h"
25+
#include "codegen/llvm_types.h"
26+
#include "gandiva/arrow.h"
27+
#include "gandiva/expression.h"
28+
#include "gandiva/status.h"
29+
30+
namespace gandiva {
31+
32+
class FunctionRegistry;
33+
34+
/// \brief Validates the entire expression tree including
35+
/// data types, signatures and return types
36+
class ExprValidator : public NodeVisitor {
37+
public:
38+
explicit ExprValidator(LLVMTypes * types, SchemaPtr schema)
39+
: types_(types),
40+
schema_(schema) {
41+
for (auto &field : schema_->fields()) {
42+
field_map_[field->name()] = field;
43+
}
44+
}
45+
46+
/// \brief Validates the root node
47+
/// of an expression.
48+
/// 1. Data type of fields and literals.
49+
/// 2. Function signature is supported.
50+
/// 3. For if nodes that return types match
51+
/// for if, then and else nodes.
52+
Status Validate(const ExpressionPtr &expr);
53+
54+
private:
55+
Status Visit(const FieldNode &node) override;
56+
Status Visit(const FunctionNode &node) override;
57+
Status Visit(const IfNode &node) override;
58+
Status Visit(const LiteralNode &node) override;
59+
60+
FunctionRegistry registry_;
61+
62+
LLVMTypes *types_;
63+
64+
SchemaPtr schema_;
65+
66+
using FieldMap = std::unordered_map<std::string,
67+
FieldPtr,
68+
boost::hash<std::string>>;
69+
FieldMap field_map_;
70+
};
71+
72+
} // namespace gandiva
73+
74+
#endif //GANDIVA_EXPR_VALIDATOR_H

cpp/src/gandiva/codegen/llvm_generator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class LLVMGenerator {
4949
Status Execute(const arrow::RecordBatch &record_batch,
5050
const ArrayDataVector &output_vector);
5151

52+
LLVMTypes *types() { return types_; }
53+
5254
private:
5355
LLVMGenerator();
5456

@@ -59,7 +61,6 @@ class LLVMGenerator {
5961
llvm::Module *module() { return engine_->module(); }
6062
llvm::LLVMContext &context() { return *(engine_->context()); }
6163
llvm::IRBuilder<> &ir_builder() { return engine_->ir_builder(); }
62-
LLVMTypes *types() { return types_; }
6364

6465
/// Visitor to generate the code for a decomposed expression.
6566
class Visitor : public DexVisitor {

cpp/src/gandiva/codegen/node.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "codegen/node_visitor.h"
2424
#include "gandiva/arrow.h"
2525
#include "gandiva/gandiva_aliases.h"
26+
#include "gandiva/status.h"
2627

2728
namespace gandiva {
2829

@@ -36,7 +37,7 @@ class Node {
3637
const DataTypePtr &return_type() const { return return_type_; }
3738

3839
/// Derived classes should simply invoke the Visit api of the visitor.
39-
virtual void Accept(NodeVisitor &visitor) const = 0;
40+
virtual Status Accept(NodeVisitor &visitor) const = 0;
4041

4142
protected:
4243
DataTypePtr return_type_;
@@ -49,8 +50,8 @@ class LiteralNode : public Node {
4950
: Node(type),
5051
holder_(holder) {}
5152

52-
void Accept(NodeVisitor &visitor) const override {
53-
visitor.Visit(*this);
53+
Status Accept(NodeVisitor &visitor) const override {
54+
return visitor.Visit(*this);
5455
}
5556

5657
const LiteralHolder &holder() const { return holder_; }
@@ -65,8 +66,8 @@ class FieldNode : public Node {
6566
explicit FieldNode(FieldPtr field)
6667
: Node(field->type()), field_(field) {}
6768

68-
void Accept(NodeVisitor &visitor) const override {
69-
visitor.Visit(*this);
69+
Status Accept(NodeVisitor &visitor) const override {
70+
return visitor.Visit(*this);
7071
}
7172

7273
const FieldPtr &field() const { return field_; }
@@ -83,8 +84,8 @@ class FunctionNode : public Node {
8384
DataTypePtr retType)
8485
: Node(retType), descriptor_(descriptor), children_(children) { }
8586

86-
void Accept(NodeVisitor &visitor) const override {
87-
visitor.Visit(*this);
87+
Status Accept(NodeVisitor &visitor) const override {
88+
return visitor.Visit(*this);
8889
}
8990

9091
const FuncDescriptorPtr &descriptor() const { return descriptor_; }
@@ -125,8 +126,8 @@ class IfNode : public Node {
125126
then_node_(then_node),
126127
else_node_(else_node) {}
127128

128-
void Accept(NodeVisitor &visitor) const override {
129-
visitor.Visit(*this);
129+
Status Accept(NodeVisitor &visitor) const override {
130+
return visitor.Visit(*this);
130131
}
131132

132133
const NodePtr &condition() const { return condition_; }

cpp/src/gandiva/codegen/node_visitor.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define GANDIVA_NODE_VISITOR_H
1717

1818
#include "gandiva/logging.h"
19+
#include "gandiva/status.h"
1920

2021
namespace gandiva {
2122

@@ -27,10 +28,10 @@ class LiteralNode;
2728
/// \brief Visitor for nodes in the expression tree.
2829
class NodeVisitor {
2930
public:
30-
virtual void Visit(const FieldNode &node) = 0;
31-
virtual void Visit(const FunctionNode &node) = 0;
32-
virtual void Visit(const IfNode &node) = 0;
33-
virtual void Visit(const LiteralNode &node) = 0;
31+
virtual Status Visit(const FieldNode &node) = 0;
32+
virtual Status Visit(const FunctionNode &node) = 0;
33+
virtual Status Visit(const IfNode &node) = 0;
34+
virtual Status Visit(const LiteralNode &node) = 0;
3435
};
3536

3637
} // namespace gandiva

0 commit comments

Comments
 (0)