Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/contrib/ethosu/cascader/cascader_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class CascaderOptions(Object):
The maximum number of Parts in a Plan.
always_copy_size : int
The maximum size of a Tensor that will always be copied into the cascade region.
enable_striping : bool
A boolean option to enable striping

"""

Expand All @@ -50,6 +52,7 @@ def __init__(
stripe_factors: int,
max_plan_size: int,
always_copy_size: int,
enable_striping: bool = False,
):
self.__init_handle_by_constructor__(
_ffi_api.CascaderOptions,
Expand All @@ -58,4 +61,5 @@ def __init__(
stripe_factors,
max_plan_size,
always_copy_size,
enable_striping,
)
2 changes: 1 addition & 1 deletion python/tvm/contrib/ethosu/cascader/device_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _get_subkernel_propagator(
if output_layout == "NHCWB16" and input_layout == "NHWC":
transform[3][-1] = depth
elif output_layout == "NHCWB16" and input_layout == "NHCWB16":
transform[2][-1] = depth // 16
transform[2][-1] = 1 + ((depth - 1) // 16)

return Propagator(transform, ifm_propagator.offset)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/contrib/ethosu/cascader/plan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from .graph import CascaderGraph, Part, Tensor


def _generate_output_stripe_configs(part: Part, stripe_factors: int) -> List[StripeConfig]:
return list(_ffi_api.GenerateOutputStripeConfigs(part, stripe_factors))
def _generate_output_stripe_configs(
part: Part, stripe_factors: int, enable_striping: bool
) -> List[StripeConfig]:
return list(_ffi_api.GenerateOutputStripeConfigs(part, stripe_factors, enable_striping))


def _generate_single_plans(
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _cascader(te_graph, const_dict, sch):
return _cascader


def _ethos_u55_cascader(sram) -> Callable:
def _ethos_u55_cascader(sram, enable_striping) -> Callable:
# TODO(ekalda): Extract the flash info from ConstantPools once it is implemented
flash = MemoryRegion(name="FLASH", size=10**7, read_bandwidth=4, write_bandwidth=4)

Expand All @@ -368,6 +368,7 @@ def _ethos_u55_cascader(sram) -> Callable:
stripe_factors=5,
max_plan_size=10,
always_copy_size=1024,
enable_striping=enable_striping,
)
return _create_cascader(
options=cascader_options,
Expand Down Expand Up @@ -425,7 +426,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
), "Exactly one workspace pool needs to be provided for the U55 cascader"

sram = extract_memory_info(workspace_memory_pools.pools[0])
tir_mod = LowerToTIR(_ethos_u55_cascader(sram))(mod)
tir_mod = LowerToTIR(_ethos_u55_cascader(sram, util.is_striping_enabled()))(mod)
else:
tir_mod = LowerToTIR(copy_constants())(mod)

Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def is_cascader_enabled():
return compiler_attrs.enable_cascader


def is_striping_enabled():
"""Determine whether the cascader is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
return compiler_attrs.enable_striping


def get_arg_count(func):
"""Helper function to get the number of
arguments in a python function"""
Expand Down
9 changes: 6 additions & 3 deletions src/contrib/ethosu/cascader/cascader_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,27 @@ void CascaderOptionsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("stripe_factors", &stripe_factors);
v->Visit("max_plan_size", &max_plan_size);
v->Visit("always_copy_size", &always_copy_size);
v->Visit("enable_striping", &enable_striping);
}

CascaderOptions::CascaderOptions(const MemoryRegion& cascade_region, int max_proposals,
int stripe_factors, int max_plan_size, int always_copy_size) {
int stripe_factors, int max_plan_size, int always_copy_size,
bool enable_striping) {
auto n = make_object<CascaderOptionsNode>();
n->cascade_region = std::move(cascade_region);
n->max_proposals = max_proposals;
n->stripe_factors = stripe_factors;
n->max_plan_size = max_plan_size;
n->always_copy_size = always_copy_size;
n->enable_striping = enable_striping;
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderOptions")
.set_body_typed([](MemoryRegion cascade_region, int max_proposals, int stripe_factors,
int max_plan_size, int always_copy_size) {
int max_plan_size, int always_copy_size, bool enable_striping) {
return CascaderOptions(cascade_region, max_proposals, stripe_factors, max_plan_size,
always_copy_size);
always_copy_size, enable_striping);
});

TVM_REGISTER_NODE_TYPE(CascaderOptionsNode);
Expand Down
4 changes: 3 additions & 1 deletion src/contrib/ethosu/cascader/cascader_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class CascaderOptionsNode : public Object {
int max_plan_size;
/*! \brief The maximum size of Tensor that will always be copied into the cascade region. */
int always_copy_size;
/*! \brief A boolean option to enable striping. */
bool enable_striping;

static constexpr const char* _type_key = "contrib.ethosu.cascader.CascaderOptions";
TVM_DECLARE_FINAL_OBJECT_INFO(CascaderOptionsNode, Object)
Expand All @@ -58,7 +60,7 @@ class CascaderOptionsNode : public Object {
class CascaderOptions : public ObjectRef {
public:
CascaderOptions(const MemoryRegion& cascade_region, int max_proposals, int stripe_factors,
int max_plan_size, int always_copy_size);
int max_plan_size, int always_copy_size, bool enable_striping = true);

TVM_DEFINE_OBJECT_REF_METHODS(CascaderOptions, ObjectRef, CascaderOptionsNode);
};
Expand Down
12 changes: 7 additions & 5 deletions src/contrib/ethosu/cascader/plan_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ std::vector<bool> GetCascadableAxes(const Part& part) {
return cascadable_axes;
}

std::vector<StripeConfig> GenerateOutputStripeConfigs(const Part& part, int stripe_factors) {
std::vector<StripeConfig> GenerateOutputStripeConfigs(const Part& part, int stripe_factors,
bool enable_striping) {
// If stripe_factors is <= 0, then we won't produce any StripeConfigs
if (stripe_factors <= 0) {
return std::vector<StripeConfig>();
Expand Down Expand Up @@ -134,7 +135,7 @@ std::vector<StripeConfig> GenerateOutputStripeConfigs(const Part& part, int stri
auto axis = output_shape[i];
auto axis_align = part->GetStripeAlignHint()[i];
std::set<int> axis_splits; // Note this is a set to remove duplicate splits
if (!cascadable_axes[i]) {
if (!cascadable_axes[i] || (!enable_striping)) {
axis_splits.insert(axis);
} else {
for (float factor : factors) {
Expand Down Expand Up @@ -436,7 +437,7 @@ std::unordered_map<std::vector<Part>, std::vector<Plan>> GenerateGraphPlans(
// output of a Plan. The number generated is a function of stripe_factors and the number of
// cascadable dimensions in the Part.
std::vector<StripeConfig> stripe_configs =
GenerateOutputStripeConfigs(part, options->stripe_factors);
GenerateOutputStripeConfigs(part, options->stripe_factors, options->enable_striping);
// Check to see if the output Tensor is part of any existing open Plans
if (stripe_configs_by_tensor.find(part->GetOutputTensor()) != stripe_configs_by_tensor.end()) {
// If there are other open Plans which have this Part's output Tensor as an input, then
Expand Down Expand Up @@ -514,11 +515,12 @@ std::unordered_map<std::vector<Part>, std::vector<Plan>> GenerateGraphPlans(
}

TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GenerateOutputStripeConfigs")
.set_body_typed([](Part part, int stripe_factors) {
.set_body_typed([](Part part, int stripe_factors, bool enable_striping) {
if (stripe_factors < 0) {
return Array<StripeConfig>();
}
return Array<StripeConfig>(GenerateOutputStripeConfigs(part, stripe_factors));
return Array<StripeConfig>(
GenerateOutputStripeConfigs(part, stripe_factors, enable_striping));
});

TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GenerateSinglePlans")
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/contrib/ethosu/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace ethosu {
struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode> {
String accelerator_config;
bool enable_cascader;
bool enable_striping;

TVM_DECLARE_ATTRS(EthosUCompilerConfigNode, "ext.attrs.EthosUCompilerConfigNode") {
TVM_ATTR_FIELD(accelerator_config)
Expand All @@ -50,6 +51,9 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
TVM_ATTR_FIELD(enable_cascader)
.describe("Whether the cascader should be enabled")
.set_default(false);
TVM_ATTR_FIELD(enable_striping)
.describe("Whether the cascader should be striping")
.set_default(false);
}
};

Expand Down
2 changes: 2 additions & 0 deletions tests/python/contrib/test_ethosu/cascader/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def make_options(
stripe_factors: int = 1,
max_plan_size: int = 1,
always_copy_size: int = 1024,
enable_striping: bool = True,
):
return cs.CascaderOptions(
cascade_region=cascade_region,
max_proposals=max_proposals,
stripe_factors=stripe_factors,
max_plan_size=max_plan_size,
always_copy_size=always_copy_size,
enable_striping=enable_striping,
)


Expand Down
47 changes: 30 additions & 17 deletions tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .. import infra


def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascader):
def _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader, enable_striping
):
enable_usmp = True

target = tvm.target.Target("c")
Expand All @@ -52,6 +54,7 @@ def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascad
"relay.ext.ethos-u.options": {
"accelerator_config": accel_type,
"enable_cascader": enable_cascader,
"enable_striping": enable_striping,
},
"tir.usmp.enable": enable_usmp,
"tir.usmp.algorithm": "hill_climb",
Expand Down Expand Up @@ -86,7 +89,7 @@ def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascad


@pytest.mark.parametrize(
"accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader",
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
[
("ethos-u55-256", 1067408, 14096),
("ethos-u55-128", 1067408, 3968),
Expand All @@ -95,7 +98,7 @@ def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascad
],
)
def test_double_conv2d(
accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader
accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping
):
np.random.seed(1)
ifm_shape = (1, 321, 212, 6)
Expand Down Expand Up @@ -135,32 +138,37 @@ def tf_graph(x):
# Run the graph without the cascader, with lots of memory
pool_size = 2000000
workspace_size_cascader_disabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=False
mod, params, accel_type, pool_size, enable_cascader=False, enable_striping=False
)
workspace_size_cascader_enabled_striping_disabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True, enable_striping=False
)
# if striping is not done, it should be same as cacader disabled
assert workspace_size_cascader_disabled == workspace_size_cascader_enabled_striping_disabled

# Run the same graph with the cascader, giving it less memory to persuade cascder to cascade
pool_size = 600000
workspace_size_cascader_enabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True
workspace_size_cascader_enabled_striping_enabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True, enable_striping=True
)

assert workspace_size_cascader_disabled == expected_ws_size_without_cascader
assert workspace_size_cascader_enabled == expected_ws_size_with_cascader
assert workspace_size_cascader_disabled == expected_ws_size_without_striping
assert workspace_size_cascader_enabled_striping_enabled == expected_ws_size_with_striping


# TODO(ekalda): Fix a bug in the block config selection that selects block config that is too large
# for the smaller accelerators
@pytest.mark.parametrize(
"accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader",
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
[
("ethos-u55-256", 180096, 5024),
("ethos-u55-128", 180096, 4832),
pytest.param("ethos-u55-64", 180096, 4832, marks=pytest.mark.xfail),
pytest.param("ethos-u55-32", 180096, 4832, marks=pytest.mark.xfail),
("ethos-u55-64", 180096, 6464),
("ethos-u55-32", 180096, 6464),
],
)
def test_depthwise2d_conv2d_pooling(
accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader
accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping
):
np.random.seed(2)
ifm_shape = (1, 80, 75, 3)
Expand Down Expand Up @@ -210,14 +218,19 @@ def tf_graph(x):
# Run the graph without the cascader, with lots of memory
pool_size = 10**6
workspace_size_cascader_disabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=False
mod, params, accel_type, pool_size, enable_cascader=False, enable_striping=False
)
workspace_size_cascader_enabled_striping_disabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True, enable_striping=False
)
# if striping is not done, it should be same as cacader disabled
assert workspace_size_cascader_disabled == workspace_size_cascader_enabled_striping_disabled

# Run the same graph with the cascader, giving it less memory to persuade cascder to cascade
pool_size = 40000
workspace_size_cascader_enabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True
workspace_size_cascader_enabled_striping_enabled = _get_ethosu_workspace_size(
mod, params, accel_type, pool_size, enable_cascader=True, enable_striping=True
)

assert workspace_size_cascader_disabled == expected_ws_size_without_cascader
assert workspace_size_cascader_enabled == expected_ws_size_with_cascader
assert workspace_size_cascader_disabled == expected_ws_size_without_striping
assert workspace_size_cascader_enabled_striping_enabled == expected_ws_size_with_striping
32 changes: 30 additions & 2 deletions tests/python/contrib/test_ethosu/cascader/test_plan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,33 @@ def test_generate_output_stripe_configs():
tensor_1.add_consumer(part_1)
tensor_2.add_producer(part_1)

assert len(_generate_output_stripe_configs(part_1, stripe_factors)) == expected_configs
assert (
len(_generate_output_stripe_configs(part_1, stripe_factors, enable_striping=True))
== expected_configs
)


@pytest.mark.parametrize("stripe_factors", [3, 4, 8, 16, 10])
def test_generate_output_stripe_configs_disable_striping(stripe_factors):
subgraph = cs.TESubgraph([], None)
part_1 = cs.InlinePart(
subgraph,
[
cs.Propagator(
[[2, 0, 0], [0, 2, 0], [0, 0, 1]],
[0, 0],
),
],
)
tensor_1 = cs.Tensor([800, 800], "uint8")
tensor_2 = cs.Tensor([400, 400], "uint8")

part_1.set_input(0, tensor_1)
part_1.set_output(tensor_2)
tensor_1.add_consumer(part_1)
tensor_2.add_producer(part_1)

assert len(_generate_output_stripe_configs(part_1, stripe_factors, enable_striping=False)) == 1


def test_generate_single_plans(SRAM, DRAM):
Expand All @@ -74,7 +100,9 @@ def test_generate_single_plans(SRAM, DRAM):
tensor_2: [SRAM],
}
options = make_options(cascade_region=SRAM, stripe_factors=1)
output_stripe_configs = _generate_output_stripe_configs(part_1, options.stripe_factors)
output_stripe_configs = _generate_output_stripe_configs(
part_1, options.stripe_factors, enable_striping=True
)
plans = _generate_single_plans(part_1, output_stripe_configs, home_map, options)
for plan in plans:
assert plan.interior_region == SRAM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,27 @@ def test_generate_proposals_mobilenetv2diamond(FLASH, SRAM, MobileNetv2DiamondGr
assert min_sram < proposal.memory_usage < max_sram
assert proposal.cycles > 0

def test_generate_proposals_mobilenetv1_disable_striping(FLASH, SRAM, MobileNetv1Graph):
graph = MobileNetv1Graph
home_map = make_simple_home_map(graph, SRAM, FLASH)
options = make_options(
cascade_region=SRAM,
max_proposals=32,
stripe_factors=5,
max_plan_size=10,
enable_striping=False,
)

proposals = generate_proposals(graph, home_map, options)
assert len(proposals) == 1
proposal = proposals[0]
for plan in proposal.plans:
for stripe_config in plan.output_config.stripe_configs:
for shape_dim, stride_dim in list(zip(stripe_config.shape, stripe_config.strides)):
# The striding and shape sizes in each dimension should be the same
# if striping is disabled
assert int(shape_dim) == int(stride_dim)


if __name__ == "__main__":
pytest.main([__file__])