diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index e402c5888978..f258bffc3e8f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -50,7 +50,9 @@ def cuda_atomic_add_rule(op): def opencl_atomic_add_rule(op): if op.dtype == "int32": return tvm.tir.call_pure_extern("int32", "atomic_add", op.args[0], op.args[1]) - raise RuntimeError("only support int32") + elif op.dtype == "float32": + return tvm.tir.call_pure_extern("float32", "atomic_add", op.args[0], op.args[1]) + raise RuntimeError("only support int32, float32") register_intrin_lowering("tir.atomic_add", target="cuda", f=cuda_atomic_add_rule, level=99) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index f17a452d5c28..5933c9582cec 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -129,6 +129,16 @@ std::string CodeGenOpenCL::Finish() { if (enable_atomics_) { decl_stream << "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" "#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n\n"; + decl_stream << "__inline float atomic_add_float_emu(volatile __global float* sum, const float " + "toAdd) {\n" + "float next_value = 0;" + "float prev_value = 0;" + "do {\n" + "prev_value =*(sum);\n" + "next_value =prev_value + toAdd;\n" + "} while(atomic_cmpxchg((volatile global int *)(sum), *((int*)&prev_value), " + "*((int*)&next_value)) != *((int*)&prev_value));\n" + "return next_value;\n}\n"; } // Enable OpenCL 1.2 sampler-less texture reads, but utilize @@ -458,13 +468,21 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args.back(), os); os << "]"; } - } else if (op->op.same_as(builtin_call_extern_)) { + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { auto func = Downcast(op->args[0]); // Enable atomics extension if used. - if (func->value == "atomic_add") { + if (func->value == "atomic_add" && op->dtype.is_float()) { enable_atomics_ = true; + this->PrintCallExtern(GetType(GetRef(op)), "atomic_add_float_emu", op->args, true, + os); + } else if (func->value == "nearbyint") { + this->PrintCallExtern(GetType(GetRef(op)), "round", op->args, true, os); + } else { + if (func->value == "atomic_add") { + enable_atomics_ = true; + } + CodeGenC::VisitExpr_(op, os); } - CodeGenC::VisitExpr_(op, os); } else { CodeGenC::VisitExpr_(op, os); } @@ -534,6 +552,34 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { PrintBinaryExpr(op, "max", os, this); } +void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) + std::string opstr; + if (op->dtype.is_int() || op->dtype.is_uint()) { + opstr = "%"; + } else { + ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " + << op->dtype; + opstr = "fmod"; + } + if (op->dtype.lanes() == 1) { + if (isalpha(opstr.c_str()[0])) { + os << opstr.c_str() << '('; + this->PrintExpr(op->a, os); + os << ", "; + this->PrintExpr(op->b, os); + os << ')'; + } else { + os << '('; + this->PrintExpr(op->a, os); + os << ' ' << opstr.c_str() << ' '; + this->PrintExpr(op->b, os); + os << ')'; + } + } else { + this->PrintVecBinaryOp(opstr.c_str(), op->dtype, op->a, op->b, os); + } +} + void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { std::ostringstream oss; os << "("; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 8b365f85d6e6..e668f75b2ec2 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -74,6 +74,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const AndNode* op, std::ostream& os) final; void VisitExpr_(const OrNode* op, std::ostream& os) final; void VisitExpr_(const SelectNode* op, std::ostream& os) final; + void VisitExpr_(const ModNode* op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/tests/python/relay/opencl_texture/test_relay_ops.py b/tests/python/relay/opencl_texture/test_relay_ops.py new file mode 100644 index 000000000000..686a9a9b9e89 --- /dev/null +++ b/tests/python/relay/opencl_texture/test_relay_ops.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm + + +executor_type = tvm.testing.parameter("ge", "vm") +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_mod(remote, target, executor_type, dtype): + # NCHW + input_shape = (1, 25, 38, 64) + A = relay.var("data", shape=input_shape, dtype=dtype) + scale = relay.const(2.0, dtype=dtype) + op = relay.mod(A, scale) + mod = relay.Function([A], op) + + if executor_type == "ge": + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_scatter_nd_add(remote, target, executor_type, dtype): + # NCHW + + A = relay.var("data", shape=(6, 30, 30, 256), dtype=dtype) + indices = relay.const(tvm.nd.array(np.random.randint(0, 1, (2, 6, 30, 30))), dtype="int64") + update = relay.const( + tvm.nd.array(np.random.uniform(-1, 1, size=(50, 50, 256)).astype(dtype)), dtype=dtype + ) + op = relay.scatter_nd(update, indices, A, mode="add") + mod = relay.Function([A], op) + shape_dict = { + "data": (6, 30, 30, 256), + } + dtype_dict = { + "data": dtype, + } + + if executor_type == "ge": + build_run_compare(remote, mod, {}, shape_dict, dtype_dict, target) + else: + build_run_compare_vm(remote, mod, {}, shape_dict, dtype_dict, target) + + +if __name__ == "__main__": + tvm.testing.main()