Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8()},
utf8(), kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8",
NativeFunction::kNeedsContext |
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8(), int32()},
utf8(), kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8_int32",
NativeFunction::kNeedsContext |
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/gandiva/gdv_string_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ const char* gdv_fn_regexp_replace_utf8_utf8(
out_length);
}

const char* gdv_fn_regexp_extract_utf8_utf8(int64_t ptr, int64_t holder_ptr,
const char* data, int32_t data_len,
const char* /*pattern*/,
int32_t /*pattern_len*/,
int32_t* out_length) {
gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);
gandiva::ExtractHolder* holder = reinterpret_cast<gandiva::ExtractHolder*>(holder_ptr);
return (*holder)(context, data, data_len, 1, out_length);
}

const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_ptr,
const char* data, int32_t data_len,
const char* /*pattern*/,
Expand Down Expand Up @@ -855,6 +865,19 @@ arrow::Status ExportedStringFunctions::AddMappings(Engine* engine) const {
"gdv_fn_regexp_extract_utf8_utf8_int32", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_regexp_extract_utf8_utf8_int32));

// gdv_fn_regexp_extract_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i64_type(), // int64_t holder_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int data_len
types->i8_ptr_type(), // const char* pattern
types->i32_type(), // int pattern_len
types->i32_ptr_type()}; // int32_t* out_length

engine->AddGlobalMappingForFunc(
"gdv_fn_regexp_extract_utf8_utf8", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_regexp_extract_utf8_utf8));

// gdv_fn_castVARCHAR_int32_int64
args = {types->i64_type(), // int64_t execution_context
types->i32_type(), // int32_t value
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/regex_functions_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ void ReplaceHolder::return_error(ExecutionContext* context, std::string& data,
}

Result<std::shared_ptr<ExtractHolder>> ExtractHolder::Make(const FunctionNode& node) {
ARROW_RETURN_IF(node.children().size() != 3,
Status::Invalid("'extract' function requires three parameters"));
ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3,
Status::Invalid("'extract' function requires two or three parameters"));

auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
ARROW_RETURN_IF(
Expand Down
31 changes: 27 additions & 4 deletions cpp/src/gandiva/regex_functions_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,24 +604,47 @@ TEST_F(TestExtractHolder, TestExtractInvalidPattern) {
execution_context_.Reset();
}

TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) {
// Create function with incorrect number of params
TEST_F(TestExtractHolder, TestDefaultIndexExtract) {
// 2-arg form defaults to index 1 (first capture group)
auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
auto pattern_node = std::make_shared<LiteralNode>(
arrow::utf8(), LiteralHolder(R"((\w+) (\w+))"), false);
auto function_node =
FunctionNode("regexp_extract", {field, pattern_node}, arrow::utf8());

EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(function_node));

std::string input_string = "John Doe";
int32_t out_length = 0;

auto& extract = *extract_holder;
const char* ret =
extract(&execution_context_, input_string.c_str(),
static_cast<int32_t>(input_string.length()), 1, &out_length);
EXPECT_EQ(std::string(ret, out_length), "John");

input_string = "Ringo Beast";
ret = extract(&execution_context_, input_string.c_str(),
static_cast<int32_t>(input_string.length()), 1, &out_length);
Comment thread
lriggs marked this conversation as resolved.
EXPECT_EQ(std::string(ret, out_length), "Ringo");
}

TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) {
// Create function with incorrect number of params (one arg)
auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
NodeVector one_arg = {field};
auto function_node = FunctionNode("regexp_extract", one_arg, arrow::utf8());

auto extract_holder = ExtractHolder::Make(function_node);
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("'extract' function requires three parameters"),
Invalid, ::testing::HasSubstr("'extract' function requires two or three parameters"),
extract_holder.status());

execution_context_.Reset();

// Create function with non-utf8 literal parameter as pattern
field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
pattern_node = std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(2), false);
auto pattern_node = std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(2), false);
auto index_node = std::make_shared<FieldNode>(arrow::field("idx", arrow::int32()));
function_node =
FunctionNode("regexp_extract", {field, pattern_node, index_node}, arrow::utf8());
Expand Down
Loading