Skip to content

Commit

Permalink
Replace lltm with myaddmul; update to new custom ops APIs
Browse files Browse the repository at this point in the history
We're replacing lltm with myaddmul(a: Tensor, b: Tensor, c: float),
which just does a*b+c. This simplification allows us to focus on the
operator registration instead of get lost in the details of the
complicated lltm kernels.

Test Plan:
- tests

ghstack-source-id: dbf9af17fe0d66139320f45e5ed76a6331faefce
Pull Request resolved: #95
  • Loading branch information
zou3519 committed May 29, 2024
1 parent a5ed0b0 commit b75f3dc
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 460 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

An example of writing a C++/CUDA extension for PyTorch. See
[here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial.
This repo demonstrates how to write an example `extension_cpp.ops.lltm`
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
custom op that has both custom CPU and CUDA kernels.

The examples in this repo work with PyTorch 2.4+.

To build:
```
pip install .
Expand Down
183 changes: 0 additions & 183 deletions extension_cpp/csrc/cuda/lltm_cuda.cu

This file was deleted.

85 changes: 85 additions & 0 deletions extension_cpp/csrc/cuda/muladd.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

namespace extension_cpp {

__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}

at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();

int numel = a_contig.numel();
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}

__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx];
}

at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
return result;
}

__global__ void add_kernel(int numel, const float* a, const float* b, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx];
}

void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
int numel = a_contig.numel();
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
}


// Registers CUDA implementations for mymuladd, mymul, myadd_out
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
m.impl("mymul", &mymul_cuda);
m.impl("myadd_out", &myadd_out_cuda);
}

}
101 changes: 0 additions & 101 deletions extension_cpp/csrc/lltm.cpp

This file was deleted.

Loading

0 comments on commit b75f3dc

Please sign in to comment.