Skip to content

[Relax][Bug] Segmentation fault when using the MergeCompositeFunctions transform #17120

@Cookiee235

Description

@Cookiee235

Actual behavior

Segmentation fault (core dumped)

Environment

TVM: 0.17.dev0
OS: Ubuntu20.04

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(x11: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0 in range(T.int64(10)):
            with T.block("compute"):
                v_i0 = T.axis.spatial(T.int64(10), i0)
                T.reads(x11[v_i0])
                T.writes(compute[v_i0])
                compute[v_i0] = T.max(x11[v_i0], T.float32(0))

    @R.function(private=True)
    def fused_relax_nn_gelu(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
        cls = Module
        with R.dataflow():
            gv3 = R.nn.gelu(x21)
            R.output(gv3)
        return gv3

    @R.function(private=True)
    def fused_relax_nn_relu(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
        cls = Module
        with R.dataflow():
            # gv2 = R.call_tir(cls.relu, (x11,), out_sinfo=R.Tensor((10,), dtype="float32"))
            gv2 = R.nn.relu(x11)
            R.output(gv2)
        return gv2

    @R.function
    def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(x1)
            lv2 = R.call_tir(cls.relu, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32"))
            lv3: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv2)
            R.output(lv3)
        return lv3

mod = Module
mod.show()
mod = relax.transform.MergeCompositeFunctions()(mod)  #seg fault

Triage

  • needs-triage

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions