Skip to content

Conversation

@sunggg
Copy link
Contributor

@sunggg sunggg commented Jan 24, 2022

This PR includes BYOC builder/runner infra and its test case for TensorRT.
Thanks for your time to review this.
cc: @junrushao1994

Please note that previous PR is closed due to the overlap with previous merge.

@junrushao
Copy link
Member

CC @zxybazh would love you guys to review each other’s code :-)

@junrushao junrushao changed the title [MetatSchedule] testcase for TensorRT builder/runner [MetatSchedule][M4b] Testcases for TensorRT builder/runner Jan 26, 2022
@junrushao junrushao changed the title [MetatSchedule][M4b] Testcases for TensorRT builder/runner [MetaSchedule][M4b] Testcases for TensorRT builder/runner Jan 27, 2022
Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor nitpicks

Comment on lines 30 to 32
# from tvm import script
# from tvm._ffi import register_func
# from tvm.runtime import Module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

):
if use_meta_sched:
# With meta_schedule
dev = "nvidia/geforce-rtx-2080"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dev = "nvidia/geforce-rtx-2080"
dev = "cuda"

Comment on lines 102 to 115
def relay_build_with_tensorrt(
mod: Module,
target: Target,
params: dict,
) -> List[BuilderResult]:
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt

mod, config = partition_for_tensorrt(mod, params)
with tvm.transform.PassContext(
opt_level=3, config={"relay.ext.tensorrt.options": config}
):
return tvm.relay.build_module._build_module_no_factory(
mod, "cuda", "llvm", params
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should refactor these functions, put them under python/tvm/meta_schedule/testing/byoc_trt.py, so that others could conveniently reuse these cool stuff

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target: Target,
params: dict,
) -> List[BuilderResult]:
# @Sung: Weird. Cannot pass keyword arg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have time, you may add a proxy function to _build_module_no_factory to allow kwargs

@register_func("tvm.relay.build")
def _build_module_no_factory_impl(mod, target, target_host, params, mod_name):
    target, target_host = Target.check_and_update_host_consist(target, target_host)
    return build(mod, target, params=params, mod_name=mod_name).module


def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"):
    """A wrapper around build which discards the Python GraphFactoryRuntime.
    This wrapper is suitable to be used from other programming languages as
    the runtime::Module can be freely passed between language boundaries.
    """
    return _build_module_no_factory_impl(mod, target, target_host, params, mod_name)

"model_name",
["resnet-50", "mobilenet"],
)
@pytest.mark.parametrize("batch_size", [1, 8, 16])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("batch_size", [1, 8, 16])
@pytest.mark.parametrize("batch_size", [1])

)


# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cannot reproduce this, so let's double confirm :-) If there is no problem, let's remove this line

sunggg and others added 16 commits January 28, 2022 15:45
…0010)

Adding interfaces into Pipeline Executor to "run", "stop","set input",
and "get input" from the pipeline executor,

In this patch, we also implemented the "BackendRuntime" structure to
wrap the graph runtime interface in order to support  pipeline executor
interface and implement data copy method. This method is used to
transfer data between two backend runtimes.
…che#10036)

* introduce profile_all_alignments option

* add profile_all_alignment option to API

* wip

* fixed dynamic case

* black

* update gen_gemm too

* minor improvement

* fix

* all tests work

* add doc

* fixed for sm = 75 case

* fix typo

* remove unused import

* profile_all -> find_first_valid

* fix
…10049)

* Add ApplyHisotryBest.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>

* Retrigger CI.

* Update integration.py

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
* WIP

* WIP

* WIP

* test cases

* add examples

* lint

* Amend co-authors information

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>

* WIP

* address comments and changed tensorized comparator

* update

* nit

* fix example

* lint

* lint

* lint

* remove unused

* trigger ci

* clang-format

* fix

* rebase

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
…ardware (apache#9993)

* Add env variable to micro tflite tutorial

* Address @gromero comments

* address @areusch comment

* fix scope

* trigger

* trigger
This commit introduces BaseAddress ObjectRef to determine
base addresses in the codegen for microNPU. This is
required when multiple memory pools become available. Thus,
base addresses could not be statically determined in the
source module.
… than call node (apache#10069)

Co-authored-by: pranav jonnalagadda-SJ1 Eng_ML <[email protected]>
* multi level tiling

* remove tensor core related code

* pylint

* fix

Co-authored-by: Junru Shao <[email protected]>
…)" (apache#10072)

Because of the failure of LSTM conversion from Pytorch
…ort tensoflow 2.6 (apache#9978)

On tensorflow 2.4 the test is expected to fail as the generated graph is not forzen.
On tensorflow 2.6 the generated graph is identified as frozen, therefore the test is not needed
)

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
@junrushao junrushao merged commit ba65197 into apache:main Jan 29, 2022
@junrushao
Copy link
Member

Thanks @sunggg! It's finally merged :-)

ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.