Skip to content

[Bug] DomainTouchedAccessMap does not preserve buffer order #13330

@guberti

Description

@guberti

When using te.extern_primfunc, we need to ensure the tensors we pass in are the same size as the buffers we are loading them into. This is done by python/tvm/te/operation.py, which has the following code:

def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs):
    access_map = {
        k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
    }
    in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
    input_buffers = in_buffers
    # [some lines omitted for brevity]
    for obuf in out_buffers:

    assert len(input_buffers) == len(input_tensors), (
        "The number of provided input input_tensors does not match the number of ",
        "input buffers in the primfunc",
    )
    for tensor, buffer in zip(input_tensors, input_buffers):
        assert len(tensor.shape) == len(buffer.shape)
        for d1, d2 in zip(tensor.shape, buffer.shape):
            assert d1 == d2, (
                "The input input_tensors provided do not match the input buffers in the ",
                "primfunc. Please check that the order of input te.Input_Tensors and the ",
                "order of the primfunc variables in the params list agree.",
            )

Specifically, by using zip this code requires that the order of tensors in the list input_tensors match the order of buffers in the dictionary input_buffers. It's easy to see how this might pose a problem.

However, the issue is slightly more nuanced. For Python 3.6 and above, dictionaries preserve order. Furthermore, there are unittests for the code in question - why doesn't it always break?

While extern_primfunc probably shouldn't use dictionaries to store buffers since order is important, the issue isn't here. Instead, the problem is that tvm.arith._ffi_api.DomainTouchedAccessMap returns a tvm.ir.container.Map. When used with a small number of elements (e.g. 3), tvm.ir.container.Map happens to preserve order, but it does not guarantee this (which is why this bug is not detected by the unit tests). However, when a larger number (e.g. 5) elements are present in the map, an error is produced:

E           tvm._ffi.base.TVMError: Traceback (most recent call last):
E             22: TVMFuncCall
[some lines omitted for brevity]
E             File "/home/guberti/tvm/python/tvm/te/operation.py", line 440, in extern_primfunc
E               assert len(tensor.shape) == len(buffer.shape)
E           TVMError: AssertionError

Suggested fix

As far as I can tell, extern_primfunc needs to be given the buffers in the same order as tensors. Thus, DomainTouchedAccessMap should be changed to not return a map, and instead return an list of tuples. I'm blocked by this bug, so I'm just gonna fix it.

  • tir:arith

CC @areusch, @Lunderberg

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions