Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/tvm/driver/tvmc/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,25 @@ def _generate_codegen_args(parser, codegen_name):
for tvm_type, python_type in INTERNAL_TO_NATIVE_TYPE.items():
if field.type_info.startswith(tvm_type):
target_option = field.name
default_value = None

# Retrieve the default value string from attrs(field) of config node
# Eg: "default=target_cpu_name"
target_option_default_str = field.type_info.split("default=")[1]

# Extract the defalut value based on the tvm type
if target_option_default_str and tvm_type == "runtime.String":
default_value = target_option_default_str
elif target_option_default_str and tvm_type == "IntImm":
# Extract the numeric value from the python Int string, Eg: T.int64(8)
str_slice = target_option_default_str.split("(")[1]
default_value = str_slice.split(")")[0]

target_group.add_argument(
f"--target-{codegen_name}-{target_option}",
type=python_type,
help=field.description,
default=default_value,
)


Expand Down Expand Up @@ -133,7 +148,6 @@ def reconstruct_target_args(args):
codegen_options = _reconstruct_codegen_args(args, codegen_name)
if codegen_options:
reconstructed[codegen_name] = codegen_options

return reconstructed


Expand Down
16 changes: 16 additions & 0 deletions tests/python/driver/tvmc/test_target_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def test_target_to_argparse_for_mrvl_hybrid():
assert parsed.target_mrvl_mcpu == "cnf10kb"


@tvm.testing.requires_mrvl
def test_default_arg_for_mrvl_hybrid():
parser = argparse.ArgumentParser()
generate_target_args(parser)
parsed, _ = parser.parse_known_args(
[
"--target=mrvl, llvm",
]
)
assert parsed.target == "mrvl, llvm"
assert parsed.target_mrvl_mcpu == "cn10ka"
assert parsed.target_mrvl_num_tiles == 8


@tvm.testing.requires_cmsisnn
def test_mapping_target_args():
parser = argparse.ArgumentParser()
generate_target_args(parser)
Expand Down Expand Up @@ -129,6 +144,7 @@ def test_ethosu_compiler_attrs():
}


@tvm.testing.requires_cmsisnn
def test_skip_target_from_codegen():
parser = argparse.ArgumentParser()
generate_target_args(parser)
Expand Down