Skip to content

Conversation

@jwfromm
Copy link
Contributor

@jwfromm jwfromm commented May 3, 2022

This PR adds a new pass to relay.transform that creates a dispatcher around an input module to handle multiple input shapes. For example consider a case where I'd like to optimize my model to handle both batch_size=1 and batch_size=4. I can now do so elegantly as follows:

shape_dict = {'data': [1, 3, 224, 224]}
model_bs1 = tvmc.load('my_model.onnx', shape_dict=shape_dict)
tvmc.tune(model, log_file='batch_1.logs')

shape_dict = {'data': [4, 3, 224, 224]}
model_bs4 = tvmc.load('my_model.onnx', shape_dict=shape_dict)
tvmc.tune(model, log_file='batch_4.logs')

# Create dispatcher for multiple batch sizes.
flex_mod = relay.transform.FlexibleShapeDispatch(buckets=[1, 4])(model_bs1.mod)

with ApplyHistoryBest(['batch_1.logs', 'batch_4.logs']):
    exe = relay.vm.compile(flex_mod, "llvm")

# Now we can run inputs with either batch 1 or batch 4 and get the tuned performance!
batch_1 = np.random.rand(1, 3, 224, 224).astype("float32")
vm.benchmark(tvm.cpu(), batch_2, func_name="main")

batch_4 = np.random.rand(4, 3, 224, 224).astype("float32")
vm.benchmark(tvm.cpu(), batch_4, func_name="main")

As seen above FlexibleShapeDispatch is a simple halfway point between fully static and fully dynamic graphs that allows us to leverage TVM tuning. If an input shape is not provided in buckets, it will either run fully dynamically using relay.Any, or if the auto_pad argument is set for FlexibleShapeDispatch, padding will be applied to match the closest bucket.

There are a few special cases that this pass handles. Multiple dynamic inputs (like those you might see in BERT) can be handled by setting input_indices to indicate which inputs have a dynamic axis. affects_output can be set to False for cases where the output shape is not dependent on input dynamism which could occur in dynamic resolution cases or something.

To make applying tuning logs more convenient, I also added the ability to load and merge multiple files to both autotvm and autoscheduler.

Thanks @jroesch for providing the backbone of this implementation.

@jwfromm jwfromm requested review from AndrewZhaoLuo and jroesch May 3, 2022 02:29
Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo feedback

@jroesch
Copy link
Member

jroesch commented May 3, 2022

cc @mbs-octoml

@jwfromm
Copy link
Contributor Author

jwfromm commented May 3, 2022

@jroesch I think the documentation should now be substantially improved based on your input. Let me know what you think.

Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some optional nits:

  • consider asserting dim is Any in override_shape to avoid surprises
  • rename dim -> axis, value -> dim
  • should we force a TypeInfer afterwards to be sure the Any->dim subst is sound?

@jwfromm
Copy link
Contributor Author

jwfromm commented May 4, 2022

All feedback has been addressed and tests are green. @jroesch do you want to give this one more pass?

@jwfromm jwfromm merged commit 98aa41e into apache:main May 6, 2022
shtinsa pushed a commit to Deelvin/tvm that referenced this pull request May 17, 2022
* Added pass that creates a semi-dynamic dispatcher around a relay module.

* Added automatic padding feature.

* Output slicing working.

* Multiple input support working i think.

* Added test file.

* Improve comments.

* Fix lint.

* Allow default values.

* Fix docstring.

* Improved documentation based on feedback.

* Add extra check for record loading.

* Improve variable names.

* Add type inference to make sure things worked.

* Added support for multiple outputs.
SebastianBoblest pushed a commit to SebastianBoblest/tvm that referenced this pull request May 27, 2022
* Added pass that creates a semi-dynamic dispatcher around a relay module.

* Added automatic padding feature.

* Output slicing working.

* Multiple input support working i think.

* Added test file.

* Improve comments.

* Fix lint.

* Allow default values.

* Fix docstring.

* Improved documentation based on feedback.

* Add extra check for record loading.

* Improve variable names.

* Add type inference to make sure things worked.

* Added support for multiple outputs.
@jwfromm jwfromm deleted the flexible_shape_pass branch April 12, 2023 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants