-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
When using te.extern_primfunc, we need to ensure the tensors we pass in are the same size as the buffers we are loading them into. This is done by python/tvm/te/operation.py, which has the following code:
def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs):
access_map = {
k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
}
in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
input_buffers = in_buffers
# [some lines omitted for brevity]
for obuf in out_buffers:
assert len(input_buffers) == len(input_tensors), (
"The number of provided input input_tensors does not match the number of ",
"input buffers in the primfunc",
)
for tensor, buffer in zip(input_tensors, input_buffers):
assert len(tensor.shape) == len(buffer.shape)
for d1, d2 in zip(tensor.shape, buffer.shape):
assert d1 == d2, (
"The input input_tensors provided do not match the input buffers in the ",
"primfunc. Please check that the order of input te.Input_Tensors and the ",
"order of the primfunc variables in the params list agree.",
)Specifically, by using zip this code requires that the order of tensors in the list input_tensors match the order of buffers in the dictionary input_buffers. It's easy to see how this might pose a problem.
However, the issue is slightly more nuanced. For Python 3.6 and above, dictionaries preserve order. Furthermore, there are unittests for the code in question - why doesn't it always break?
While extern_primfunc probably shouldn't use dictionaries to store buffers since order is important, the issue isn't here. Instead, the problem is that tvm.arith._ffi_api.DomainTouchedAccessMap returns a tvm.ir.container.Map. When used with a small number of elements (e.g. 3), tvm.ir.container.Map happens to preserve order, but it does not guarantee this (which is why this bug is not detected by the unit tests). However, when a larger number (e.g. 5) elements are present in the map, an error is produced:
E tvm._ffi.base.TVMError: Traceback (most recent call last):
E 22: TVMFuncCall
[some lines omitted for brevity]
E File "/home/guberti/tvm/python/tvm/te/operation.py", line 440, in extern_primfunc
E assert len(tensor.shape) == len(buffer.shape)
E TVMError: AssertionError
Suggested fix
As far as I can tell, extern_primfunc needs to be given the buffers in the same order as tensors. Thus, DomainTouchedAccessMap should be changed to not return a map, and instead return an list of tuples. I'm blocked by this bug, so I'm just gonna fix it.
- tir:arith
CC @areusch, @Lunderberg