Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ class DmlOperatorPadding : public DmlOperator, public PaddingHelper
}
};

void CALLBACK QueryPad(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
{
// DML_PADDING1_OPERATOR_DESC doesn't support negative padding counts i.e. StartPadding and EndPadding
// can't contain negative elements.
// For opset < 11,
// if attribute 'pads' contains negative element, fall back to CPU
// opset >= 11
// DML EP continues to produce wrong result. [TODO: After DML1.9 release, introduce new API for pad to
// handle negative values for StartPadding and EndPadding]
*isSupported = true;

MLOperatorAttributes attributes(context);

std::vector<int32_t> padding = attributes.GetOptionalAttributeVectorInt32(AttrName::Pads);
*isSupported = std::none_of(padding.begin(), padding.end(), [](int32_t padCount) {return padCount < 0; });
}

DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel<DmlOperatorPadding, 13>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Resize);
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
DML_OP_EXTERN_QUERY_FUNCTION(RecurrentNeuralNetwork);
DML_OP_EXTERN_QUERY_FUNCTION(BatchNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(Pad);

constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
Expand Down Expand Up @@ -428,7 +429,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 13, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryPad)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
Expand Down