|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local |
| 18 | +"""Dispatch sort and scan operators to platform dependent implementation.""" |
| 19 | + |
| 20 | +from tvm import topi |
| 21 | +from tvm.ir import Op |
| 22 | +from tvm.ir.module import IRModule |
| 23 | +from tvm.ir.transform import PassContext, module_pass |
| 24 | +from tvm.target import Target |
| 25 | +from tvm.contrib.thrust import can_use_thrust |
| 26 | +from tvm.relax import Expr, Function, Call, PyExprMutator, expr_functor, TensorStructInfo |
| 27 | + |
| 28 | + |
| 29 | +@expr_functor.mutator |
| 30 | +class SortScanDispatcher(PyExprMutator): |
| 31 | + """ |
| 32 | + Dispatcher to dispatch sort and scan. |
| 33 | +
|
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self, mod): |
| 37 | + super().__init__(mod) |
| 38 | + |
| 39 | + def _get_target(self, expr: Expr) -> Target: |
| 40 | + sinfo = expr.struct_info |
| 41 | + # Get target information from TensorStructInfo |
| 42 | + if isinstance(sinfo, TensorStructInfo): |
| 43 | + vdevice = sinfo.vdevice |
| 44 | + if vdevice is not None: |
| 45 | + return vdevice.target |
| 46 | + # Return the target in current context |
| 47 | + target = Target.current() |
| 48 | + if target is None: |
| 49 | + raise ValueError( |
| 50 | + "Target not found. Please ensure that the target is annotated within the module, " |
| 51 | + "or alternatively, execute this within a specified target context." |
| 52 | + ) |
| 53 | + return target |
| 54 | + |
| 55 | + def visit_call_(self, call: Call) -> Expr: |
| 56 | + if not isinstance(call.op, Op): |
| 57 | + return super().visit_call_(call) |
| 58 | + |
| 59 | + if call.op.name == "relax.sort": |
| 60 | + tgt = self._get_target(call) |
| 61 | + with tgt: |
| 62 | + if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): |
| 63 | + return self.builder_.call_te( |
| 64 | + topi.cuda.sort_thrust, |
| 65 | + call.args[0], |
| 66 | + call.attrs.axis, |
| 67 | + not call.attrs.descending, |
| 68 | + ) |
| 69 | + return self.builder_.call_te( |
| 70 | + topi.cuda.sort if tgt.kind.name == "cuda" else topi.sort, |
| 71 | + call.args[0], |
| 72 | + call.attrs.axis, |
| 73 | + not call.attrs.descending, |
| 74 | + ) |
| 75 | + |
| 76 | + if call.op.name == "relax.cumsum": |
| 77 | + tgt = self._get_target(call) |
| 78 | + axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis |
| 79 | + with tgt: |
| 80 | + return self.builder_.call_te( |
| 81 | + topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum, |
| 82 | + call.args[0], |
| 83 | + axis, |
| 84 | + call.attrs.dtype, |
| 85 | + ) |
| 86 | + |
| 87 | + return super().visit_call_(call) |
| 88 | + |
| 89 | + |
| 90 | +@module_pass(opt_level=0, name="DispatchSortScan") |
| 91 | +class DispatchSortScan: |
| 92 | + """ |
| 93 | + Pass to dispatch scan and sort operators to platform dependent implementation. |
| 94 | + """ |
| 95 | + |
| 96 | + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: |
| 97 | + sort_scan_dispater = SortScanDispatcher(mod) |
| 98 | + for gv, func in mod.functions_items(): |
| 99 | + if isinstance(func, Function): |
| 100 | + func = sort_scan_dispater.visit_expr(func) |
| 101 | + sort_scan_dispater.builder_.update_func(gv, func) |
| 102 | + return sort_scan_dispater.builder_.get() |
0 commit comments