Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lltm with myaddmul; update to new custom ops APIs #94

Merged
merged 2 commits into from
May 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Replace lltm with myaddmul; update to new custom ops APIs
We're replacing lltm with myaddmul(a: Tensor, b: Tensor, c: float),
which just does a*b+c. This simplification allows us to focus on the
operator registration instead of get lost in the details of the
complicated lltm kernels.

Test Plan:
- tests

[ghstack-poisoned]
zou3519 committed May 24, 2024

Verified

This commit was signed with the committer’s verified signature.
KyleFromNVIDIA Kyle Edwards
commit 1979f756d9cbbd33b5b64b35e7d7b5b58a6b7c0e
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

An example of writing a C++/CUDA extension for PyTorch. See
[here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial.
This repo demonstrates how to write an example `extension_cpp.ops.lltm`
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
custom op that has both custom CPU and CUDA kernels.

To build:
183 changes: 0 additions & 183 deletions extension_cpp/csrc/cuda/lltm_cuda.cu

This file was deleted.

59 changes: 59 additions & 0 deletions extension_cpp/csrc/cuda/muladd.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

namespace extension_cpp {

__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}

at::Tensor mymuladd_cuda(at::Tensor a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();

int numel = a_contig.numel();
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}

__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx];
}

at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
return result;
}

// Registers CUDA implementations for mymuladd, mymul
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
m.impl("mymul", &mymul_cuda);
}

}
101 changes: 0 additions & 101 deletions extension_cpp/csrc/lltm.cpp

This file was deleted.

58 changes: 58 additions & 0 deletions extension_cpp/csrc/muladd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <torch/extension.h>

#include <vector>

namespace extension_cpp {

at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}

at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i];
}
return result;
}

// Registers _C as a Python extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

// Defines the operators
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
}

// Registers CPU implementations for muladd, mul
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
m.impl("mymul", &mymul_cpu);
}

}
130 changes: 55 additions & 75 deletions extension_cpp/ops.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,58 @@
from typing import Tuple
import torch
from torch import Tensor

__all__ = ["lltm", "reference_lltm"]


def lltm(
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
) -> Tuple[Tensor, Tensor]:
return LLTMFunction.apply(input, weights, bias, old_h, old_cell)


class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = torch.ops.extension_cpp.lltm_forward.default(
input, weights, bias, old_h, old_cell
)
new_h, new_cell = outputs[:2]
variables = list(outputs[1:]) + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_h, grad_cell):
(
d_old_h,
d_input,
d_weights,
d_bias,
d_old_cell,
) = torch.ops.extension_cpp.lltm_backward.default(
grad_h, grad_cell, *ctx.saved_tensors
)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


@torch.library.impl_abstract("extension_cpp::lltm_forward")
def _(input, weights, bias, old_h, old_cell):
X = torch.cat([old_h, input], dim=1)
gate_weights = torch.nn.functional.linear(X, weights, bias)
gates = gate_weights.chunk(3, dim=1)
input_gate = torch.empty_like(gates[0])
output_gate = torch.empty_like(gates[1])
candidate_cell = torch.empty_like(gates[2])
new_cell = torch.empty_like(old_cell)
new_h = torch.empty_like(old_h)
if input.device.type == "cuda":
batch_size = old_cell.shape[0]
state_size = old_cell.shape[1]
gate_weights = gate_weights.reshape(batch_size, 3, state_size)
return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights


def reference_lltm(
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
) -> Tuple[Tensor, Tensor]:
X = torch.cat([old_h, input], dim=1)

# Compute the input, output and candidate cell gates with one MM.
gate_weights = torch.nn.functional.linear(X, weights, bias)
# Split the combined gate weight matrix into its components.
gates = gate_weights.chunk(3, dim=1)

input_gate = torch.sigmoid(gates[0])
output_gate = torch.sigmoid(gates[1])
# Here we use an ELU instead of the usual tanh.
candidate_cell = torch.nn.functional.elu(gates[2])

# Compute the new cell state.
new_cell = old_cell + candidate_cell * input_gate
# Compute the new hidden state and output.
new_h = torch.tanh(new_cell) * output_gate

return new_h, new_cell
__all__ = ["mymuladd"]


def mymuladd(a: Tensor, b: Tensor, c: float):
"""Your docstring here"""
return torch.ops.extension_cpp.mymuladd.default(a, b, c)


# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl")
# that describes what the properties of the output Tensor are given
# the properties of the input Tensor. The FakeTensor kernel is necessary
# for the op to work performantly with torch.compile.
@torch.library.register_fake("extension_cpp::mymuladd")
def _(a, b, c):
torch._check(a.shape == b.shape)
torch._check(a.dtype == torch.float)
torch._check(b.dtype == torch.float)
torch._check(a.device == b.device)
return torch.empty_like(a)


def _backward(ctx, grad):
a, b = ctx.saved_tensors
grad_a, grad_b = None, None
if ctx.needs_input_grad[0]:
grad_a = torch.ops.extension_cpp.mymul.default(grad, b)
if ctx.needs_input_grad[1]:
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
return grad_a, grad_b, None


def _setup_context(ctx, inputs, output):
a, b, c = inputs
saved_a, saved_b = None, None
if ctx.needs_input_grad[0]:
saved_b = b
if ctx.needs_input_grad[1]:
saved_a = a
ctx.save_for_backward(saved_a, saved_b)


# This adds training support for the operator. You must provide us
# the backward formula for the operator and a `setup_context` function
# to save values to be used in the backward.
torch.library.register_autograd(
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)


@torch.library.register_fake("extension_cpp::mymul")
def _(a, b):
torch._check(a.shape == b.shape)
torch._check(a.dtype == torch.float)
torch._check(b.dtype == torch.float)
torch._check(a.device == b.device)
return torch.empty_like(a)
83 changes: 0 additions & 83 deletions test/benchmark.py

This file was deleted.

70 changes: 40 additions & 30 deletions test/test_extension.py
Original file line number Diff line number Diff line change
@@ -9,29 +9,31 @@


def sample_inputs(device, *, requires_grad=False):
batch_size = 3
features = 17
state_size = 5
kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad}
X = torch.randn(
batch_size, # E: No overload variant of "randn" matches argument
features,
**kwargs
)
h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia
C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia
W = torch.randn(3 * state_size, features + state_size, **kwargs)
b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn"
return X, W, b, h, C


class TestLLTM(TestCase):
def make_tensor(*size):
return torch.randn(size, device=device, requires_grad=requires_grad)

def make_nondiff_tensor(*size):
return torch.randn(size, device=device, requires_grad=False)

return [
[make_tensor(3), make_tensor(3), 1],
[make_tensor(20), make_tensor(20), 3.14],
[make_tensor(20), make_nondiff_tensor(20), -123],
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
]


def reference_muladd(a, b, c):
return a * b + c


class TestMyMulAdd(TestCase):
def _test_correctness(self, device):
args = sample_inputs(device)
result = extension_cpp.ops.lltm(*args)
expected = extension_cpp.ops.reference_lltm(*args)
self.assertEqual(len(result), len(expected))
torch.testing.assert_close(result, expected)
samples = sample_inputs(device)
for args in samples:
result = extension_cpp.ops.mymuladd(*args)
expected = reference_muladd(*args)
torch.testing.assert_close(result, expected)

def test_correctness_cpu(self):
self._test_correctness("cpu")
@@ -41,23 +43,31 @@ def test_correctness_cuda(self):
self._test_correctness("cuda")

def _test_gradients(self, device):
args = sample_inputs(device, requires_grad=True)
# Use torch.autograd.gradcheck to check that gradients are OK
torch.autograd.gradcheck(extension_cpp.ops.lltm, args)
samples = sample_inputs(device, requires_grad=True)
for args in samples:
diff_tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
out = extension_cpp.ops.mymuladd(*args)
grad_out = torch.randn_like(out)
result = torch.autograd.grad(out, diff_tensors, grad_out)

out = reference_muladd(*args)
expected = torch.autograd.grad(out, diff_tensors, grad_out)

torch.testing.assert_close(result, expected)

def test_gradients_cpu(self):
self._test_gradients("cpu")

# This is supposed to succeed, there's probably a bug in the CUDA kernel.
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_gradients_cuda(self):
self._test_gradients("cuda")

def _opcheck(self, device):
args = sample_inputs(device)
# Use opcheck to test that the operator was written correctly.
opcheck(torch.ops.extension_cpp.lltm_forward.default, args)
# Use opcheck to check for incorrect usage of operator registration APIs
samples = sample_inputs(device, requires_grad=True)
samples.extend(sample_inputs(device, requires_grad=False))
for args in samples:
opcheck(torch.ops.extension_cpp.mymuladd.default, args)

def test_opcheck_cpu(self):
self._opcheck("cpu")