diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td
index 839bd81d9c..2086c61b80 100644
--- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td
+++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td
@@ -885,6 +885,58 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
     let hasVerifier = 1;
 }
 
+def TTIR_ConvTranspose2dOp : TTIR_DPSOp<"conv_transpose2d"> {
+    let summary = "ConvTranspose2d operation.";
+    let description = [{
+      Applies a 2D transposed convolution operator over an input image composed of several input planes.
+
+      Inputs:
+      - `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
+      - `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
+      - `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
+      - `output` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
+
+      Attributes:
+      - `stride` (i32 | array<i32>): Controls the stride for the cross-correlation.
+      - `padding` (i32 | array<i32>): Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
+      - `output_padding` (i32 | array<i32>): Controls the additional size added to one side of the output shape.
+      - `dilation` (i32 | array<i32>): Controls the spacing between the kernel points
+      - `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.
+
+      Example:
+        %input = tensor.empty() : () -> tensor<256x256x3x3xbf16>
+        %weight = tensor.empty() : () -> tensor<256x256x3x3xbf16>
+        %bias = tensor.empty() : () -> tensor<1x1x1x256xbf16>
+        %output = tensor.empty() : () -> tensor<1x10x10x256xbf16>
+        %0 = "ttir.conv_transpose2d"(%input, %weight, %bias, %output)
+          <{
+            stride = = array<i32: 1, 1>,
+            padding = 0: i32,
+            output_padding = 0: i32,
+            dilation = 1: i32,
+            groups = 1: i32
+          > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    }];
+
+    let arguments = (ins AnyRankedTensor:$input,
+                         AnyRankedTensor:$weight,
+                         Optional<AnyRankedTensor>:$bias,
+                         AnyRankedTensor:$output,
+                         AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$stride,
+                         AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$padding,
+                         AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$output_padding,
+                         AnyAttrOf<[I32Attr, DenseI32ArrayAttr]>:$dilation,
+                         I32Attr:$groups);
+
+    let results = (outs AnyRankedTensor:$result);
+
+    let extraClassDeclaration = [{
+      MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
+    }];
+
+    let hasVerifier = 1;
+}
+
 def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
   let summary = "Generalized convolution op.";
   let description = [{
diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
index ba2484ac5f..5609810f12 100644
--- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
+++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
@@ -861,6 +861,57 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
     let hasVerifier = 1;
 }
 
+def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
+    let summary = "ConvTranspose2d operation.";
+    let description = [{
+      Applies a 2D transposed convolution operator over an input image composed of several input planes.
+
+      Inputs:
+      - `input` AnyRankedTensor: NHWC format (batch_size x height x width x channels)
+      - `weight` AnyRankedTensor: OIHW format (output_channels x input_channels x height x width)
+      - `bias` Optional<AnyRankedTensor>: (1 x 1 x 1 x output_channels)
+      - `output` AnyRankedTensor: (1 x 1 x (batch_size * height * width) x channels)
+
+      Attributes:
+      - `in_channels` i32: The number of input channels.
+      - `out_channels` i32: The number of output channels.
+      - `batch_size` i32: The batch size.
+      - `input_height` i32: The input height.
+      - `input_width` i32: The input width.
+      - `kernel_size` array<i32>: The kernel size.
+      - `stride` array<i32>: Controls the stride for the cross-correlation.
+      - `padding` array<i32>: Controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points.
+      - `output_padding` array<i32>: Controls the additional size added to one side of the output shape.
+      - `dilation` array<i32>: Controls the spacing between the kernel points
+      - `groups` i32: Controls the connections between inputs and outputs. Must be divisible by input and output channels.
+    }];
+
+    let arguments = (ins AnyRankedTensor:$input,
+                         AnyRankedTensor:$weight,
+                         Optional<AnyRankedTensor>:$bias,
+                         AnyRankedTensor:$output,
+                         TT_Device:$device,
+                         I32Attr:$in_channels,
+                         I32Attr:$out_channels,
+                         I32Attr:$batch_size,
+                         I32Attr:$input_height,
+                         I32Attr:$input_width,
+                         DenseI32ArrayAttr:$kernel_size,
+                         DenseI32ArrayAttr:$stride,
+                         DenseI32ArrayAttr:$padding,
+                         DenseI32ArrayAttr:$output_padding,
+                         DenseI32ArrayAttr:$dilation,
+                         I32Attr:$groups);
+
+    let results = (outs AnyRankedTensor:$result);
+
+    let extraClassDeclaration = [{
+      MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
+    }];
+
+    let hasVerifier = 1;
+}
+
 def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
     let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
     let description = [{
diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs
index b56cdb39ab..461cea532e 100644
--- a/include/ttmlir/Target/TTNN/program.fbs
+++ b/include/ttmlir/Target/TTNN/program.fbs
@@ -269,6 +269,25 @@ table Conv2dOp {
   groups: uint32;
 }
 
+table ConvTranspose2dOp {
+  input: tt.target.TensorRef;
+  weight: tt.target.TensorRef;
+  bias: tt.target.TensorRef;
+  out: tt.target.TensorRef;
+  device: tt.target.DeviceRef;
+  in_channels: uint32;
+  out_channels: uint32;
+  batch_size: uint32;
+  input_height: uint32;
+  input_width: uint32;
+  kernel_size: [int32];
+  stride: [int32];
+  padding: [int32];
+  output_padding: [int32];
+  dilation: [int32];
+  groups: uint32;
+}
+
 table MaxPool2dOp {
   in: tt.target.TensorRef;
   out: tt.target.TensorRef;
@@ -346,6 +365,7 @@ union OpType {
   SoftmaxOp,
   TransposeOp,
   Conv2dOp,
+  ConvTranspose2dOp,
   ConcatOp,
   ReshapeOp,
   SliceOp,
diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
index 2e84eb3471..3dbeff1aa2 100644
--- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
+++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
@@ -12,6 +12,7 @@
 #include "ttmlir/Dialect/TTNN/Types/Types.h"
 #include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
 #include "ttmlir/Dialect/TTNN/Utils/Utils.h"
+#include "ttmlir/Utils.h"
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Attributes.h"
@@ -26,6 +27,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <llvm/Support/LogicalResult.h>
 
 #include <cstdint>
 
@@ -883,6 +885,105 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
   }
 };
 
+class ConvTranspose2dOpConversionPattern
+    : public OpConversionPattern<ttir::ConvTranspose2dOp> {
+public:
+  using OpConversionPattern<ttir::ConvTranspose2dOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ttir::ConvTranspose2dOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
+
+    auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
+    auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
+    auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());
+
+    std::function<int64_t(const RankedTensorType &, int)> getLastDim =
+        [](const RankedTensorType &ty, int offset = 1) {
+          return ty.getShape()[ty.getRank() - offset];
+        };
+
+    auto inChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 1));
+    auto outChannelsAttr = rewriter.getI32IntegerAttr(getLastDim(outputTy, 1));
+    auto batchSizeAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 4));
+    auto inputHeightAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 3));
+    auto inputWidthAttr = rewriter.getI32IntegerAttr(getLastDim(inputTy, 2));
+
+    auto kernelSizeAttr = rewriter.getDenseI32ArrayAttr(
+        {static_cast<int32_t>(getLastDim(kernelTy, 2)),
+         static_cast<int32_t>(getLastDim(kernelTy, 1))});
+
+    auto strideAttr = attrToDenseI32ArrayAttr(adaptor.getStride(), rewriter);
+    if (auto error = strideAttr.takeError()) {
+      return LogicalResult::failure();
+    }
+
+    auto paddingAttr = attrToDenseI32ArrayAttr(adaptor.getPadding(), rewriter);
+    if (auto error = paddingAttr.takeError()) {
+      return LogicalResult::failure();
+    }
+
+    auto outputPaddingAttr =
+        attrToDenseI32ArrayAttr(adaptor.getOutputPadding(), rewriter);
+    if (auto error = outputPaddingAttr.takeError()) {
+      return LogicalResult::failure();
+    }
+
+    auto dilationAttr =
+        attrToDenseI32ArrayAttr(adaptor.getDilation(), rewriter);
+    if (auto error = dilationAttr.takeError()) {
+      return LogicalResult::failure();
+    }
+
+    auto groupsAttr = rewriter.getI32IntegerAttr(adaptor.getGroups());
+
+    // Transposed convolution in ttnn returns a tensor in a flattened shape
+    // (1 x 1 x N * H * W x C)
+    llvm::ArrayRef<std::int64_t> output_shape = outputTy.getShape();
+    llvm::SmallVector<std::int64_t, 4> flattenedOutputShape = {
+        1, 1, output_shape[0] * output_shape[1] * output_shape[2],
+        output_shape[3]};
+    outputTy = mlir::cast<RankedTensorType>(getTypeConverter()->convertType(
+        outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType())));
+
+    // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
+    // attribute determination
+    auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
+        adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
+        outputTy.getElementType());
+
+    // Must set the type to the output type to maintain the layout attributes
+    convDPSOutput.getResult().setType(outputTy);
+
+    ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
+        op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
+        adaptor.getBias(), convDPSOutput, device, inChannelsAttr,
+        outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
+        kernelSizeAttr, *strideAttr, *paddingAttr, *outputPaddingAttr,
+        *dilationAttr, groupsAttr);
+
+    // Restore the normal shape (N x H x W x C)
+    Value output =
+        ttir_to_ttnn::utils::generateReshape(new_conv, output_shape, rewriter);
+
+    rewriter.replaceOp(op, output);
+    return success();
+  }
+
+private:
+  llvm::Expected<DenseI32ArrayAttr>
+  attrToDenseI32ArrayAttr(mlir::Attribute attr,
+                          ConversionPatternRewriter &rewriter) const {
+    auto pair = ttmlir::utils::getPairOfInteger<int32_t>(attr);
+    if (auto error = pair.takeError()) {
+      return error;
+    }
+
+    return rewriter.getDenseI32ArrayAttr({pair->first, pair->second});
+  }
+};
+
 class MaxPool2dOpConversionPattern
     : public OpConversionPattern<ttir::MaxPool2dOp> {
 public:
@@ -1223,6 +1324,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
            LinearOpConversionPattern,
            MatmulOpConversionPattern,
            Conv2dOpConversionPattern,
+           ConvTranspose2dOpConversionPattern,
            MaxPool2dOpConversionPattern,
            SubtractOpConversionPattern,
            MeshShardOpConversionPattern,
diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
index f92e730baf..93940c0d5f 100644
--- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
+++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
@@ -851,6 +851,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
   // Conv ops
   //
   patterns.add<DefaultOpConversionPattern<ttnn::Conv2dOp>>(typeConverter, ctx);
+  patterns.add<DefaultOpConversionPattern<ttnn::ConvTranspose2dOp>>(
+      typeConverter, ctx);
   patterns.add<DefaultOpConversionPattern<ttnn::MaxPool2dOp>>(typeConverter,
                                                               ctx);
 
diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp
index 73daad713e..b46a3873c1 100644
--- a/lib/Dialect/TTIR/IR/TTIROps.cpp
+++ b/lib/Dialect/TTIR/IR/TTIROps.cpp
@@ -141,6 +141,154 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvTranspose2dOp
+//===----------------------------------------------------------------------===//
+
+// ConvTranspose2dOp verification
+mlir::LogicalResult mlir::tt::ttir::ConvTranspose2dOp::verify() {
+  mlir::RankedTensorType inputType = getInput().getType();
+  mlir::RankedTensorType weightType = getWeight().getType();
+  mlir::RankedTensorType outputType = getOutput().getType();
+  std::optional<mlir::RankedTensorType> bias =
+      getBias().getImpl() ? std::make_optional(getBias().getType())
+                          : std::nullopt;
+
+  if (inputType.getRank() != 4) {
+    return emitOpError("Input must be a 4D tensor");
+  }
+
+  if (outputType.getRank() != 4) {
+    return emitOpError("Output must be a 4D tensor");
+  }
+
+  if (weightType.getRank() != 4) {
+    return emitOpError("Weight must be a 4D tensor");
+  }
+
+  if (bias.has_value()) {
+    if (bias->getRank() != 4) {
+      return emitOpError("Bias must be a 4D tensor");
+    }
+  }
+
+  if (inputType.getShape()[0] != outputType.getShape()[0]) {
+    return emitOpError("Batch size of input and output tensors must match");
+  }
+
+  auto stride = ttmlir::utils::getPairOfInteger<int32_t>(getStride());
+  if (auto error = stride.takeError()) {
+    return emitOpError() << llvm::toString(std::move(error)) << " for stride";
+  }
+  if (stride->first < 1 || stride->second < 1) {
+    return emitOpError("Stride values must be greater than 0");
+  }
+
+  auto padding = ttmlir::utils::getPairOfInteger<int32_t>(getPadding());
+  if (auto error = padding.takeError()) {
+    return emitOpError() << llvm::toString(std::move(error)) << " for padding";
+  }
+  if (padding->first < 0 || padding->second < 0) {
+    return emitOpError("Padding values must be greater or equal than 0");
+  }
+
+  auto outputPadding =
+      ttmlir::utils::getPairOfInteger<int32_t>(getOutputPadding());
+  if (auto error = outputPadding.takeError()) {
+    return emitOpError() << llvm::toString(std::move(error))
+                         << " for output padding";
+  }
+  if (outputPadding->first < 0 || outputPadding->second < 0) {
+    return emitOpError("Output padding values must be greater or equal than 0");
+  }
+
+  auto dilation = ttmlir::utils::getPairOfInteger<int32_t>(getDilation());
+  if (auto error = dilation.takeError()) {
+    return emitOpError() << llvm::toString(std::move(error)) << " for dilation";
+  }
+  if (dilation->first < 1 || dilation->second < 1) {
+    return emitOpError("Dilation values must be greater than 0");
+  }
+
+  llvm::ArrayRef<std::int64_t> kernelShape = weightType.getShape();
+
+  int32_t inputChannels = inputType.getDimSize(inputType.getRank() - 1);
+  int32_t outputChannels = outputType.getDimSize(outputType.getRank() - 1);
+  uint32_t groups = getGroups();
+
+  if (inputChannels % groups != 0) {
+    return emitOpError() << "Number of input channels from input tensor must "
+                            "be divisible by the number of groups. "
+                         << "Got " << inputChannels << " input channels and "
+                         << groups << " groups.";
+  }
+
+  if (outputChannels % groups != 0) {
+    return emitOpError() << "Number of output channels from output tensor must "
+                            "be divisible by the number of groups. "
+                         << "Got " << outputChannels << " output channels and "
+                         << groups << " groups.";
+  }
+
+  if (inputChannels != kernelShape[0]) {
+    return emitOpError() << "Number of input channels from input tensor must "
+                            "match the first dimension of the weight tensor. "
+                         << "Got " << inputChannels << " input channels and "
+                         << kernelShape[0] << " in the weight tensor.";
+  }
+
+  if (outputChannels / groups != kernelShape[1]) {
+    return emitOpError() << "Number of output channels per group must match "
+                            "the second dimension of the weight tensor. "
+                         << "Got " << (outputChannels / groups)
+                         << " output channels per group and " << kernelShape[1]
+                         << " in the weight tensor.";
+  }
+
+  if (bias) {
+    if (bias->getDimSize(bias->getRank() - 1) != outputChannels) {
+      return emitOpError() << "Mismatch in bias tensor dimensions. "
+                           << "Bias tensor has "
+                           << bias->getDimSize(bias->getRank() - 1)
+                           << " channels, "
+                           << "but the output tensor has " << outputChannels
+                           << " channels.";
+    }
+  }
+
+  int32_t kernelHeight = kernelShape[2];
+  int32_t kernelWidth = kernelShape[3];
+
+  int32_t Hin = inputType.getDimSize(inputType.getRank() - 3);
+  int32_t Win = inputType.getDimSize(inputType.getRank() - 2);
+
+  int32_t expectedHOut = (Hin - 1) * stride->first - 2 * padding->first +
+                         dilation->first * (kernelHeight - 1) +
+                         outputPadding->first + 1;
+  int32_t expectedWOut = (Win - 1) * stride->second - 2 * padding->second +
+                         dilation->second * (kernelWidth - 1) +
+                         outputPadding->second + 1;
+  if (expectedHOut < 0 || expectedWOut < 0) {
+    return emitOpError() << "Given input size per channel: (" << Hin << " x "
+                         << Win << "). "
+                         << "Calculated output size per channel: ("
+                         << expectedHOut << " x " << expectedWOut << "). "
+                         << "Output size is too small";
+  }
+
+  int32_t HOut = outputType.getDimSize(outputType.getRank() - 3);
+  int32_t WOut = outputType.getDimSize(outputType.getRank() - 2);
+  if (HOut != expectedHOut || WOut != expectedWOut) {
+    return emitOpError() << "Mismatch between expected output size per channel "
+                            "and got output tensor dimensions. "
+                         << "Expected: (" << expectedHOut << " x "
+                         << expectedWOut << "), "
+                         << "got: (" << HOut << " x " << WOut << ").";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConvolutionOp
 //===----------------------------------------------------------------------===//
diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp
index eccb1e9ba7..2560170528 100644
--- a/lib/Dialect/TTNN/IR/TTNNOps.cpp
+++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp
@@ -81,6 +81,168 @@ ::mlir::LogicalResult mlir::tt::ttnn::Conv2dOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvTranspose2dOp
+//===----------------------------------------------------------------------===//
+
+// ConvTranspose2dOp verification
+::mlir::LogicalResult mlir::tt::ttnn::ConvTranspose2dOp::verify() {
+  mlir::RankedTensorType inputType = getInput().getType();
+  mlir::RankedTensorType weightType = getWeight().getType();
+  mlir::RankedTensorType outputType = getOutput().getType();
+  std::optional<mlir::RankedTensorType> bias =
+      getBias().getImpl() ? std::make_optional(getBias().getType())
+                          : std::nullopt;
+
+  if (inputType.getRank() != 4) {
+    return emitOpError("Input must be a 4D tensor");
+  }
+
+  if (outputType.getRank() != 4) {
+    return emitOpError("Output must be a 4D tensor");
+  }
+
+  if (weightType.getRank() != 4) {
+    return emitOpError("Weight must be a 4D tensor");
+  }
+
+  if (bias.has_value()) {
+    if (bias->getRank() != 4) {
+      return emitOpError("Bias must be a 4D tensor");
+    }
+  }
+
+  std::function<mlir::LogicalResult(llvm::ArrayRef<int32_t> &, const char *,
+                                    int32_t)>
+      checkBiggerThan = [&](llvm::ArrayRef<int32_t> &values, const char *name,
+                            int32_t minValue) -> mlir::LogicalResult {
+    for (int32_t value : values) {
+      if (value < minValue) {
+        return emitOpError() << "Attribute '" << name
+                             << "' contains a value less than: " << minValue;
+      }
+    }
+    return mlir::success();
+  };
+
+  uint32_t inChannels = getInChannels();
+  if (inChannels != inputType.getDimSize(inputType.getRank() - 1)) {
+    return emitOpError("Input channels attribute must match "
+                       "the last dimension of the input tensor");
+  }
+
+  uint32_t outChannels = getOutChannels();
+  if (outChannels != outputType.getDimSize(outputType.getRank() - 1)) {
+    return emitOpError("Output channels attribute match "
+                       "the last dimension of the output tensor");
+  }
+
+  uint32_t batchSize = getBatchSize();
+  if (batchSize != inputType.getDimSize(0)) {
+    return emitOpError("Batch size attribute must match the first "
+                       "dimension of the input tensor");
+  }
+
+  uint32_t inputHeight = getInputHeight();
+  if (inputHeight != inputType.getDimSize(inputType.getRank() - 3)) {
+    return emitOpError("Input height attribute must match the third "
+                       "dimension of the input tensor");
+  }
+
+  uint32_t inputWidth = getInputWidth();
+  if (inputWidth != inputType.getDimSize(inputType.getRank() - 2)) {
+    return emitOpError("Input width attribute must match the second "
+                       "dimension of the input tensor");
+  }
+
+  llvm::ArrayRef<int32_t> stride = getStride();
+  if (failed(checkBiggerThan(stride, "stride", 1))) {
+    return mlir::failure();
+  }
+
+  llvm::ArrayRef<int32_t> padding = getPadding();
+  if (failed(checkBiggerThan(padding, "padding", 0))) {
+    return mlir::failure();
+  }
+
+  llvm::ArrayRef<int32_t> outputPadding = getOutputPadding();
+  if (failed(checkBiggerThan(outputPadding, "output padding", 0))) {
+    return mlir::failure();
+  }
+
+  llvm::ArrayRef<int32_t> dilation = getDilation();
+  if (failed(checkBiggerThan(dilation, "dilation", 1))) {
+    return mlir::failure();
+  }
+
+  llvm::ArrayRef<std::int64_t> kernelShape = weightType.getShape();
+
+  int32_t inputChannels = inputType.getDimSize(inputType.getRank() - 1);
+  int32_t outputChannels = outputType.getDimSize(outputType.getRank() - 1);
+  uint32_t groups = getGroups();
+
+  if (inputChannels % groups != 0) {
+    return emitOpError() << "Number of input channels from input tensor must "
+                            "be divisible by the number of groups. "
+                         << "Got " << inputChannels << " input channels and "
+                         << groups << " groups.";
+  }
+
+  if (outputChannels % groups != 0) {
+    return emitOpError() << "Number of output channels from output tensor must "
+                            "be divisible by the number of groups. "
+                         << "Got " << outputChannels << " output channels and "
+                         << groups << " groups.";
+  }
+
+  if (inputChannels != kernelShape[0]) {
+    return emitOpError() << "Number of input channels from input tensor must "
+                            "match the first dimension of the weight tensor. "
+                         << "Got " << inputChannels << " input channels and "
+                         << kernelShape[0] << " in the weight tensor.";
+  }
+
+  if (outputChannels / groups != kernelShape[1]) {
+    return emitOpError() << "Number of output channels per group must match "
+                            "the second dimension of the weight tensor. "
+                         << "Got " << (outputChannels / groups)
+                         << " output channels per group and " << kernelShape[1]
+                         << " in the weight tensor.";
+  }
+
+  if (bias) {
+    if (bias->getDimSize(bias->getRank() - 1) != outputChannels) {
+      return emitOpError() << "Mismatch in bias tensor dimensions. "
+                           << "Bias tensor has "
+                           << bias->getDimSize(bias->getRank() - 1)
+                           << " channels, "
+                           << "but the output tensor has " << outputChannels
+                           << " channels.";
+    }
+  }
+
+  int32_t kernelHeight = kernelShape[2];
+  int32_t kernelWidth = kernelShape[3];
+
+  int32_t Hin = inputType.getDimSize(inputType.getRank() - 3);
+  int32_t Win = inputType.getDimSize(inputType.getRank() - 2);
+
+  int32_t expectedHOut = (Hin - 1) * stride[0] - 2 * padding[0] +
+                         dilation[0] * (kernelHeight - 1) + outputPadding[0] +
+                         1;
+  int32_t expectedWOut = (Win - 1) * stride[1] - 2 * padding[1] +
+                         dilation[1] * (kernelWidth - 1) + outputPadding[1] + 1;
+  if (expectedHOut < 0 || expectedWOut < 0) {
+    return emitOpError() << "Given input size per channel: (" << Hin << " x "
+                         << Win << "). "
+                         << "Calculated output size per channel: ("
+                         << expectedHOut << " x " << expectedWOut << "). "
+                         << "Output size is too small";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MaxPool2dOp
 //===----------------------------------------------------------------------===//
diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
index e148b575fb..14923c0a7c 100644
--- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
+++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
@@ -283,11 +283,13 @@ class TTNNLayoutDPSOperandsRewriter
     bool modified = false;
     for (OpOperand &operand : op->getOpOperands()) {
       // Check if the operand is a dps result
-      bool isResult = op.isDpsInit(&operand);
+      bool isDPSResult = op.isDpsInit(&operand);
 
       // TTNN Conv2d moves input, weight, and bias from host to device
       // itself. Inserting the ToLayoutOp on these operands is thus problematic.
-      if (mlir::isa<ttir::Conv2dOp>(op.getOperation()) && !isResult) {
+      if (!isDPSResult &&
+          (mlir::isa<ttir::Conv2dOp>(op.getOperation()) ||
+           mlir::isa<ttir::ConvTranspose2dOp>(op.getOperation()))) {
         // For the weight input of the conv2d op, it specifically needs to be on
         // host, so we create a host to layout op (issue
         // https://github.com/tenstorrent/tt-mlir/issues/1528).
@@ -319,7 +321,7 @@ class TTNNLayoutDPSOperandsRewriter
           modified = true;
           op->setOperand(operand.getOperandNumber(), *desiredLayout);
           // If operand is dps result, update the result type on current op
-          if (isResult) {
+          if (isDPSResult) {
             op->getResult(0).setType(desiredLayout->getType());
           }
         });
diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp
index 055566c243..e720840c7f 100644
--- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp
+++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp
@@ -467,6 +467,40 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) {
       op.getGroups());
 }
 
+::flatbuffers::Offset<::tt::target::ttnn::ConvTranspose2dOp>
+createOp(FlatbufferObjectCache &cache, ConvTranspose2dOp op) {
+  auto in0 =
+      cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
+  auto in1 = cache.at<::tt::target::TensorRef>(
+      getOperandThroughDPSOps(op.getWeight()));
+  auto in2 = op.getODSOperands(2).empty()
+                 ? flatbuffers::Offset<::tt::target::TensorRef>()
+                 : cache.at<::tt::target::TensorRef>(
+                       getOperandThroughDPSOps(op.getBias()));
+  auto output = cache.at<::tt::target::TensorRef>(
+      getOperandThroughDPSOps(op.getResult()));
+
+  auto device = getOperandThroughDPSOps(op.getDevice());
+
+  ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> kernelSize =
+      toFlatbuffer(cache, op.getKernelSize());
+  ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride =
+      toFlatbuffer(cache, op.getStride());
+  ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> padding =
+      toFlatbuffer(cache, op.getPadding());
+  ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputPadding =
+      toFlatbuffer(cache, op.getOutputPadding());
+  ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> dilation =
+      toFlatbuffer(cache, op.getDilation());
+
+  return ::tt::target::ttnn::CreateConvTranspose2dOp(
+      *cache.fbb, in0, in1, in2, output,
+      cache.at<::tt::target::DeviceRef>(device), op.getInChannels(),
+      op.getOutChannels(), op.getBatchSize(), op.getInputHeight(),
+      op.getInputWidth(), kernelSize, stride, padding, outputPadding, dilation,
+      op.getGroups());
+}
+
 ::flatbuffers::Offset<::tt::target::ttnn::AllGatherOp>
 createOp(FlatbufferObjectCache &cache, AllGatherOp op) {
   auto input =
@@ -1122,6 +1156,11 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
     return createOperation(cache, createOp(cache, conv2dOp), debugString,
                            locInfo);
   }
+  if (auto conv_transpose2dOp = dyn_cast<ConvTranspose2dOp>(op);
+      conv_transpose2dOp) {
+    return createOperation(cache, createOp(cache, conv_transpose2dOp),
+                           debugString, locInfo);
+  }
   if (auto allGatherOp = dyn_cast<AllGatherOp>(op); allGatherOp) {
     return createOperation(cache, createOp(cache, allGatherOp), debugString,
                            locInfo);
diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt
index fa5cd3c06b..40b9de913d 100644
--- a/runtime/lib/ttnn/operations/CMakeLists.txt
+++ b/runtime/lib/ttnn/operations/CMakeLists.txt
@@ -6,6 +6,7 @@ set(TTNN_OPS_SRCS
   ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv_transpose2d.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/creation/ones.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp
diff --git a/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp b/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp
new file mode 100644
index 0000000000..eee3d9eb01
--- /dev/null
+++ b/runtime/lib/ttnn/operations/conv/conv_transpose2d.cpp
@@ -0,0 +1,58 @@
+// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
+//
+// SPDX-License-Identifier: Apache-2.0
+
+#include "operations/conv/conv_transpose2d.h"
+#include "tt/runtime/detail/logger.h"
+#include "tt/runtime/detail/ttnn.h"
+#include "tt/runtime/ttnn/operations/utils.h"
+#include "tt/runtime/ttnn/utils.h"
+#include "ttmlir/Target/TTNN/program_generated.h"
+#include "ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp"
+#include "ttnn/types.hpp"
+
+namespace tt::runtime::ttnn::operations::conv {
+void run(const ::tt::target::ttnn::ConvTranspose2dOp *op,
+         ProgramContext &context) {
+  ProgramTensorPool &tensorPool = context.getTensorPool();
+  const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id());
+  const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id());
+  DEBUG_ASSERT(input.is_allocated());
+  DEBUG_ASSERT(weight.is_allocated());
+
+  std::optional<::ttnn::Tensor> bias =
+      op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id()))
+                 : std::nullopt;
+
+  std::array<uint32_t, 2> kernelSize, stride, padding, outputPadding, dilation;
+  std::copy(op->kernel_size()->begin(), op->kernel_size()->end(),
+            kernelSize.begin());
+  std::copy(op->stride()->begin(), op->stride()->end(), kernelSize.begin());
+  std::copy(op->padding()->begin(), op->padding()->end(), kernelSize.begin());
+  std::copy(op->output_padding()->begin(), op->output_padding()->end(),
+            kernelSize.begin());
+  std::copy(op->dilation()->begin(), op->dilation()->end(), kernelSize.begin());
+
+  auto config = ::ttnn::operations::conv::Conv2dConfig();
+  config.dtype = utils::getDataType(op->input());
+  config.weights_dtype = utils::getDataType(op->weight());
+  config.shard_layout = ::ttnn::TensorMemoryLayout::WIDTH_SHARDED;
+  ::ttnn::MemoryConfig outMemConfig =
+      ::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
+
+  DeviceVariant targetDevice =
+      context.getTargetDevice(op->device()->global_id());
+  ::ttnn::Tensor out = std::visit(
+      [&](auto &&targetDevice) -> ::ttnn::Tensor {
+        return std::get<0>(::ttnn::conv_transpose2d(
+            ::ttnn::DefaultQueueId, input, weight, &(targetDevice.get()),
+            op->in_channels(), op->out_channels(), op->batch_size(),
+            op->input_height(), op->input_width(), kernelSize, stride, padding,
+            outputPadding, dilation, op->groups(), bias, config));
+      },
+      targetDevice);
+
+  tensorPool.insert_or_assign(op->out()->global_id(), out);
+}
+
+} // namespace tt::runtime::ttnn::operations::conv
diff --git a/runtime/lib/ttnn/operations/conv/conv_transpose2d.h b/runtime/lib/ttnn/operations/conv/conv_transpose2d.h
new file mode 100644
index 0000000000..a3be8431c9
--- /dev/null
+++ b/runtime/lib/ttnn/operations/conv/conv_transpose2d.h
@@ -0,0 +1,17 @@
+// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
+//
+// SPDX-License-Identifier: Apache-2.0
+
+#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONVTRANSPOSE2D_H
+#define RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONVTRANSPOSE2D_H
+
+#include "tt/runtime/ttnn/types.h"
+#include "ttmlir/Target/TTNN/program_generated.h"
+
+namespace tt::runtime::ttnn::operations::conv {
+void run(const ::tt::target::ttnn::ConvTranspose2dOp *op,
+         ProgramContext &context);
+
+} // namespace tt::runtime::ttnn::operations::conv
+
+#endif
diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp
index a1176253aa..773358c765 100644
--- a/runtime/lib/ttnn/program.cpp
+++ b/runtime/lib/ttnn/program.cpp
@@ -4,6 +4,7 @@
 #include "operations/ccl/all_gather.h"
 #include "operations/context/get_device.h"
 #include "operations/conv/conv2d.h"
+#include "operations/conv/conv_transpose2d.h"
 #include "operations/creation/arange.h"
 #include "operations/creation/empty.h"
 #include "operations/creation/full.h"
@@ -219,6 +220,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) {
   case ::tt::target::ttnn::OpType::Conv2dOp: {
     return operations::conv::run(op->type_as_Conv2dOp(), context);
   }
+  case ::tt::target::ttnn::OpType::ConvTranspose2dOp: {
+    return operations::conv::run(op->type_as_ConvTranspose2dOp(), context);
+  }
   case ::tt::target::ttnn::OpType::DeallocateOp: {
     return operations::deletion::run(op->type_as_DeallocateOp(), context);
   }
diff --git a/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir
new file mode 100644
index 0000000000..f29180ead2
--- /dev/null
+++ b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_negative.mlir
@@ -0,0 +1,363 @@
+// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
+// Negative tests for conv_transpose2d operation
+
+// Verify that the parsing fails if tensors don't have four dimensions
+module attributes {} {
+  func.func @conv_transpose2d_invalid_input_shape(%arg0: tensor<8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Input must be a 4D tensor
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_weight_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x8x8x256xbf16> {
+    %0 = tensor.empty() : tensor<1x8x8x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Weight must be a 4D tensor
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x8x8x256xbf16>) -> tensor<1x8x8x256xbf16>
+    return %1 : tensor<1x8x8x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_bias_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<256xbf16>) -> tensor<1x8x8x256xbf16> {
+    %0 = tensor.empty() : tensor<1x8x8x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Bias must be a 4D tensor
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<256xbf16>, tensor<1x8x8x256xbf16>) -> tensor<1x8x8x256xbf16>
+    return %1 : tensor<1x8x8x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_output_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Output must be a 4D tensor
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<10x10x256xbf16>) -> tensor<10x10x256xbf16>
+    return %1 : tensor<10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_output_shape(%arg0: tensor<4x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<2x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<2x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Batch size of input and output tensors must match
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<4x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<2x10x10x256xbf16>) -> tensor<2x10x10x256xbf16>
+    return %1 : tensor<2x10x10x256xbf16>
+  }
+}
+
+// Verify that the parsing fails if attributes are not integers or pair of integers
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_stride_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Expected integer or pair of integers, got tuple of size 3 for stride
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = array<i32: 1, 2, 3>,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_padding_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Expected integer or pair of integers, got tuple of size 3 for padding
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = array<i32: 5, 6, 7>,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_output_padding_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Expected integer or pair of integers, got tuple of size 3 for output padding
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = array<i32: 8, 9, 10>,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_dilation_shape(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Expected integer or pair of integers, got tuple of size 3 for dilation
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = array<i32: 11, 12, 13>,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// Verify that the parsing fails if attributes have invalid values
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_stride_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Stride values must be greater than 0
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = array<i32: 2, -2>,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_padding_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Padding values must be greater or equal than 0
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = array<i32: -1, 0>,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_output_padding_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Output padding values must be greater or equal than 0
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = -6: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_invalid_dilation_values(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Dilation values must be greater than 0
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = array<i32: -2, -2>,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// Verify the parsing fails if number of channels are incorrect
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_input_channels_not_divisible_by_groups(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Number of input channels from input tensor must be divisible by the number of groups. Got 256 input channels and 3 groups
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 3: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_output_channels_not_divisible_by_groups(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x350x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x350xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x350xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Number of output channels from output tensor must be divisible by the number of groups. Got 350 output channels and 4 groups.
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 4: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x350x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x350xbf16>) -> tensor<1x10x10x350xbf16>
+    return %1 : tensor<1x10x10x350xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_input_channels_missmatch_with_weight(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<128x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Number of input channels from input tensor must match the first dimension of the weight tensor. Got 256 input channels and 128 in the weight tensor.
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<128x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_output_channels_missmatch_with_weight(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Number of output channels per group must match the second dimension of the weight tensor. Got 64 output channels per group and 256 in the weight tensor.
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 4: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_output_channels_missmatch_with_bias(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Mismatch in bias tensor dimensions. Bias tensor has 128 channels, but the output tensor has 256 channels.
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// Verify the parsing fails if calculated output size per channel is below zero or different from the output tensor
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_output_channels_missmatch_with_bias(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x128xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Mismatch in bias tensor dimensions. Bias tensor has 128 channels, but the output tensor has 256 channels.
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x128xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_calculated_output_size_per_channel_below_zero(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<1x10x10x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Given input size per channel: (8 x 8). Calculated output size per channel: (-2 x -4). Output size is too small
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = array<i32: 6, 7>,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x10x10x256xbf16>) -> tensor<1x10x10x256xbf16>
+    return %1 : tensor<1x10x10x256xbf16>
+  }
+}
+
+// -----
+module attributes {} {
+  func.func @conv_transpose2d_calculated_output_size_per_channel_missmatch_with_output_tensor(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x2x2x256xbf16> {
+    %0 = tensor.empty() : tensor<1x2x2x256xbf16>
+    // CHECK: error: 'ttir.conv_transpose2d' op Mismatch between expected output size per channel and got output tensor dimensions. Expected: (10 x 10), got: (2 x 2).
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x2x2x256xbf16>) -> tensor<1x2x2x256xbf16>
+    return %1 : tensor<1x2x2x256xbf16>
+  }
+}
diff --git a/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir
new file mode 100644
index 0000000000..bf1d52f0d1
--- /dev/null
+++ b/test/ttmlir/Dialect/TTIR/conv_transpose2d/conv_transpose2d_tests_positive.mlir
@@ -0,0 +1,101 @@
+// RUN: ttmlir-opt %s | FileCheck %s
+
+module attributes {} {
+  func.func @conv_transpose2d_simple(%arg0: tensor<4x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<4x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<4x10x10x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<4x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<4x10x10x256xbf16>) -> tensor<4x10x10x256xbf16>
+    return %1 : tensor<4x10x10x256xbf16>
+  }
+
+  func.func @conv_transpose2d_stride(%arg0: tensor<1x16x32x256xbf16>, %arg1: tensor<256x256x8x8xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x38x132x256xbf16> {
+    %0 = tensor.empty() : tensor<1x38x132x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = array<i32: 2, 4>,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x16x32x256xbf16>, tensor<256x256x8x8xbf16>, tensor<1x1x1x256xbf16>, tensor<1x38x132x256xbf16>) -> tensor<1x38x132x256xbf16>
+    return %1 : tensor<1x38x132x256xbf16>
+  }
+
+  func.func @conv_transpose2d_padding(%arg0: tensor<1x64x64x256xbf16>, %arg1: tensor<256x256x16x16xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x73x67x256xbf16> {
+    %0 = tensor.empty() : tensor<1x73x67x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = array<i32: 3, 6>,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x64x64x256xbf16>, tensor<256x256x16x16xbf16>, tensor<1x1x1x256xbf16>, tensor<1x73x67x256xbf16>) -> tensor<1x73x67x256xbf16>
+    return %1 : tensor<1x73x67x256xbf16>
+  }
+
+  func.func @conv_transpose2d_output_padding(%arg0: tensor<1x32x32x128xbf16>, %arg1: tensor<128x256x8x8xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x45x47x256xbf16> {
+    %0 = tensor.empty() : tensor<1x45x47x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = array<i32: 6, 8>,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<1x32x32x128xbf16>, tensor<128x256x8x8xbf16>, tensor<1x1x1x256xbf16>, tensor<1x45x47x256xbf16>) -> tensor<1x45x47x256xbf16>
+    return %1 : tensor<1x45x47x256xbf16>
+  }
+
+  func.func @conv_transpose2d_dilation(%arg0: tensor<1x32x32x128xbf16>, %arg1: tensor<128x256x16x32xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x77x94x256xbf16> {
+    %0 = tensor.empty() : tensor<1x77x94x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = array<i32: 3, 2>,
+              groups = 1: i32}
+            > : (tensor<1x32x32x128xbf16>, tensor<128x256x16x32xbf16>, tensor<1x1x1x256xbf16>, tensor<1x77x94x256xbf16>) -> tensor<1x77x94x256xbf16>
+    return %1 : tensor<1x77x94x256xbf16>
+  }
+
+  func.func @conv_transpose2d_groups(%arg0: tensor<1x16x32x192xbf16>, %arg1: tensor<192x126x8x8xbf16>, %arg2: tensor<1x1x1x252xbf16>) -> tensor<1x23x39x252xbf16> {
+    %0 = tensor.empty() : tensor<1x23x39x252xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 2: i32}
+            > : (tensor<1x16x32x192xbf16>, tensor<192x126x8x8xbf16>, tensor<1x1x1x252xbf16>, tensor<1x23x39x252xbf16>) -> tensor<1x23x39x252xbf16>
+    return %1 : tensor<1x23x39x252xbf16>
+  }
+
+  func.func @conv_transpose2d(%arg0: tensor<1x8x8x256xbf16>, %arg1: tensor<256x64x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<1x21x38x256xbf16> {
+    %0 = tensor.empty() : tensor<1x21x38x256xbf16>
+    // CHECK: %[[C:.*]] = "ttir.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = array<i32: 2, 3>,
+              padding = array<i32: 6, 4>,
+              output_padding = array<i32: 10, 12>,
+              dilation = array<i32: 4, 6>,
+              groups = 4: i32}
+            > : (tensor<1x8x8x256xbf16>, tensor<256x64x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<1x21x38x256xbf16>) -> tensor<1x21x38x256xbf16>
+    return %1 : tensor<1x21x38x256xbf16>
+  }
+}
diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir
new file mode 100644
index 0000000000..a268c7bab7
--- /dev/null
+++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv_transpose2d.mlir
@@ -0,0 +1,19 @@
+// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
+// RUN: FileCheck %s --input-file=%t.mlir
+// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
+
+module attributes {} {
+  func.func @forward(%arg0: tensor<3x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<3x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<3x10x10x256xbf16>
+    // CHECK: %[[C:.*]] = "ttnn.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<3x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<3x10x10x256xbf16>) -> tensor<3x10x10x256xbf16>
+    return %1 : tensor<3x10x10x256xbf16>
+  }
+}
diff --git a/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir b/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir
new file mode 100644
index 0000000000..a268c7bab7
--- /dev/null
+++ b/test/ttmlir/Silicon/TTNN/simple_conv_transpose2d.mlir
@@ -0,0 +1,19 @@
+// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
+// RUN: FileCheck %s --input-file=%t.mlir
+// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
+
+module attributes {} {
+  func.func @forward(%arg0: tensor<3x8x8x256xbf16>, %arg1: tensor<256x256x3x3xbf16>, %arg2: tensor<1x1x1x256xbf16>) -> tensor<3x10x10x256xbf16> {
+    %0 = tensor.empty() : tensor<3x10x10x256xbf16>
+    // CHECK: %[[C:.*]] = "ttnn.conv_transpose2d"[[C:.*]]
+    %1 = "ttir.conv_transpose2d"(%arg0, %arg1, %arg2, %0)
+            <{
+              stride = 1: i32,
+              padding = 0: i32,
+              output_padding = 0: i32,
+              dilation = 1: i32,
+              groups = 1: i32}
+            > : (tensor<3x8x8x256xbf16>, tensor<256x256x3x3xbf16>, tensor<1x1x1x256xbf16>, tensor<3x10x10x256xbf16>) -> tensor<3x10x10x256xbf16>
+    return %1 : tensor<3x10x10x256xbf16>
+  }
+}