Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,28 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (ifm == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<EthosuConv2DAttrs>();
CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr.";
CHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_conv2d type(uint8) or type(int8) for ifm but was " << ifm->dtype;
CHECK(weight->dtype == DataType::UInt(8) || weight->dtype == DataType::Int(8))
<< "Expected ethosu_conv2d type(uint8) or type(int8) for weight but was " << weight->dtype;
CHECK(scale_bias->dtype == DataType::UInt(8))
<< "Expected ethosu_conv2d type(uint8) for scale_bias but was " << scale_bias->dtype;

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

// The scale_bias should be provided as a tensor of size {ofm_channels, 10}
reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8)));
Expand Down
33 changes: 24 additions & 9 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,30 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At

const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";
ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was "
<< ifm->dtype;
ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was "
<< weight->dtype;
ICHECK(scale_bias->dtype == DataType::UInt(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was "
<< scale_bias->dtype;

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};
Expand Down
12 changes: 7 additions & 5 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,15 @@ def make_ethosu_conv2d(
ifm_layout="NHWC",
ofm_layout="NHWC",
weight_dtype="int8",
scale_bias_dtype="uint8",
):
# conv params
weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels)
padding = get_pad_tuple(padding, kernel_shape)

scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8")
scale_bias = relay.const(scale_bias_data, dtype="uint8")
weight_data = generate_weights_data(weight_shape, "int8")
scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype)
scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype)
weight_data = generate_weights_data(weight_shape, weight_dtype)
weight = relay.const(weight_data, dtype=weight_dtype)
conv = ethosu_ops.ethosu_conv2d(
ifm,
Expand Down Expand Up @@ -427,13 +428,14 @@ def make_ethosu_depthwise_conv2d(
ifm_layout="NHWC",
ofm_layout="NHWC",
weight_dtype="int8",
scale_bias_dtype="uint8",
):
# params
weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1)
padding = get_pad_tuple(padding, kernel_shape)

scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8")
scale_bias = relay.const(scale_bias_data, dtype="uint8")
scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype)
scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype)
weight_data = generate_weights_data(weight_shape, weight_dtype)
weight = relay.const(weight_data, dtype=weight_dtype)
depthwise = ethosu_ops.ethosu_depthwise_conv2d(
Expand Down
55 changes: 55 additions & 0 deletions tests/python/contrib/test_ethosu/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ def test_ethosu_conv2d_type_inference(
assert tuple(func.body.checked_type.shape) == ofm_shape


@pytest.mark.parametrize(
"ifm_dtype,weight_dtype,scale_bias_dtype",
[("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")],
)
def test_ethosu_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype):
ifm_channels = 55
ofm_channels = 122
kernel_shape = (3, 2)
padding = (0, 1, 2, 3)
strides = (1, 2)
dilation = (2, 1)
ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype)
conv2d = make_ethosu_conv2d(
ifm,
ifm_channels,
ofm_channels,
kernel_shape,
padding,
strides,
dilation,
weight_dtype=weight_dtype,
scale_bias_dtype=scale_bias_dtype,
)
func = relay.Function([ifm], conv2d)
with pytest.raises(TVMError):
run_opt_pass(func, relay.transform.InferType())


@pytest.mark.parametrize(
"ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")]
)
Expand Down Expand Up @@ -94,6 +122,33 @@ def test_ethosu_depthwise_conv2d_type_inference(
assert tuple(func.body.checked_type.shape) == ofm_shape


@pytest.mark.parametrize(
"ifm_dtype,weight_dtype,scale_bias_dtype",
[("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")],
)
def test_ethosu_depthwise_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype):
channels = 55
kernel_shape = (3, 2)
padding = (0, 1, 2, 3)
strides = (1, 2)
dilation = (2, 1)
dilation = (2, 1)
ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype)
depthwise_conv2d = make_ethosu_depthwise_conv2d(
ifm,
channels,
kernel_shape,
padding,
strides,
dilation,
weight_dtype=weight_dtype,
scale_bias_dtype=scale_bias_dtype,
)
func = relay.Function([ifm], depthwise_conv2d)
with pytest.raises(TVMError):
run_opt_pass(func, relay.transform.InferType())


@pytest.mark.parametrize(
"ifm_shape, ifm_layout", [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")]
)
Expand Down