Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ class CinnJitInstruction::FnPtrImpl {

// Pass real tensor data to cinn_buffer_t func args placeholder
for (size_t i = 0; i < kernel_tensor_args.size(); ++i) {
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory =
reinterpret_cast<uint8_t*>(kernel_tensor_args[i]->data());
if (!kernel_tensor_args[i]->has_allocation()) {
VLOG(2) << "WARNING! Access DenseTensor::data() without allocation, "
"return nullptr!";
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory = nullptr;
} else {
cinn_pod_value_to_buffer_p(&(func_args_[i]))->memory =
reinterpret_cast<uint8_t*>(kernel_tensor_args[i]->data());
}
}

// Launch host kernel
Expand Down Expand Up @@ -297,6 +303,7 @@ CinnJitInstruction::CinnJitInstruction(
ir_dims_.push_back(
result.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
tensor_args_.push_back(tensor);
alloc_tensors_.push_back(tensor);
auto alloc_tensor_type =
result.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
tensor->set_type(
Expand All @@ -321,6 +328,7 @@ CinnJitInstruction::CinnJitInstruction(
}
for (auto& tensor : temp_space_tensors_) {
tensor_args_.push_back(&tensor);
alloc_tensors_.push_back(&tensor);
}
output_tensor_size += temp_space_tensors_.size();
}
Expand All @@ -343,8 +351,8 @@ void CinnJitInstruction::Run() {
fn_ptr_impl_->InferShape(
tensor_args_, ir_dims_, input_tensor_size, output_tensor_size);
}
for (size_t i = 0; i < tensor_args_.size(); ++i) {
dev_ctx_->Alloc(tensor_args_[i], tensor_args_[i]->dtype());
for (size_t i = 0; i < alloc_tensors_.size(); ++i) {
dev_ctx_->Alloc(alloc_tensors_[i], alloc_tensors_[i]->dtype());
}

// 2. execute kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CinnJitInstruction : public InstructionBase {

bool need_update_shape{false};
std::vector<phi::DenseTensor*> tensor_args_;
std::vector<phi::DenseTensor*> alloc_tensors_;
std::vector<phi::DDim> ir_dims_;

// Tensors that hold the temporary spaces used by the kernel. These tensors
Expand Down
Loading