Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
if (op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
if (dyn_cast<RankedTensorType>(op.getOperand(i).getType())) {
llvm_unreachable("Not implemented for tensor types");
}

Expand All @@ -61,7 +61,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) const {
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
if (isa<LLVM::LLVMPointerType>(type)) {
return "%p";
}
// Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
auto ctx = type.getContext();
auto pointeeType = type.getPointeeType();
if (pointeeType.isa<RankedTensorType>()) {
if (isa<RankedTensorType>(pointeeType)) {
llvm_unreachable("Not implemented");
}
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context)
addConversion([this](triton::PointerType ptrType) -> triton::PointerType {
// Check whether tensor pointer `tt.ptr<tensor<>>`
auto pointeeTensorType =
ptrType.getPointeeType().dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(ptrType.getPointeeType());
if (pointeeTensorType == nullptr)
return ptrType;

Expand Down Expand Up @@ -99,9 +99,9 @@ TritonCPUConversionTarget::TritonCPUConversionTarget(
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
dotOp.getA().getType().cast<RankedTensorType>().getEncoding();
cast<RankedTensorType>(dotOp.getA().getType()).getEncoding();
Attribute bEncoding =
dotOp.getB().getType().cast<RankedTensorType>().getEncoding();
cast<RankedTensorType>(dotOp.getB().getType()).getEncoding();
// TODO:
return false;
});
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def run(self, *args, grid, warmup, **kwargs):

# The CPU launcher will provide the grid ids directly to the kernel.
# Note that this design is interim and subject to change.
if target[0] == 'cpu':
if target.backend == 'cpu':
signature["__grid0"] = 'i32'
signature["__grid1"] = 'i32'
signature["__grid2"] = 'i32'
Expand Down
6 changes: 3 additions & 3 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any

from triton._C.libtriton import cpu, ir, llvm, passes
from triton.backends.compiler import BaseBackend
from triton.backends.compiler import BaseBackend, GPUTarget


@dataclass(frozen=True)
Expand Down Expand Up @@ -35,8 +35,8 @@ def hash(self):
class CPUBackend(BaseBackend):

@staticmethod
def supports_target(target: tuple):
return target[0] == "cpu"
def supports_target(target: GPUTarget):
return target.backend == "cpu"

def __init__(self, target: tuple) -> None:
super().__init__(target)
Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from triton.backends.compiler import GPUTarget
from triton.backends.driver import CPUDriverBase

# ------------------------
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(self):

def get_current_target(self):
# Capability and warp size are zeros for CPU.
return ("cpu", 0, 0)
# TODO: GPUTarget naming isn't obviously good.
return GPUTarget("cpu", 0, 0)

@staticmethod
def is_active():
Expand Down