Skip to content

Commit

Permalink
[Relay] add ShapeFunc for tanh (apache#6898)
Browse files Browse the repository at this point in the history
* add ShapeFunc for tanh

* _schedule_dense_small_batch turn autotvm off when dense's inner dim is unknown

* fix CI pylint
  • Loading branch information
monklof authored and trevor-m committed May 11, 2021
1 parent 91b2783 commit 668a004
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("clip", False, elemwise_shape_func)
register_shape_func("log2", False, elemwise_shape_func)
register_shape_func("sigmoid", False, elemwise_shape_func)
register_shape_func("tanh", False, elemwise_shape_func)
register_shape_func("isnan", False, elemwise_shape_func)
register_shape_func("isinf", False, elemwise_shape_func)
register_shape_func("logical_not", False, elemwise_shape_func)
Expand Down
25 changes: 19 additions & 6 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,26 @@ def _callback(op):


def _schedule_dense_small_batch(cfg, s, C):
A, _ = C.op.input_tensors
_, in_dim = get_const_tuple(A.shape)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
A, weights = C.op.input_tensors
_, in_dim_weights = get_const_tuple(weights.shape)
_, in_dim_A = get_const_tuple(A.shape)

if isinstance(in_dim_A, int):
in_dim = in_dim_A
elif isinstance(in_dim_weights, int):
in_dim = in_dim_weights
else:
in_dim = None

if in_dim is not None:
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
_, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
else:
tile_k = 64
_, kf = s[C].split(C.op.reduce_axis[0], tile_k)

_, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
CF = s.rfactor(C, kf)

if C.op in s.outputs:
Expand Down

0 comments on commit 668a004

Please sign in to comment.