diff --git a/README.md b/README.md
index a0aae0c..709441b 100644
--- a/README.md
+++ b/README.md
@@ -2,9 +2,11 @@
 
 An example of writing a C++/CUDA extension for PyTorch. See
 [here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial.
-This repo demonstrates how to write an example `extension_cpp.ops.lltm`
+This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
 custom op that has both custom CPU and CUDA kernels.
 
+The examples in this repo work with PyTorch 2.4+.
+
 To build:
 ```
 pip install .
diff --git a/extension_cpp/csrc/cuda/lltm_cuda.cu b/extension_cpp/csrc/cuda/lltm_cuda.cu
deleted file mode 100644
index 7612b84..0000000
--- a/extension_cpp/csrc/cuda/lltm_cuda.cu
+++ /dev/null
@@ -1,183 +0,0 @@
-#include <torch/extension.h>
-
-#include <cuda.h>
-#include <cuda_runtime.h>
-
-#include <vector>
-
-namespace {
-template <typename scalar_t>
-__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
-  return 1.0 / (1.0 + exp(-z));
-}
-
-template <typename scalar_t>
-__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
-  const auto s = sigmoid(z);
-  return (1.0 - s) * s;
-}
-
-template <typename scalar_t>
-__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
-  const auto t = tanh(z);
-  return 1 - (t * t);
-}
-
-template <typename scalar_t>
-__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
-  return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0));
-}
-
-template <typename scalar_t>
-__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
-  const auto e = exp(z);
-  const auto d_relu = z < 0.0 ? 0.0 : 1.0;
-  return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);
-}
-
-template <typename scalar_t>
-__global__ void lltm_cuda_forward_kernel(
-    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
-  //batch index
-  const int n = blockIdx.y;
-  // column index
-  const int c = blockIdx.x * blockDim.x + threadIdx.x;
-  if (c < gates.size(2)){
-    input_gate[n][c] = sigmoid(gates[n][0][c]);
-    output_gate[n][c] = sigmoid(gates[n][1][c]);
-    candidate_cell[n][c] = elu(gates[n][2][c]);
-    new_cell[n][c] =
-        old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
-    new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
-  }
-}
-
-template <typename scalar_t>
-__global__ void lltm_cuda_backward_kernel(
-    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
-    torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
-    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
-    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
-  //batch index
-  const int n = blockIdx.y;
-  // column index
-  const int c = blockIdx.x * blockDim.x + threadIdx.x;
-  if (c < d_gates.size(2)){
-    const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
-    const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
-    const auto d_new_cell =
-        d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
-
-
-    d_old_cell[n][c] = d_new_cell;
-    const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
-    const auto d_input_gate = candidate_cell[n][c] * d_new_cell;
-
-    d_gates[n][0][c] =
-        d_input_gate * d_sigmoid(gate_weights[n][0][c]);
-    d_gates[n][1][c] =
-        d_output_gate * d_sigmoid(gate_weights[n][1][c]);
-    d_gates[n][2][c] =
-        d_candidate_cell * d_elu(gate_weights[n][2][c]);
-  }
-}
-} // namespace
-
-std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_cuda_forward(
-    torch::Tensor input,
-    torch::Tensor weights,
-    torch::Tensor bias,
-    torch::Tensor old_h,
-    torch::Tensor old_cell) {
-  auto X = torch::cat({old_h, input}, /*dim=*/1);
-  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
-
-  const auto batch_size = old_cell.size(0);
-  const auto state_size = old_cell.size(1);
-
-  auto gates = gate_weights.reshape({batch_size, 3, state_size});
-  auto new_h = torch::zeros_like(old_cell);
-  auto new_cell = torch::zeros_like(old_cell);
-  auto input_gate = torch::zeros_like(old_cell);
-  auto output_gate = torch::zeros_like(old_cell);
-  auto candidate_cell = torch::zeros_like(old_cell);
-
-  const int threads = 1024;
-  const dim3 blocks((state_size + threads - 1) / threads, batch_size);
-
-  AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
-    lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
-        gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
-        old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
-  }));
-
-  return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
-}
-
-std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_cuda_backward(
-    torch::Tensor grad_h,
-    torch::Tensor grad_cell,
-    torch::Tensor new_cell,
-    torch::Tensor input_gate,
-    torch::Tensor output_gate,
-    torch::Tensor candidate_cell,
-    torch::Tensor X,
-    torch::Tensor gates,
-    torch::Tensor weights) {
-  auto d_old_cell = torch::zeros_like(new_cell);
-  auto d_gates = torch::zeros_like(gates);
-
-  auto grad_h_contig = grad_h.contiguous();
-  auto grad_cell_contig = grad_cell.contiguous();
-
-  const auto batch_size = new_cell.size(0);
-  const auto state_size = new_cell.size(1);
-
-  const int threads = 1024;
-  const dim3 blocks((state_size + threads - 1) / threads, batch_size);
-
-  AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
-    lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
-        d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
-        grad_h_contig.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        grad_cell_contig.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
-        gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
-  }));
-
-  auto d_gate_weights = d_gates.flatten(1, 2);
-  auto d_weights = d_gate_weights.t().mm(X);
-  auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);
-
-  auto d_X = d_gate_weights.mm(weights);
-  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
-  auto d_input = d_X.slice(/*dim=*/1, state_size);
-
-  return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
-}
-
-// Registers CUDA implementations for lltm_forward, lltm_backward
-TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
-  m.impl("lltm_forward", &lltm_cuda_forward);
-  m.impl("lltm_backward", &lltm_cuda_backward);
-}
diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/csrc/cuda/muladd.cu
new file mode 100644
index 0000000..6700d08
--- /dev/null
+++ b/extension_cpp/csrc/cuda/muladd.cu
@@ -0,0 +1,85 @@
+#include <torch/extension.h>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+namespace extension_cpp {
+
+__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
+  int idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (idx < numel) result[idx] = a[idx] * b[idx] + c;
+}
+
+at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = result.data_ptr<float>();
+
+  int numel = a_contig.numel();
+  muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
+  return result;
+}
+
+__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) {
+  int idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (idx < numel) result[idx] = a[idx] * b[idx];
+}
+
+at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = result.data_ptr<float>();
+  int numel = a_contig.numel();
+  mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
+  return result;
+}
+
+__global__ void add_kernel(int numel, const float* a, const float* b, float* result) {
+  int idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (idx < numel) result[idx] = a[idx] * b[idx];
+}
+
+void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(b.sizes() == out.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_CHECK(out.dtype() == at::kFloat);
+  TORCH_CHECK(out.is_contiguous());
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
+  TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = out.data_ptr<float>();
+  int numel = a_contig.numel();
+  add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
+}
+
+
+// Registers CUDA implementations for mymuladd, mymul, myadd_out
+TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
+  m.impl("mymuladd", &mymuladd_cuda);
+  m.impl("mymul", &mymul_cuda);
+  m.impl("myadd_out", &myadd_out_cuda);
+}
+
+}
diff --git a/extension_cpp/csrc/lltm.cpp b/extension_cpp/csrc/lltm.cpp
deleted file mode 100644
index c915dd9..0000000
--- a/extension_cpp/csrc/lltm.cpp
+++ /dev/null
@@ -1,101 +0,0 @@
-#include <torch/extension.h>
-
-#include <vector>
-
-// s'(z) = (1 - s(z)) * s(z)
-torch::Tensor d_sigmoid(torch::Tensor z) {
-  auto s = torch::sigmoid(z);
-  return (1 - s) * s;
-}
-
-// tanh'(z) = 1 - tanh^2(z)
-torch::Tensor d_tanh(torch::Tensor z) {
-  return 1 - z.tanh().pow(2);
-}
-
-// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
-torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
-  auto e = z.exp();
-  auto mask = (alpha * (e - 1)) < 0;
-  return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
-}
-
-std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_forward(
-    torch::Tensor input,
-    torch::Tensor weights,
-    torch::Tensor bias,
-    torch::Tensor old_h,
-    torch::Tensor old_cell) {
-  auto X = torch::cat({old_h, input}, /*dim=*/1);
-
-  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
-  auto gates = gate_weights.chunk(3, /*dim=*/1);
-
-  auto input_gate = torch::sigmoid(gates[0]);
-  auto output_gate = torch::sigmoid(gates[1]);
-  auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
-
-  auto new_cell = old_cell + candidate_cell * input_gate;
-  auto new_h = torch::tanh(new_cell) * output_gate;
-
-  return {new_h,
-          new_cell,
-          input_gate,
-          output_gate,
-          candidate_cell,
-          X,
-          gate_weights};
-}
-
-std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor> lltm_backward(
-    torch::Tensor grad_h,
-    torch::Tensor grad_cell,
-    torch::Tensor new_cell,
-    torch::Tensor input_gate,
-    torch::Tensor output_gate,
-    torch::Tensor candidate_cell,
-    torch::Tensor X,
-    torch::Tensor gate_weights,
-    torch::Tensor weights) {
-  auto d_output_gate = torch::tanh(new_cell) * grad_h;
-  auto d_tanh_new_cell = output_gate * grad_h;
-  auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
-
-  auto d_old_cell = d_new_cell;
-  auto d_candidate_cell = input_gate * d_new_cell;
-  auto d_input_gate = candidate_cell * d_new_cell;
-
-  auto gates = gate_weights.chunk(3, /*dim=*/1);
-  d_input_gate *= d_sigmoid(gates[0]);
-  d_output_gate *= d_sigmoid(gates[1]);
-  d_candidate_cell *= d_elu(gates[2]);
-
-  auto d_gates =
-      torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
-
-  auto d_weights = d_gates.t().mm(X);
-  auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
-
-  auto d_X = d_gates.mm(weights);
-  const auto state_size = grad_h.size(1);
-  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
-  auto d_input = d_X.slice(/*dim=*/1, state_size);
-
-  return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
-}
-
-// Registers _C as an extension module.
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
-
-// Defines the operators
-TORCH_LIBRARY(extension_cpp, m) {
-  m.impl_abstract_pystub("extension_cpp.ops");
-  m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
-  m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
-}
-
-// Registers CPU implementations for lltm_forward, lltm_backward
-TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
-  m.impl("lltm_forward", &lltm_forward);
-  m.impl("lltm_backward", &lltm_backward);
-}
diff --git a/extension_cpp/csrc/muladd.cpp b/extension_cpp/csrc/muladd.cpp
new file mode 100644
index 0000000..73b8f18
--- /dev/null
+++ b/extension_cpp/csrc/muladd.cpp
@@ -0,0 +1,81 @@
+#include <torch/extension.h>
+
+#include <vector>
+
+namespace extension_cpp {
+
+at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = result.data_ptr<float>();
+  for (int64_t i = 0; i < result.numel(); i++) {
+    result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
+  }
+  return result;
+}
+
+at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = result.data_ptr<float>();
+  for (int64_t i = 0; i < result.numel(); i++) {
+    result_ptr[i] = a_ptr[i] * b_ptr[i];
+  }
+  return result;
+}
+
+// An example of an operator that mutates one of its inputs.
+void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
+  TORCH_CHECK(a.sizes() == b.sizes());
+  TORCH_CHECK(b.sizes() == out.sizes());
+  TORCH_CHECK(a.dtype() == at::kFloat);
+  TORCH_CHECK(b.dtype() == at::kFloat);
+  TORCH_CHECK(out.dtype() == at::kFloat);
+  TORCH_CHECK(out.is_contiguous());
+  TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
+  TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
+  TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
+  at::Tensor a_contig = a.contiguous();
+  at::Tensor b_contig = b.contiguous();
+  const float* a_ptr = a_contig.data_ptr<float>();
+  const float* b_ptr = b_contig.data_ptr<float>();
+  float* result_ptr = out.data_ptr<float>();
+  for (int64_t i = 0; i < out.numel(); i++) {
+    result_ptr[i] = a_ptr[i] + b_ptr[i];
+  }
+}
+
+// Registers _C as a Python extension module.
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
+
+// Defines the operators
+TORCH_LIBRARY(extension_cpp, m) {
+  m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
+  m.def("mymul(Tensor a, Tensor b) -> Tensor");
+  m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
+}
+
+// Registers CUDA implementations for mymuladd, mymul, myadd_out
+TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
+  m.impl("mymuladd", &mymuladd_cpu);
+  m.impl("mymul", &mymul_cpu);
+  m.impl("myadd_out", &myadd_out_cpu);
+}
+
+}
diff --git a/extension_cpp/ops.py b/extension_cpp/ops.py
index 16c0311..4d8982e 100644
--- a/extension_cpp/ops.py
+++ b/extension_cpp/ops.py
@@ -1,78 +1,63 @@
-from typing import Tuple
 import torch
 from torch import Tensor
 
-__all__ = ["lltm", "reference_lltm"]
+__all__ = ["mymuladd", "myadd_out"]
 
 
-def lltm(
-    input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
-) -> Tuple[Tensor, Tensor]:
-    return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
+def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor:
+    """Performs a * b + c in an efficient fused kernel"""
+    return torch.ops.extension_cpp.mymuladd.default(a, b, c)
 
 
-class LLTMFunction(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input, weights, bias, old_h, old_cell):
-        outputs = torch.ops.extension_cpp.lltm_forward.default(
-            input, weights, bias, old_h, old_cell
-        )
-        new_h, new_cell = outputs[:2]
-        variables = list(outputs[1:]) + [weights]
-        ctx.save_for_backward(*variables)
+# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl")
+# that describes what the properties of the output Tensor are given
+# the properties of the input Tensor. The FakeTensor kernel is necessary
+# for the op to work performantly with torch.compile.
+@torch.library.register_fake("extension_cpp::mymuladd")
+def _(a, b, c):
+    torch._check(a.shape == b.shape)
+    torch._check(a.dtype == torch.float)
+    torch._check(b.dtype == torch.float)
+    torch._check(a.device == b.device)
+    return torch.empty_like(a)
 
-        return new_h, new_cell
 
-    @staticmethod
-    @torch.autograd.function.once_differentiable
-    def backward(ctx, grad_h, grad_cell):
-        (
-            d_old_h,
-            d_input,
-            d_weights,
-            d_bias,
-            d_old_cell,
-        ) = torch.ops.extension_cpp.lltm_backward.default(
-            grad_h, grad_cell, *ctx.saved_tensors
-        )
-        return d_input, d_weights, d_bias, d_old_h, d_old_cell
+def _backward(ctx, grad):
+    a, b = ctx.saved_tensors
+    grad_a, grad_b = None, None
+    if ctx.needs_input_grad[0]:
+        grad_a = torch.ops.extension_cpp.mymul.default(grad, b)
+    if ctx.needs_input_grad[1]:
+        grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
+    return grad_a, grad_b, None
 
 
-@torch.library.impl_abstract("extension_cpp::lltm_forward")
-def _(input, weights, bias, old_h, old_cell):
-    X = torch.cat([old_h, input], dim=1)
-    gate_weights = torch.nn.functional.linear(X, weights, bias)
-    gates = gate_weights.chunk(3, dim=1)
-    input_gate = torch.empty_like(gates[0])
-    output_gate = torch.empty_like(gates[1])
-    candidate_cell = torch.empty_like(gates[2])
-    new_cell = torch.empty_like(old_cell)
-    new_h = torch.empty_like(old_h)
-    if input.device.type == "cuda":
-        batch_size = old_cell.shape[0]
-        state_size = old_cell.shape[1]
-        gate_weights = gate_weights.reshape(batch_size, 3, state_size)
-    return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights
+def _setup_context(ctx, inputs, output):
+    a, b, c = inputs
+    saved_a, saved_b = None, None
+    if ctx.needs_input_grad[0]:
+        saved_b = b
+    if ctx.needs_input_grad[1]:
+        saved_a = a
+    ctx.save_for_backward(saved_a, saved_b)
 
 
-def reference_lltm(
-    input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
-) -> Tuple[Tensor, Tensor]:
-    X = torch.cat([old_h, input], dim=1)
+# This adds training support for the operator. You must provide us
+# the backward formula for the operator and a `setup_context` function
+# to save values to be used in the backward.
+torch.library.register_autograd(
+    "extension_cpp::mymuladd", _backward, setup_context=_setup_context)
 
-    # Compute the input, output and candidate cell gates with one MM.
-    gate_weights = torch.nn.functional.linear(X, weights, bias)
-    # Split the combined gate weight matrix into its components.
-    gates = gate_weights.chunk(3, dim=1)
 
-    input_gate = torch.sigmoid(gates[0])
-    output_gate = torch.sigmoid(gates[1])
-    # Here we use an ELU instead of the usual tanh.
-    candidate_cell = torch.nn.functional.elu(gates[2])
+@torch.library.register_fake("extension_cpp::mymul")
+def _(a, b):
+    torch._check(a.shape == b.shape)
+    torch._check(a.dtype == torch.float)
+    torch._check(b.dtype == torch.float)
+    torch._check(a.device == b.device)
+    return torch.empty_like(a)
 
-    # Compute the new cell state.
-    new_cell = old_cell + candidate_cell * input_gate
-    # Compute the new hidden state and output.
-    new_h = torch.tanh(new_cell) * output_gate
 
-    return new_h, new_cell
+def myadd_out(a: Tensor, b: Tensor, out: Tensor) -> None:
+    """Writes a + b into out"""
+    torch.ops.extension_cpp.myadd_out.default(a, b, out)
diff --git a/test/benchmark.py b/test/benchmark.py
deleted file mode 100644
index e9f4799..0000000
--- a/test/benchmark.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import math
-import time
-
-import torch
-
-TIME_SCALES = {"s": 1, "ms": 1000, "us": 1000000}
-
-parser = argparse.ArgumentParser()
-parser.add_argument("example", choices=["py", "cpp", "cuda"])
-parser.add_argument("-b", "--batch-size", type=int, default=16)
-parser.add_argument("-f", "--features", type=int, default=32)
-parser.add_argument("-s", "--state-size", type=int, default=128)
-parser.add_argument("-r", "--runs", type=int, default=100)
-parser.add_argument("--scale", choices=["s", "ms", "us"], default="us")
-parser.add_argument("-c", "--cuda", action="store_true")
-parser.add_argument("-d", "--double", action="store_true")
-options = parser.parse_args()
-
-if options.example == "py":
-    from extension_cpp.ops import reference_lltm as LLTM
-else:
-    from extension_cpp.ops import lltm as LLTM
-if options.example == "cuda":
-    options.cuda = True
-
-device = torch.device("cuda") if options.cuda else torch.device("cpu")
-dtype = torch.float64 if options.double else torch.float32
-
-kwargs = {"dtype": dtype, "device": device, "requires_grad": True}
-batch_size = options.batch_size
-features = options.features
-state_size = options.state_size
-X = torch.randn(
-    batch_size,  # E: No overload variant of "randn" matches argument
-    features,
-    **kwargs
-)
-h = torch.randn(batch_size, state_size, **kwargs)  # E: No overload varia
-C = torch.randn(batch_size, state_size, **kwargs)  # E: No overload varia
-W = torch.randn(3 * state_size, features + state_size, **kwargs)
-b = torch.randn(1, 3 * state_size, **kwargs)  # E: No overload variant of "randn"
-
-# Force CUDA initialization
-new_h, new_C = LLTM(X, W, b, h, C)
-(new_h.sum() + new_C.sum()).backward()
-
-forward_min = math.inf
-forward_time = 0
-backward_min = math.inf
-backward_time = 0
-for _ in range(options.runs):
-    X.grad = None
-    h.grad = None
-    C.grad = None
-    W.grad = None
-    b.grad = None
-    start = time.time()
-    new_h, new_C = LLTM(X, W, b, h, C)
-    elapsed = time.time() - start
-    forward_min = min(forward_min, elapsed)
-    forward_time += elapsed
-
-    start = time.time()
-    (new_h.sum() + new_C.sum()).backward()
-    elapsed = time.time() - start
-    backward_min = min(backward_min, elapsed)
-    backward_time += elapsed
-
-scale = TIME_SCALES[options.scale]
-forward_min *= scale
-backward_min *= scale
-forward_average = forward_time / options.runs * scale
-backward_average = backward_time / options.runs * scale
-
-print(
-    "Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}".format(
-        forward_min, forward_average, backward_min, backward_average, options.scale
-    )
-)
diff --git a/test/test_extension.py b/test/test_extension.py
index 7bc29f4..618f00b 100644
--- a/test/test_extension.py
+++ b/test/test_extension.py
@@ -8,30 +8,31 @@
 import torch.nn.functional as F
 
 
-def sample_inputs(device, *, requires_grad=False):
-    batch_size = 3
-    features = 17
-    state_size = 5
-    kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad}
-    X = torch.randn(
-        batch_size,  # E: No overload variant of "randn" matches argument
-        features,
-        **kwargs
-    )
-    h = torch.randn(batch_size, state_size, **kwargs)  # E: No overload varia
-    C = torch.randn(batch_size, state_size, **kwargs)  # E: No overload varia
-    W = torch.randn(3 * state_size, features + state_size, **kwargs)
-    b = torch.randn(1, 3 * state_size, **kwargs)  # E: No overload variant of "randn"
-    return X, W, b, h, C
-
-
-class TestLLTM(TestCase):
+def reference_muladd(a, b, c):
+    return a * b + c
+
+
+class TestMyMulAdd(TestCase):
+    def sample_inputs(self, device, *, requires_grad=False):
+        def make_tensor(*size):
+            return torch.randn(size, device=device, requires_grad=requires_grad)
+
+        def make_nondiff_tensor(*size):
+            return torch.randn(size, device=device, requires_grad=False)
+
+        return [
+            [make_tensor(3), make_tensor(3), 1],
+            [make_tensor(20), make_tensor(20), 3.14],
+            [make_tensor(20), make_nondiff_tensor(20), -123],
+            [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
+        ]
+
     def _test_correctness(self, device):
-        args = sample_inputs(device)
-        result = extension_cpp.ops.lltm(*args)
-        expected = extension_cpp.ops.reference_lltm(*args)
-        self.assertEqual(len(result), len(expected))
-        torch.testing.assert_close(result, expected)
+        samples = self.sample_inputs(device)
+        for args in samples:
+            result = extension_cpp.ops.mymuladd(*args)
+            expected = reference_muladd(*args)
+            torch.testing.assert_close(result, expected)
 
     def test_correctness_cpu(self):
         self._test_correctness("cpu")
@@ -41,23 +42,67 @@ def test_correctness_cuda(self):
         self._test_correctness("cuda")
 
     def _test_gradients(self, device):
-        args = sample_inputs(device, requires_grad=True)
-        # Use torch.autograd.gradcheck to check that gradients are OK
-        torch.autograd.gradcheck(extension_cpp.ops.lltm, args)
+        samples = self.sample_inputs(device, requires_grad=True)
+        for args in samples:
+            diff_tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
+            out = extension_cpp.ops.mymuladd(*args)
+            grad_out = torch.randn_like(out)
+            result = torch.autograd.grad(out, diff_tensors, grad_out)
+
+            out = reference_muladd(*args)
+            expected = torch.autograd.grad(out, diff_tensors, grad_out)
+
+            torch.testing.assert_close(result, expected)
 
     def test_gradients_cpu(self):
         self._test_gradients("cpu")
 
-    # This is supposed to succeed, there's probably a bug in the CUDA kernel.
-    @unittest.expectedFailure
     @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
     def test_gradients_cuda(self):
         self._test_gradients("cuda")
 
     def _opcheck(self, device):
-        args = sample_inputs(device)
-        # Use opcheck to test that the operator was written correctly.
-        opcheck(torch.ops.extension_cpp.lltm_forward.default, args)
+        # Use opcheck to check for incorrect usage of operator registration APIs
+        samples = self.sample_inputs(device, requires_grad=True)
+        samples.extend(self.sample_inputs(device, requires_grad=False))
+        for args in samples:
+            opcheck(torch.ops.extension_cpp.mymuladd.default, args)
+
+    def test_opcheck_cpu(self):
+        self._opcheck("cpu")
+
+    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+    def test_opcheck_cuda(self):
+        self._opcheck("cuda")
+
+
+class TestMyAddOut(TestCase):
+    def sample_inputs(self, device, *, requires_grad=False):
+        def make_tensor(*size):
+            return torch.randn(size, device=device, requires_grad=requires_grad)
+
+        def make_nondiff_tensor(*size):
+            return torch.randn(size, device=device, requires_grad=False)
+
+        return [
+            [make_tensor(3), make_tensor(3), make_tensor(3)],
+            [make_tensor(20), make_tensor(20), make_tensor(20)],
+        ]
+
+    def _test_correctness(self, device):
+        samples = self.sample_inputs(device)
+        for args in samples:
+            result = args[-1]
+            extension_cpp.ops.myadd_out(*args)
+            expected = torch.add(*args[:2])
+            torch.testing.assert_close(result, expected)
+
+    def _opcheck(self, device):
+        # Use opcheck to check for incorrect usage of operator registration APIs
+        samples = self.sample_inputs(device, requires_grad=True)
+        samples.extend(self.sample_inputs(device, requires_grad=False))
+        for args in samples:
+            opcheck(torch.ops.extension_cpp.myadd_out.default, args)
 
     def test_opcheck_cpu(self):
         self._opcheck("cpu")