Skip to content

Commit

Permalink
Throw exception in TensorPrimitives for unsupported span overlaps (do…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored and michaelgsharp committed Oct 20, 2023
1 parent fdff01f commit bbd26a2
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,7 @@
<data name="Argument_SpansMustHaveSameLength" xml:space="preserve">
<value>Input span arguments must all have the same length.</value>
</data>
</root>
<data name="Argument_InputAndDestinationSpanMustNotOverlap" xml:space="preserve">
<value>The destination span may only overlap with an input span if the two spans start at the same memory location.</value>
</data>
</root>

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,8 @@ private static unsafe void InvokeSpanIntoSpan<TUnaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;
Expand Down Expand Up @@ -1313,6 +1315,9 @@ private static unsafe void InvokeSpanSpanIntoSpan<TBinaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down Expand Up @@ -1428,6 +1433,8 @@ private static unsafe void InvokeSpanScalarIntoSpan<TBinaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;
Expand Down Expand Up @@ -1553,6 +1560,10 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);
ValidateInputOutputSpanNonOverlapping(z, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float zRef = ref MemoryMarshal.GetReference(z);
Expand Down Expand Up @@ -1681,6 +1692,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down Expand Up @@ -1814,6 +1828,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(z, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float zRef = ref MemoryMarshal.GetReference(z);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ private static void InvokeSpanIntoSpan<TUnaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;
Expand Down Expand Up @@ -354,6 +356,9 @@ private static void InvokeSpanSpanIntoSpan<TBinaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down Expand Up @@ -408,6 +413,8 @@ private static void InvokeSpanScalarIntoSpan<TBinaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;
Expand Down Expand Up @@ -467,6 +474,10 @@ private static void InvokeSpanSpanSpanIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);
ValidateInputOutputSpanNonOverlapping(z, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float zRef = ref MemoryMarshal.GetReference(z);
Expand Down Expand Up @@ -531,6 +542,9 @@ private static void InvokeSpanSpanScalarIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(y, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float yRef = ref MemoryMarshal.GetReference(y);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down Expand Up @@ -596,6 +610,9 @@ private static void InvokeSpanScalarSpanIntoSpan<TTernaryOperator>(
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);
ValidateInputOutputSpanNonOverlapping(z, destination);

ref float xRef = ref MemoryMarshal.GetReference(x);
ref float zRef = ref MemoryMarshal.GetReference(z);
ref float dRef = ref MemoryMarshal.GetReference(destination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@ public static void ThrowArgument_SpansMustHaveSameLength() =>
[DoesNotReturn]
public static void ThrowArgument_SpansMustBeNonEmpty() =>
throw new ArgumentException(SR.Argument_SpansMustBeNonEmpty);

[DoesNotReturn]
public static void ThrowArgument_InputAndDestinationSpanMustNotOverlap() =>
throw new ArgumentException(SR.Argument_InputAndDestinationSpanMustNotOverlap, "destination");
}
}
Loading

0 comments on commit bbd26a2

Please sign in to comment.