Skip to content

Commit 674c359

Browse files
Craigacprachguo
authored and
rachguo
committed
String Tensor SplitToSequence fix (#19942)
1 parent 0941cc7 commit 674c359

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

onnxruntime/core/providers/cpu/sequence/sequence_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
453453
int num_remaining_splits = 0;
454454
InlinedVector<int64_t> split_sizes;
455455
const bool is_string_type = input.IsDataTypeString();
456-
const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
456+
const size_t element_size = input.DataType()->Size();
457457

458458
// figure out split_scalar or split_sizes
459459
if (p_split_input) {

onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc

+13
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
442442
test.Run();
443443
}
444444

445+
TEST(SequenceOpsTest, SplitToSequence_StringSplit) {
446+
OpTester test("SplitToSequence", 11);
447+
test.AddInput<std::string>("input", {3}, std::vector<std::string>({"Test string", "Another string", "A third and much longer string"}));
448+
int64_t axis = 0;
449+
test.AddAttribute("axis", axis);
450+
SeqTensors<std::string> output;
451+
output.AddTensor({1}, {"Test string"});
452+
output.AddTensor({1}, {"Another string"});
453+
output.AddTensor({1}, {"A third and much longer string"});
454+
test.AddSeqOutput("S2", output);
455+
test.Run();
456+
}
457+
445458
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
446459
OpTester test("SplitToSequence", 11);
447460
test.AddInput<float>("input", {5, 2}, GetConsecutiveVector<float>(1.f, 10));

0 commit comments

Comments
 (0)