Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,10 +1628,12 @@ def _std(self, node: fx.Node) -> relax.Var:

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
if isinstance(dim, (list, tuple)) and len(dim) == 0:
dim = None
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))

def _var(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ def create_convert_map(
"any.dim": self._any,
"any.dims": self._any,
"mean.dim": self._mean,
"mean.default": self._mean,
"prod.default": self._prod,
"std.correction": self._std,
"sum.default": self._sum,
Expand Down
62 changes: 59 additions & 3 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1805,12 +1805,70 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice);
}

// TODO(relax-team): implement FRelaxInferLayout for repeat
InferLayoutOutput InferLayoutRepeat(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));

const auto* attrs = call->attrs.as<RepeatAttrs>();
ICHECK(attrs != nullptr) << "Invalid Call";
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
int ndim = tensor_sinfo->ndim;

// Can't handle sub indexed layouts.
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(ndim));
}

// When axis is not specified, the output is 1D (flattened)
if (!attrs->axis.has_value()) {
return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(1)}, Attrs(call->attrs));
}

// Transform the axis based on the layout
int axis = attrs->axis.value();
if (axis < 0) {
axis += ndim;
}

// Create a mapping from original layout to existing layout
std::string axis_str(ndim, '0');
axis_str[axis] = '1';
for (int i = 0, j = 0; i < ndim; ++i) {
if (axis_str[i] != '1') {
axis_str[i] = 'A' + j++;
}
}

ffi::String new_axis_str =
TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);

int64_t new_axis = -1;
for (size_t i = 0; i < new_axis_str.size(); ++i) {
if (new_axis_str.at(i) == '1') {
new_axis = i;
break;
}
}
ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";

ObjectPtr<RepeatAttrs> new_attrs = ffi::make_object<RepeatAttrs>(*attrs);
new_attrs->axis = new_axis;

// When axis is specified, the layout is preserved
return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.repeat")
.set_attrs_type<RepeatAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRepeat)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRepeat)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.tile */
Expand Down Expand Up @@ -2398,7 +2456,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) {
if (sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?
LOG(WARNING) << "Data type of " << name
<< " has not been specified. Assume it has an integer type.";
}
Expand All @@ -2415,7 +2472,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
}

if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?
LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(
Expand Down
33 changes: 30 additions & 3 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,36 @@ class CUDAWrappedFunc {
}
}
CUstream strm = static_cast<CUstream>(TVMFFIEnvGetStream(kDLCUDA, device_id));
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr);
CUresult result;

if (launch_param_config_.use_programtic_dependent_launch()) {
CUlaunchConfig config{};
CUlaunchAttribute attribute[1]{};
attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION;
attribute[0].value.programmaticStreamSerializationAllowed = 1;

config.attrs = attribute;
config.numAttrs = 1;
config.hStream = strm;
config.gridDimX = wl.grid_dim(0);
config.gridDimY = wl.grid_dim(1);
config.gridDimZ = wl.grid_dim(2);
config.blockDimX = wl.block_dim(0);
config.blockDimY = wl.block_dim(1);
config.blockDimZ = wl.block_dim(2);
config.sharedMemBytes = wl.dyn_shmem_size;

result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr);
} else if (launch_param_config_.use_cooperative_launch()) {
result = cuLaunchCooperativeKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args);
} else {
result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2),
wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size,
strm, void_args, nullptr);
}

if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char* msg;
cuGetErrorName(result, &msg);
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ namespace launch_param {

/*! \brief A tag to specify whether or not dynamic shared memory is used */
constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory";
/*! \brief A tag to specify whether or not use programatic dependent launch */
constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch";
/*! \brief A tag to specify whether or not use cooperative launch */
constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch";

} // namespace launch_param

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ void OpenCLWorkspace::SetNativePtr(const tvm::runtime::Tensor& narr, void* host_
}

void OpenCLWorkspace::SetPerfHint(Device dev, cl_uint perf_hint) {
#ifdef CL_CONTEXT_PERF_HINT_QCOM
#if defined(USE_OPENCL_EXTN_QCOM) && defined(CL_CONTEXT_PERF_HINT_QCOM)
cl_device_id device_id = GetCLDeviceID(dev.device_id);
auto platform = device_info[device_id].platform_id;
OPENCL_CALL(clSetPerfHintQCOM(this->contexts[platform], perf_hint));
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ class LaunchParamConfig {
ICHECK_EQ(i, launch_param_tags.size() - 1)
<< "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags.";
use_dyn_shared_memory_ = true;
} else if (tag == launch_param::kUseProgramaticDependentLaunch) {
use_programmatic_dependent_launch_ = true;
} else if (tag == launch_param::kUseCooperativeLaunch) {
use_cooperative_launch_ = true;
} else {
ThreadScope ts = ThreadScope::Create(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
Expand Down Expand Up @@ -281,6 +285,10 @@ class LaunchParamConfig {
// return the work dim
size_t work_dim() const { return work_dim_; }

bool use_programtic_dependent_launch() const { return use_programmatic_dependent_launch_; }

bool use_cooperative_launch() const { return use_cooperative_launch_; }

private:
/*! \brief base axis */
size_t base_;
Expand All @@ -290,6 +298,10 @@ class LaunchParamConfig {
std::vector<uint32_t> arg_index_map_;
/*! \brief Whether or not use dynamic shared memory. */
bool use_dyn_shared_memory_{false};
/*! \brief Whether or not use programmatic dependent launch. */
bool use_programmatic_dependent_launch_{false};
/*! \brief Whether or not use cooperative launch. */
bool use_cooperative_launch_{false};
};

} // namespace runtime
Expand Down
8 changes: 5 additions & 3 deletions src/target/llvm/codegen_aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
auto kVScaleValues = arith::GetVScaleValues(Target::Current());
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
func->addFnAttr(
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val));
if (!kVScaleValues.empty()) {
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
func->addFnAttr(
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val));
}
}
#endif
CodeGenCPU::SetTargetAttributes(func);
Expand Down
34 changes: 29 additions & 5 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ TVM_REGISTER_OP("tir.sinh")
TVM_REGISTER_OP("tir.asin")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using namespace intrin;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];

PrimExpr threshold = make_const(x.dtype(), 0.5);
PrimExpr abs_x = tvm::abs(x);
PrimExpr use_lib = abs_x >= threshold;

PrimExpr x2 = x * x;
PrimExpr term1 = x;
PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6);
Expand All @@ -178,25 +184,43 @@ TVM_REGISTER_OP("tir.asin")
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11;
/* --- domain limit check --- */

PrimExpr lib_result =
::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);

PrimExpr lower = make_const(x.dtype(), -1.0);
PrimExpr upper = make_const(x.dtype(), 1.0);
PrimExpr out_range = tir::Or(x<lower, x> upper);
// Use a quiet NaN constant
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());
// select: if out of [-1,1] → NaN, else → series
return tir::Select(out_range, nan_const, series);

return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, series));
});

TVM_REGISTER_OP("tir.acos")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using namespace intrin;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr) << "Invalid call node in acos legalization";
const PrimExpr& x = call->args[0];

PrimExpr threshold = make_const(x.dtype(), 0.5);
PrimExpr abs_x = tvm::abs(x);
PrimExpr use_lib = abs_x >= threshold;

PrimExpr half_pi = make_const(x.dtype(), M_PI / 2);
PrimExpr asin_x = asin(x);
return half_pi - asin_x;
PrimExpr formula_result = half_pi - asin_x;

PrimExpr lib_result =
::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);

PrimExpr lower = make_const(x.dtype(), -1.0);
PrimExpr upper = make_const(x.dtype(), 1.0);
PrimExpr out_range = tir::Or(x<lower, x> upper);
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());

return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, formula_result));
});

TVM_REGISTER_OP("tir.atan")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def main(


def test_residual_block():
"""
r"""
- some kind of residual block followed by convolution to have texture after residual block
- scalar data type verification which should be mapped to global memory scope
layout_transform (NCHW->NCHW4c)
Expand Down Expand Up @@ -874,7 +874,7 @@ def main(


def test_conv2d_conv2d_fallback_to_buffer_conv2d():
"""
r"""
layout_transform (NCHW->NCHW4c)
| <- texture
conv2d (1) <- textures as output
Expand Down Expand Up @@ -931,7 +931,7 @@ def main(


def test_conv2d_conv2d_conv2d_concat():
"""
r"""
layout_transform (NCHW->NCHW4c)
| <- texture
conv2d (1) <- textures as output
Expand Down Expand Up @@ -991,7 +991,7 @@ def main(


def test_pooling_branching_texture_params():
"""
r"""
Verification of the pooling and many branches having textures
layout_transform (NCHW->NCHW4c)
| <- texture
Expand Down Expand Up @@ -1066,7 +1066,7 @@ def main(


def test_injective_inputs1():
"""
r"""
Input
/ \
/ |
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def main(


def test_injective_nwo_inputs2():
"""
r"""
Input
/ \
| \
Expand Down
Loading