Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] How to use CallFunction() when arg is a ExtensionScalar(ExtensionType) #39024

Open
Cerdore opened this issue Dec 1, 2023 · 5 comments · May be fixed by #39200
Open

[C++] How to use CallFunction() when arg is a ExtensionScalar(ExtensionType) #39024

Cerdore opened this issue Dec 1, 2023 · 5 comments · May be fixed by #39200
Assignees
Labels
Component: C++ Type: usage Issue is a user question

Comments

@Cerdore
Copy link
Contributor

Cerdore commented Dec 1, 2023

I have now inherited and implemented the ExtensionType , which storage type is int16().
I want to use the ExtensionType's Array and Scalar to call a custom add function by CallFunction. The custom add function using the built-in Kernel for the two int16(). I also wrote a TypeMatcher to match the Kernel.

Currently, two ExtensionArrays can be added together successfully, but there is an issue when adding an ExtensionArray and an ExtensionScalar; it throws a std::bad_cast error.

After investigating, I found that the parent class of ExtensionScalar is Scalar, but it needs to be cast to PrimitiveScalarBase (whose parent class is also Scalar), which fails.

ExtensionType:

class SmallintTypecxs : public ExtensionType
{
public:
	SmallintTypecxs() : ExtensionType(int16()) {}

	std::string extension_name() const override { return "smallintcxs"; }

	bool ExtensionEquals(const ExtensionType &other) const override;

	std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

	Result<std::shared_ptr<DataType>> Deserialize(
		std::shared_ptr<DataType> storage_type,
		const std::string &serialized) const override;

	std::string Serialize() const override { return "smallintcxs"; }

	const std::shared_ptr<DataType>& storage_type() const { return storage_type_; }
private:
};
std::shared_ptr<DataType> smallintcxs() { return std::make_shared<SmallintTypecxs>(); }

TypeMatcher

//============ type matcher
template <typename ArrowType>
class ExtendMatcher : public TypeMatcher {
  using ThisType = ExtendMatcher<ArrowType>;

 public:
  explicit ExtendMatcher(Type::type oid)
      : type_id_(oid) {}

  bool Matches(const DataType& type) const override {
    if (type.id() != ArrowType::type_id) {
      return false;
    }

	const auto& ex_type = arrow::internal::checked_cast<const ArrowType&>(type);

	return ex_type.storage_type().get()->id() == type_id_;
  }

  bool Equals(const TypeMatcher& other) const override {
    if (this == &other) {
      return true;
    }
    auto casted = dynamic_cast<const ThisType*>(&other);
    if (casted == nullptr) {
      return false;
    }
    return this->type_id_ == casted->type_id_;
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << ArrowType::type_name() << "(" << type_id_
       << ")";
    return ss.str();
  }

 private:
	Type::type type_id_;
};

using SmallIntMatcher = ExtendMatcher<SmallintTypecxs>;

std::shared_ptr<TypeMatcher> SmallIntcxsTypeMatcher(Type::type type_id) {
  return std::make_shared<SmallIntMatcher>(type_id);
}
Status RegisterExtensionFunc()
{
	auto registry = arrow::compute::GetFunctionRegistry();
	auto ty = smallintcxs();

	std::shared_ptr<Function> fc ;
	ARROW_ASSIGN_OR_RAISE(fc, registry->GetFunction("add"));

  	std::vector<TypeHolder> args = {int16(), int16()};

	const Kernel* kernel;
	ARROW_ASSIGN_OR_RAISE(kernel, arrow::internal::checked_cast<const ScalarFunction&>(*fc)
                             .DispatchExact(args));
	
	const arrow::compute::ScalarKernel* derived_ptr = static_cast<const arrow::compute::ScalarKernel*>(kernel);

	std::cout<<derived_ptr->signature->out_type().ToString()<<std::endl;

        auto func = std::make_shared<ScalarFunction>("cxsadd", Arity::Binary(),
                                                 /*doc=*/FunctionDoc::Empty());

        InputType in_type(SmallIntcxsTypeMatcher(ty.get()->storage_id()));

        ARROW_RETURN_NOT_OK(func->AddKernel({in_type,in_type}, int16(), derived_ptr->exec)); 
	
	ARROW_RETURN_NOT_OK(registry->AddFunction(func));

	return Status::OK();
}

CallFunction

arrow::Status testSmallIntType()
{

	ARROW_RETURN_NOT_OK(arrow::compute::RegisterExtensionFunc());

	auto ary1 = arrow::SmallintArraycxsFromJSON("[30000, 100, 200, 1, 2]");

	auto ary2 = arrow::SmallintArraycxsFromJSON("[30000, 100, 200, 1, 2]");
	
	std::shared_ptr<arrow::Scalar> scalar1;

	ARROW_ASSIGN_OR_RAISE(scalar1, arrow::MakeScalar(arrow::smallintcxs(), 1));

	arrow::Datum incremented_datum;


	//ARROW_ASSIGN_OR_RAISE(incremented_datum,
	//					  arrow::compute::CallFunction("cxsadd", {ary1, ary2})); // it's ok

	ARROW_ASSIGN_OR_RAISE(incremented_datum,
						  arrow::compute::CallFunction("cxsadd", {ary1, scalar1})); // not work
	std::shared_ptr<Array> incremented_array = std::move(incremented_datum).make_array();

	auto resarr = std::static_pointer_cast<arrow::Int16Array>(incremented_array);

	std::cout<<"begin output"<<std::endl;

	for(int i = 0; i < resarr->length(); i++)
	{
		std::cout << static_cast<int64_t>(resarr->Value(i)) << std::endl;
	}
  return arrow::Status::OK();
}

When I use 'ARROW_ASSIGN_OR_RAISE(incremented_datum,
arrow::compute::CallFunction("cxsadd", {ary1, scalar1})); ', program will abort and show this:

terminate called after throwing an instance of 'std::bad_cast'
  what():  std::bad_cast

coredump file:

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007fd65aa73859 in __GI_abort () at abort.c:79
#2  0x00007fd65ae708d1 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x00007fd65ae7c37c in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#4  0x00007fd65ae7c3e7 in std::terminate() () from /lib/x86_64-linux-gnu/libstdc++.so.6
#5  0x00007fd65ae7c699 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
#6  0x00007fd65ae7033c in __cxa_bad_cast () from /lib/x86_64-linux-gnu/libstdc++.so.6
#7  0x0000559b33971f3f in arrow::internal::checked_cast<arrow::internal::PrimitiveScalarBase const&, arrow::Scalar const&> (value=...)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/util/checked_cast.h:38
#8  0x0000559b3395fbae in arrow::compute::internal::UnboxScalar<arrow::Int16Type, void>::Unbox (val=...) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h:341
#9  0x0000559b33cd3331 in arrow::compute::internal::applicator::ScalarBinary<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, arrow::compute::internal::Add>::ArrayScalar (
    ctx=0x559b37531198, arg0=..., arg1=..., out=0x7fffb3211140) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h:753
#10 0x0000559b33cc0df8 in arrow::compute::internal::applicator::ScalarBinary<arrow::Int16Type, arrow::Int16Type, arrow::Int16Type, arrow::compute::internal::Add>::Exec (ctx=0x559b37531198,
    batch=..., out=0x7fffb3211140) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h:780
#11 0x0000559b33686c08 in arrow::compute::detail::(anonymous namespace)::ScalarExecutor::ExecuteSingleSpan (this=0x559b37531090, input=..., out=0x7fffb3211140)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/exec.cc:891
#12 0x0000559b3368665f in arrow::compute::detail::(anonymous namespace)::ScalarExecutor::ExecuteSpans (this=0x559b37531090, listener=0x7fffb3211350)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/exec.cc:859
#13 0x0000559b33686001 in arrow::compute::detail::(anonymous namespace)::ScalarExecutor::Execute (this=0x559b37531090, batch=..., listener=0x7fffb3211350)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/exec.cc:808
#14 0x0000559b336d267d in arrow::compute::detail::FunctionExecutorImpl::Execute (this=0x559b37531170, args=std::vector of length 2, capacity 2 = {...}, passed_length=-1)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/function.cc:276
#15 0x0000559b336d0262 in arrow::compute::(anonymous namespace)::ExecuteInternal (func=..., args=std::vector of length 2, capacity 2 = {...}, passed_length=-1, options=0x0,
    ctx=0x559b359cf320 <arrow::compute::default_exec_context()::default_ctx>) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/function.cc:341
#16 0x0000559b336d039e in arrow::compute::Function::Execute (this=0x559b37530270, args=std::vector of length 2, capacity 2 = {...}, options=0x0,
    ctx=0x559b359cf320 <arrow::compute::default_exec_context()::default_ctx>) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/function.cc:348
#17 0x0000559b3368a320 in arrow::compute::CallFunction (func_name="cxsadd", args=std::vector of length 2, capacity 2 = {...}, options=0x0,
    ctx=0x559b359cf320 <arrow::compute::default_exec_context()::default_ctx>) at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/exec.cc:1369
#18 0x0000559b3368a3db in arrow::compute::CallFunction (func_name="cxsadd", args=std::vector of length 2, capacity 2 = {...}, ctx=0x0)
    at /home/gpadmin/gitRepo/arrow/cpp/src/arrow/compute/exec.cc:1374
#19 0x0000559b332837d8 in testSmallIntType () at /home/gpadmin/testProjects/arrow_test/main.cpp:88
#20 0x0000559b33283de8 in main () at /home/gpadmin/testProjects/arrow_test/main.cpp:140

Component(s)

C++

@Cerdore Cerdore added the Type: usage Issue is a user question label Dec 1, 2023
@mapleFU
Copy link
Member

mapleFU commented Dec 1, 2023

Just curious, where would ExtendMatcher been used?

@Cerdore
Copy link
Contributor Author

Cerdore commented Dec 1, 2023

Just curious, where would ExtendMatcher been used?

Sorry, I have added some code to supplement the information.

Registering Function will use ExtendMatcher to construct InputType.

using SmallIntMatcher = ExtendMatcher<SmallintTypecxs>;

std::shared_ptr<TypeMatcher> SmallIntcxsTypeMatcher(Type::type type_id) {
  return std::make_shared<SmallIntMatcher>(type_id);
}

InputType in_type(SmallIntcxsTypeMatcher(ty.get()->storage_id()));

ARROW_RETURN_NOT_OK(func->AddKernel({in_type,in_type}, int16(), derived_ptr->exec)); 

@mapleFU
Copy link
Member

mapleFU commented Dec 1, 2023

Thanks! I think RegisterExtensionFunc() has dispatch to scalar add that caused the problem. But currently I didn't found some practice or example for using Scalar Ext type.

@js8544 @bkietz @felipecrv Would you mind take a look for help?

@js8544
Copy link
Collaborator

js8544 commented Dec 1, 2023

The problem occurs because the kernel you registered is for Int16, hence the checked_cast<arrow::internal::PrimitiveScalarBase&>. There are no generic support for ExtensionTypes in arrow compute. Reusing the storage types' kernels may work (like your add(arr1, arr2)) but there's no guarantee. So for now you'll have to write your own kernel to make sure it works.

However, I think it's worthwhile to provide general support for ExtensionTypes, and it's not hard to do so:

  1. First check if the function has a matching kernel for the extention types. If so, execute it and return.
  2. Replace the extension types with their storage types and dispatch again. If there is a match, replace the input ExtensionArray or ExtensionScalar with their storage values and execute the kernel.

This way users can simply define their own extention types and reuse the existing compute functions. No need to write extra kernels and registrations etc.

Does this sound like a viable approach? Also cc @pitrou.

@mapleFU
Copy link
Member

mapleFU commented Dec 1, 2023

(Also I think writing a self-defined kernel for ExtensionType is not so hard, but it's better to have some examples?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Component: C++ Type: usage Issue is a user question
Projects
None yet
3 participants