-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Typo in the addition kernel in muladd.cu #100
Comments
yes, I also hink this is an error. |
Hi! I'm currently learning the new version of custom operators (PyTorch 2.4 or later). Tutorial link: PyTorch Custom Operators Tutorial The following question isn’t directly related to the issue, but I wanted to seek some advice here. In the code below, everything runs smoothly in my environment. However, based on my understanding, it seems that I should use import torch
import extension_cpp
a = torch.rand(10, 10)
b = torch.rand(10, 10)
c = extension_cpp.ops.mymuladd(a, b, 0)
print(c) However, I can't find extension_cpp under torch.ops. Does anyone know the reason? |
may be this interface has been generated so you can call it directly? |
@cyk2018 Did you mean If I use jit, I can call it directly from import torch
from torch.utils.cpp_extension import load_inline
cpp_source = """
#include <torch/script.h>
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
}
"""
load_inline(
name='extension_cpp',
cpp_sources=[cpp_source],
is_python_module=False,
verbose=True
)
a = torch.randn(3, dtype=torch.float32).rename(None)
b = torch.randn(3, dtype=torch.float32).rename(None)
result = torch.ops.extension_cpp.mymuladd(a, b, 1.0)
print("Input a:", a)
print("Input b:", b)
print("Result:", result) |
I have solved it! I was missing Now I can successfully run import torch
import my_ops
a = torch.ones(3)
result = torch.ops.my_ops.ops1(a)
print("Result:", result) |
There is a typo in the add_kernel routine in the CUDA file muladd.cu. I assumed that this kernel should compute the sum of two tensors, but it acually computes the multiplication:
This bug can be tested by running the following python script:
The CPU implementation gives the correct result (with device="cpu" in the code above).
The text was updated successfully, but these errors were encountered: