-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace lltm with myaddmul; update to new custom ops APIs
We're replacing lltm with myaddmul(a: Tensor, b: Tensor, c: float), which just does a*b+c. This simplification allows us to focus on the operator registration instead of get lost in the details of the complicated lltm kernels. Test Plan: - tests ghstack-source-id: dbf9af17fe0d66139320f45e5ed76a6331faefce Pull Request resolved: #95
- Loading branch information
Showing
8 changed files
with
291 additions
and
460 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#include <torch/extension.h> | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
namespace extension_cpp { | ||
|
||
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx < numel) result[idx] = a[idx] * b[idx] + c; | ||
} | ||
|
||
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { | ||
TORCH_CHECK(a.sizes() == b.sizes()); | ||
TORCH_CHECK(a.dtype() == at::kFloat); | ||
TORCH_CHECK(b.dtype() == at::kFloat); | ||
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); | ||
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); | ||
at::Tensor a_contig = a.contiguous(); | ||
at::Tensor b_contig = b.contiguous(); | ||
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); | ||
const float* a_ptr = a_contig.data_ptr<float>(); | ||
const float* b_ptr = b_contig.data_ptr<float>(); | ||
float* result_ptr = result.data_ptr<float>(); | ||
|
||
int numel = a_contig.numel(); | ||
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); | ||
return result; | ||
} | ||
|
||
__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx < numel) result[idx] = a[idx] * b[idx]; | ||
} | ||
|
||
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { | ||
TORCH_CHECK(a.sizes() == b.sizes()); | ||
TORCH_CHECK(a.dtype() == at::kFloat); | ||
TORCH_CHECK(b.dtype() == at::kFloat); | ||
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); | ||
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); | ||
at::Tensor a_contig = a.contiguous(); | ||
at::Tensor b_contig = b.contiguous(); | ||
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); | ||
const float* a_ptr = a_contig.data_ptr<float>(); | ||
const float* b_ptr = b_contig.data_ptr<float>(); | ||
float* result_ptr = result.data_ptr<float>(); | ||
int numel = a_contig.numel(); | ||
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); | ||
return result; | ||
} | ||
|
||
__global__ void add_kernel(int numel, const float* a, const float* b, float* result) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx < numel) result[idx] = a[idx] * b[idx]; | ||
} | ||
|
||
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { | ||
TORCH_CHECK(a.sizes() == b.sizes()); | ||
TORCH_CHECK(b.sizes() == out.sizes()); | ||
TORCH_CHECK(a.dtype() == at::kFloat); | ||
TORCH_CHECK(b.dtype() == at::kFloat); | ||
TORCH_CHECK(out.dtype() == at::kFloat); | ||
TORCH_CHECK(out.is_contiguous()); | ||
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); | ||
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); | ||
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA); | ||
at::Tensor a_contig = a.contiguous(); | ||
at::Tensor b_contig = b.contiguous(); | ||
const float* a_ptr = a_contig.data_ptr<float>(); | ||
const float* b_ptr = b_contig.data_ptr<float>(); | ||
float* result_ptr = out.data_ptr<float>(); | ||
int numel = a_contig.numel(); | ||
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); | ||
} | ||
|
||
|
||
// Registers CUDA implementations for mymuladd, mymul, myadd_out | ||
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { | ||
m.impl("mymuladd", &mymuladd_cuda); | ||
m.impl("mymul", &mymul_cuda); | ||
m.impl("myadd_out", &myadd_out_cuda); | ||
} | ||
|
||
} |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.