diff --git a/README.md b/README.md index a0aae0c..709441b 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,11 @@ An example of writing a C++/CUDA extension for PyTorch. See [here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial. -This repo demonstrates how to write an example `extension_cpp.ops.lltm` +This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` custom op that has both custom CPU and CUDA kernels. +The examples in this repo work with PyTorch 2.4+. + To build: ``` pip install . diff --git a/extension_cpp/csrc/cuda/lltm_cuda.cu b/extension_cpp/csrc/cuda/lltm_cuda.cu deleted file mode 100644 index 7612b84..0000000 --- a/extension_cpp/csrc/cuda/lltm_cuda.cu +++ /dev/null @@ -1,183 +0,0 @@ -#include <torch/extension.h> - -#include <cuda.h> -#include <cuda_runtime.h> - -#include <vector> - -namespace { -template <typename scalar_t> -__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { - return 1.0 / (1.0 + exp(-z)); -} - -template <typename scalar_t> -__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { - const auto s = sigmoid(z); - return (1.0 - s) * s; -} - -template <typename scalar_t> -__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { - const auto t = tanh(z); - return 1 - (t * t); -} - -template <typename scalar_t> -__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { - return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); -} - -template <typename scalar_t> -__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { - const auto e = exp(z); - const auto d_relu = z < 0.0 ? 0.0 : 1.0; - return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); -} - -template <typename scalar_t> -__global__ void lltm_cuda_forward_kernel( - const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell, - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h, - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell, - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate, - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate, - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) { - //batch index - const int n = blockIdx.y; - // column index - const int c = blockIdx.x * blockDim.x + threadIdx.x; - if (c < gates.size(2)){ - input_gate[n][c] = sigmoid(gates[n][0][c]); - output_gate[n][c] = sigmoid(gates[n][1][c]); - candidate_cell[n][c] = elu(gates[n][2][c]); - new_cell[n][c] = - old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; - new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; - } -} - -template <typename scalar_t> -__global__ void lltm_cuda_backward_kernel( - torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell, - torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate, - const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell, - const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) { - //batch index - const int n = blockIdx.y; - // column index - const int c = blockIdx.x * blockDim.x + threadIdx.x; - if (c < d_gates.size(2)){ - const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; - const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; - const auto d_new_cell = - d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; - - - d_old_cell[n][c] = d_new_cell; - const auto d_candidate_cell = input_gate[n][c] * d_new_cell; - const auto d_input_gate = candidate_cell[n][c] * d_new_cell; - - d_gates[n][0][c] = - d_input_gate * d_sigmoid(gate_weights[n][0][c]); - d_gates[n][1][c] = - d_output_gate * d_sigmoid(gate_weights[n][1][c]); - d_gates[n][2][c] = - d_candidate_cell * d_elu(gate_weights[n][2][c]); - } -} -} // namespace - -std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_cuda_forward( - torch::Tensor input, - torch::Tensor weights, - torch::Tensor bias, - torch::Tensor old_h, - torch::Tensor old_cell) { - auto X = torch::cat({old_h, input}, /*dim=*/1); - auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); - - const auto batch_size = old_cell.size(0); - const auto state_size = old_cell.size(1); - - auto gates = gate_weights.reshape({batch_size, 3, state_size}); - auto new_h = torch::zeros_like(old_cell); - auto new_cell = torch::zeros_like(old_cell); - auto input_gate = torch::zeros_like(old_cell); - auto output_gate = torch::zeros_like(old_cell); - auto candidate_cell = torch::zeros_like(old_cell); - - const int threads = 1024; - const dim3 blocks((state_size + threads - 1) / threads, batch_size); - - AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { - lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>( - gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(), - old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()); - })); - - return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; -} - -std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_cuda_backward( - torch::Tensor grad_h, - torch::Tensor grad_cell, - torch::Tensor new_cell, - torch::Tensor input_gate, - torch::Tensor output_gate, - torch::Tensor candidate_cell, - torch::Tensor X, - torch::Tensor gates, - torch::Tensor weights) { - auto d_old_cell = torch::zeros_like(new_cell); - auto d_gates = torch::zeros_like(gates); - - auto grad_h_contig = grad_h.contiguous(); - auto grad_cell_contig = grad_cell.contiguous(); - - const auto batch_size = new_cell.size(0); - const auto state_size = new_cell.size(1); - - const int threads = 1024; - const dim3 blocks((state_size + threads - 1) / threads, batch_size); - - AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { - lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>( - d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(), - grad_h_contig.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - grad_cell_contig.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), - gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>()); - })); - - auto d_gate_weights = d_gates.flatten(1, 2); - auto d_weights = d_gate_weights.t().mm(X); - auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); - - auto d_X = d_gate_weights.mm(weights); - auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); - auto d_input = d_X.slice(/*dim=*/1, state_size); - - return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; -} - -// Registers CUDA implementations for lltm_forward, lltm_backward -TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { - m.impl("lltm_forward", &lltm_cuda_forward); - m.impl("lltm_backward", &lltm_cuda_backward); -} diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu new file mode 100644 index 0000000..6700d08 --- /dev/null +++ b/extension_cpp/csrc/cuda/muladd.cu @@ -0,0 +1,85 @@ +#include <torch/extension.h> + +#include <cuda.h> +#include <cuda_runtime.h> + +namespace extension_cpp { + +__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx] + c; +} + +at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = result.data_ptr<float>(); + + int numel = a_contig.numel(); + muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + return result; +} + +__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx]; +} + +at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = result.data_ptr<float>(); + int numel = a_contig.numel(); + mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); + return result; +} + +__global__ void add_kernel(int numel, const float* a, const float* b, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx]; +} + +void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(b.sizes() == out.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_CHECK(out.dtype() == at::kFloat); + TORCH_CHECK(out.is_contiguous()); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = out.data_ptr<float>(); + int numel = a_contig.numel(); + add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr); +} + + +// Registers CUDA implementations for mymuladd, mymul, myadd_out +TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", &mymuladd_cuda); + m.impl("mymul", &mymul_cuda); + m.impl("myadd_out", &myadd_out_cuda); +} + +} diff --git a/extension_cpp/csrc/lltm.cpp b/extension_cpp/csrc/lltm.cpp deleted file mode 100644 index c915dd9..0000000 --- a/extension_cpp/csrc/lltm.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include <torch/extension.h> - -#include <vector> - -// s'(z) = (1 - s(z)) * s(z) -torch::Tensor d_sigmoid(torch::Tensor z) { - auto s = torch::sigmoid(z); - return (1 - s) * s; -} - -// tanh'(z) = 1 - tanh^2(z) -torch::Tensor d_tanh(torch::Tensor z) { - return 1 - z.tanh().pow(2); -} - -// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} -torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { - auto e = z.exp(); - auto mask = (alpha * (e - 1)) < 0; - return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); -} - -std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_forward( - torch::Tensor input, - torch::Tensor weights, - torch::Tensor bias, - torch::Tensor old_h, - torch::Tensor old_cell) { - auto X = torch::cat({old_h, input}, /*dim=*/1); - - auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); - auto gates = gate_weights.chunk(3, /*dim=*/1); - - auto input_gate = torch::sigmoid(gates[0]); - auto output_gate = torch::sigmoid(gates[1]); - auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); - - auto new_cell = old_cell + candidate_cell * input_gate; - auto new_h = torch::tanh(new_cell) * output_gate; - - return {new_h, - new_cell, - input_gate, - output_gate, - candidate_cell, - X, - gate_weights}; -} - -std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_backward( - torch::Tensor grad_h, - torch::Tensor grad_cell, - torch::Tensor new_cell, - torch::Tensor input_gate, - torch::Tensor output_gate, - torch::Tensor candidate_cell, - torch::Tensor X, - torch::Tensor gate_weights, - torch::Tensor weights) { - auto d_output_gate = torch::tanh(new_cell) * grad_h; - auto d_tanh_new_cell = output_gate * grad_h; - auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; - - auto d_old_cell = d_new_cell; - auto d_candidate_cell = input_gate * d_new_cell; - auto d_input_gate = candidate_cell * d_new_cell; - - auto gates = gate_weights.chunk(3, /*dim=*/1); - d_input_gate *= d_sigmoid(gates[0]); - d_output_gate *= d_sigmoid(gates[1]); - d_candidate_cell *= d_elu(gates[2]); - - auto d_gates = - torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); - - auto d_weights = d_gates.t().mm(X); - auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); - - auto d_X = d_gates.mm(weights); - const auto state_size = grad_h.size(1); - auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); - auto d_input = d_X.slice(/*dim=*/1, state_size); - - return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; -} - -// Registers _C as an extension module. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} - -// Defines the operators -TORCH_LIBRARY(extension_cpp, m) { - m.impl_abstract_pystub("extension_cpp.ops"); - m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); -} - -// Registers CPU implementations for lltm_forward, lltm_backward -TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("lltm_forward", &lltm_forward); - m.impl("lltm_backward", &lltm_backward); -} diff --git a/extension_cpp/csrc/muladd.cpp b/extension_cpp/csrc/muladd.cpp new file mode 100644 index 0000000..73b8f18 --- /dev/null +++ b/extension_cpp/csrc/muladd.cpp @@ -0,0 +1,81 @@ +#include <torch/extension.h> + +#include <vector> + +namespace extension_cpp { + +at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = result.data_ptr<float>(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; +} + +at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = result.data_ptr<float>(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i]; + } + return result; +} + +// An example of an operator that mutates one of its inputs. +void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(b.sizes() == out.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_CHECK(out.dtype() == at::kFloat); + TORCH_CHECK(out.is_contiguous()); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + const float* a_ptr = a_contig.data_ptr<float>(); + const float* b_ptr = b_contig.data_ptr<float>(); + float* result_ptr = out.data_ptr<float>(); + for (int64_t i = 0; i < out.numel(); i++) { + result_ptr[i] = a_ptr[i] + b_ptr[i]; + } +} + +// Registers _C as a Python extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +// Defines the operators +TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); +} + +// Registers CUDA implementations for mymuladd, mymul, myadd_out +TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + m.impl("mymul", &mymul_cpu); + m.impl("myadd_out", &myadd_out_cpu); +} + +} diff --git a/extension_cpp/ops.py b/extension_cpp/ops.py index 16c0311..4d8982e 100644 --- a/extension_cpp/ops.py +++ b/extension_cpp/ops.py @@ -1,78 +1,63 @@ -from typing import Tuple import torch from torch import Tensor -__all__ = ["lltm", "reference_lltm"] +__all__ = ["mymuladd", "myadd_out"] -def lltm( - input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor -) -> Tuple[Tensor, Tensor]: - return LLTMFunction.apply(input, weights, bias, old_h, old_cell) +def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor: + """Performs a * b + c in an efficient fused kernel""" + return torch.ops.extension_cpp.mymuladd.default(a, b, c) -class LLTMFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = torch.ops.extension_cpp.lltm_forward.default( - input, weights, bias, old_h, old_cell - ) - new_h, new_cell = outputs[:2] - variables = list(outputs[1:]) + [weights] - ctx.save_for_backward(*variables) +# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl") +# that describes what the properties of the output Tensor are given +# the properties of the input Tensor. The FakeTensor kernel is necessary +# for the op to work performantly with torch.compile. +@torch.library.register_fake("extension_cpp::mymuladd") +def _(a, b, c): + torch._check(a.shape == b.shape) + torch._check(a.dtype == torch.float) + torch._check(b.dtype == torch.float) + torch._check(a.device == b.device) + return torch.empty_like(a) - return new_h, new_cell - @staticmethod - @torch.autograd.function.once_differentiable - def backward(ctx, grad_h, grad_cell): - ( - d_old_h, - d_input, - d_weights, - d_bias, - d_old_cell, - ) = torch.ops.extension_cpp.lltm_backward.default( - grad_h, grad_cell, *ctx.saved_tensors - ) - return d_input, d_weights, d_bias, d_old_h, d_old_cell +def _backward(ctx, grad): + a, b = ctx.saved_tensors + grad_a, grad_b = None, None + if ctx.needs_input_grad[0]: + grad_a = torch.ops.extension_cpp.mymul.default(grad, b) + if ctx.needs_input_grad[1]: + grad_b = torch.ops.extension_cpp.mymul.default(grad, a) + return grad_a, grad_b, None -@torch.library.impl_abstract("extension_cpp::lltm_forward") -def _(input, weights, bias, old_h, old_cell): - X = torch.cat([old_h, input], dim=1) - gate_weights = torch.nn.functional.linear(X, weights, bias) - gates = gate_weights.chunk(3, dim=1) - input_gate = torch.empty_like(gates[0]) - output_gate = torch.empty_like(gates[1]) - candidate_cell = torch.empty_like(gates[2]) - new_cell = torch.empty_like(old_cell) - new_h = torch.empty_like(old_h) - if input.device.type == "cuda": - batch_size = old_cell.shape[0] - state_size = old_cell.shape[1] - gate_weights = gate_weights.reshape(batch_size, 3, state_size) - return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights +def _setup_context(ctx, inputs, output): + a, b, c = inputs + saved_a, saved_b = None, None + if ctx.needs_input_grad[0]: + saved_b = b + if ctx.needs_input_grad[1]: + saved_a = a + ctx.save_for_backward(saved_a, saved_b) -def reference_lltm( - input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor -) -> Tuple[Tensor, Tensor]: - X = torch.cat([old_h, input], dim=1) +# This adds training support for the operator. You must provide us +# the backward formula for the operator and a `setup_context` function +# to save values to be used in the backward. +torch.library.register_autograd( + "extension_cpp::mymuladd", _backward, setup_context=_setup_context) - # Compute the input, output and candidate cell gates with one MM. - gate_weights = torch.nn.functional.linear(X, weights, bias) - # Split the combined gate weight matrix into its components. - gates = gate_weights.chunk(3, dim=1) - input_gate = torch.sigmoid(gates[0]) - output_gate = torch.sigmoid(gates[1]) - # Here we use an ELU instead of the usual tanh. - candidate_cell = torch.nn.functional.elu(gates[2]) +@torch.library.register_fake("extension_cpp::mymul") +def _(a, b): + torch._check(a.shape == b.shape) + torch._check(a.dtype == torch.float) + torch._check(b.dtype == torch.float) + torch._check(a.device == b.device) + return torch.empty_like(a) - # Compute the new cell state. - new_cell = old_cell + candidate_cell * input_gate - # Compute the new hidden state and output. - new_h = torch.tanh(new_cell) * output_gate - return new_h, new_cell +def myadd_out(a: Tensor, b: Tensor, out: Tensor) -> None: + """Writes a + b into out""" + torch.ops.extension_cpp.myadd_out.default(a, b, out) diff --git a/test/benchmark.py b/test/benchmark.py deleted file mode 100644 index e9f4799..0000000 --- a/test/benchmark.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import division -from __future__ import print_function - -import argparse -import math -import time - -import torch - -TIME_SCALES = {"s": 1, "ms": 1000, "us": 1000000} - -parser = argparse.ArgumentParser() -parser.add_argument("example", choices=["py", "cpp", "cuda"]) -parser.add_argument("-b", "--batch-size", type=int, default=16) -parser.add_argument("-f", "--features", type=int, default=32) -parser.add_argument("-s", "--state-size", type=int, default=128) -parser.add_argument("-r", "--runs", type=int, default=100) -parser.add_argument("--scale", choices=["s", "ms", "us"], default="us") -parser.add_argument("-c", "--cuda", action="store_true") -parser.add_argument("-d", "--double", action="store_true") -options = parser.parse_args() - -if options.example == "py": - from extension_cpp.ops import reference_lltm as LLTM -else: - from extension_cpp.ops import lltm as LLTM -if options.example == "cuda": - options.cuda = True - -device = torch.device("cuda") if options.cuda else torch.device("cpu") -dtype = torch.float64 if options.double else torch.float32 - -kwargs = {"dtype": dtype, "device": device, "requires_grad": True} -batch_size = options.batch_size -features = options.features -state_size = options.state_size -X = torch.randn( - batch_size, # E: No overload variant of "randn" matches argument - features, - **kwargs -) -h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia -C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia -W = torch.randn(3 * state_size, features + state_size, **kwargs) -b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" - -# Force CUDA initialization -new_h, new_C = LLTM(X, W, b, h, C) -(new_h.sum() + new_C.sum()).backward() - -forward_min = math.inf -forward_time = 0 -backward_min = math.inf -backward_time = 0 -for _ in range(options.runs): - X.grad = None - h.grad = None - C.grad = None - W.grad = None - b.grad = None - start = time.time() - new_h, new_C = LLTM(X, W, b, h, C) - elapsed = time.time() - start - forward_min = min(forward_min, elapsed) - forward_time += elapsed - - start = time.time() - (new_h.sum() + new_C.sum()).backward() - elapsed = time.time() - start - backward_min = min(backward_min, elapsed) - backward_time += elapsed - -scale = TIME_SCALES[options.scale] -forward_min *= scale -backward_min *= scale -forward_average = forward_time / options.runs * scale -backward_average = backward_time / options.runs * scale - -print( - "Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}".format( - forward_min, forward_average, backward_min, backward_average, options.scale - ) -) diff --git a/test/test_extension.py b/test/test_extension.py index 7bc29f4..618f00b 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -8,30 +8,31 @@ import torch.nn.functional as F -def sample_inputs(device, *, requires_grad=False): - batch_size = 3 - features = 17 - state_size = 5 - kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad} - X = torch.randn( - batch_size, # E: No overload variant of "randn" matches argument - features, - **kwargs - ) - h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia - C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia - W = torch.randn(3 * state_size, features + state_size, **kwargs) - b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" - return X, W, b, h, C - - -class TestLLTM(TestCase): +def reference_muladd(a, b, c): + return a * b + c + + +class TestMyMulAdd(TestCase): + def sample_inputs(self, device, *, requires_grad=False): + def make_tensor(*size): + return torch.randn(size, device=device, requires_grad=requires_grad) + + def make_nondiff_tensor(*size): + return torch.randn(size, device=device, requires_grad=False) + + return [ + [make_tensor(3), make_tensor(3), 1], + [make_tensor(20), make_tensor(20), 3.14], + [make_tensor(20), make_nondiff_tensor(20), -123], + [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], + ] + def _test_correctness(self, device): - args = sample_inputs(device) - result = extension_cpp.ops.lltm(*args) - expected = extension_cpp.ops.reference_lltm(*args) - self.assertEqual(len(result), len(expected)) - torch.testing.assert_close(result, expected) + samples = self.sample_inputs(device) + for args in samples: + result = extension_cpp.ops.mymuladd(*args) + expected = reference_muladd(*args) + torch.testing.assert_close(result, expected) def test_correctness_cpu(self): self._test_correctness("cpu") @@ -41,23 +42,67 @@ def test_correctness_cuda(self): self._test_correctness("cuda") def _test_gradients(self, device): - args = sample_inputs(device, requires_grad=True) - # Use torch.autograd.gradcheck to check that gradients are OK - torch.autograd.gradcheck(extension_cpp.ops.lltm, args) + samples = self.sample_inputs(device, requires_grad=True) + for args in samples: + diff_tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + out = extension_cpp.ops.mymuladd(*args) + grad_out = torch.randn_like(out) + result = torch.autograd.grad(out, diff_tensors, grad_out) + + out = reference_muladd(*args) + expected = torch.autograd.grad(out, diff_tensors, grad_out) + + torch.testing.assert_close(result, expected) def test_gradients_cpu(self): self._test_gradients("cpu") - # This is supposed to succeed, there's probably a bug in the CUDA kernel. - @unittest.expectedFailure @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_gradients_cuda(self): self._test_gradients("cuda") def _opcheck(self, device): - args = sample_inputs(device) - # Use opcheck to test that the operator was written correctly. - opcheck(torch.ops.extension_cpp.lltm_forward.default, args) + # Use opcheck to check for incorrect usage of operator registration APIs + samples = self.sample_inputs(device, requires_grad=True) + samples.extend(self.sample_inputs(device, requires_grad=False)) + for args in samples: + opcheck(torch.ops.extension_cpp.mymuladd.default, args) + + def test_opcheck_cpu(self): + self._opcheck("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_opcheck_cuda(self): + self._opcheck("cuda") + + +class TestMyAddOut(TestCase): + def sample_inputs(self, device, *, requires_grad=False): + def make_tensor(*size): + return torch.randn(size, device=device, requires_grad=requires_grad) + + def make_nondiff_tensor(*size): + return torch.randn(size, device=device, requires_grad=False) + + return [ + [make_tensor(3), make_tensor(3), make_tensor(3)], + [make_tensor(20), make_tensor(20), make_tensor(20)], + ] + + def _test_correctness(self, device): + samples = self.sample_inputs(device) + for args in samples: + result = args[-1] + extension_cpp.ops.myadd_out(*args) + expected = torch.add(*args[:2]) + torch.testing.assert_close(result, expected) + + def _opcheck(self, device): + # Use opcheck to check for incorrect usage of operator registration APIs + samples = self.sample_inputs(device, requires_grad=True) + samples.extend(self.sample_inputs(device, requires_grad=False)) + for args in samples: + opcheck(torch.ops.extension_cpp.myadd_out.default, args) def test_opcheck_cpu(self): self._opcheck("cpu")