Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,83 @@ void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Dif

VECTOR_MATRIX_BINARY_DIFF_IMPL(min)

// copysign
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[PreferRecompute]
[ForwardDerivativeOf(copysign)]
DifferentialPair<T> __d_copysign(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
// copysign(x, y) = sign(y) * abs(x)
// d/dx copysign(x, y) = sign(y) * sign(x) when x != 0 and y != 0, 0 when x == 0 or y == 0
// d/dy copysign(x, y) = 0 (sign function is not differentiable w.r.t y)
let sign_y = select(dpy.p >= T(0.0), T(1.0), T(-1.0));
let sign_x = select(dpx.p >= T(0.0), T(1.0), T(-1.0));
// When x == 0 or y == 0, derivative w.r.t. x should be 0
let dx_coeff = select((dpx.p == T(0.0)) || (dpy.p == T(0.0)), T(0.0), sign_y * sign_x);
return DifferentialPair<T>(
copysign(dpx.p, dpy.p),
__mul_p_d(dx_coeff, dpx.d)
);
}

__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[PreferRecompute]
[BackwardDerivativeOf(copysign)]
void __d_copysign(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
let sign_y = select(dpy.p >= T(0.0), T(1.0), T(-1.0));
let sign_x = select(dpx.p >= T(0.0), T(1.0), T(-1.0));
// When x == 0 or y == 0, derivative w.r.t. x should be 0
let dx_coeff = select((dpx.p == T(0.0)) || (dpy.p == T(0.0)), T(0.0), sign_y * sign_x);
// Gradient flows only to x since d/dy copysign = 0
dpx = diffPair(dpx.p, __mul_p_d(dx_coeff, dOut));
dpy = diffPair(dpy.p, T.dzero());
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[PreferRecompute]
[ForwardDerivativeOf(copysign)]
DifferentialPair<vector<T, N>> __d_copysign_vector(
DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy)
{
vector<T, N> result;
vector<T, N>.Differential d_result;
[ForceUnroll] for (int i = 0; i < N; ++i)
{
DifferentialPair<T> dp_elem = __d_copysign(
DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i])),
DifferentialPair<T>(dpy.p[i], __slang_noop_cast<T.Differential>(dpy.d[i])));
result[i] = dp_elem.p;
d_result[i] = __slang_noop_cast<T>(dp_elem.d);
}
return DifferentialPair<vector<T, N>>(result, d_result);
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[PreferRecompute]
[BackwardDerivativeOf(copysign)]
void __d_copysign_vector(
inout DifferentialPair<vector<T, N>> dpx,
inout DifferentialPair<vector<T, N>> dpy,
vector<T, N>.Differential dOut)
{
vector<T, N>.Differential x_d_result, y_d_result;
[ForceUnroll] for (int i = 0; i < N; ++i)
{
DifferentialPair<T> x_dp = DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]));
DifferentialPair<T> y_dp = DifferentialPair<T>(dpy.p[i], __slang_noop_cast<T.Differential>(dpy.d[i]));
__d_copysign(x_dp, y_dp, __slang_noop_cast<T.Differential>(dOut[i]));
x_d_result[i] = __slang_noop_cast<T>(x_dp.d);
y_d_result[i] = __slang_noop_cast<T>(y_dp.d);
}
dpx = diffPair(dpx.p, x_d_result);
dpy = diffPair(dpy.p, y_d_result);
}

// Lerp
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
Expand Down
10 changes: 5 additions & 5 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -7685,7 +7685,7 @@ matrix<T, N, M> ceil(matrix<T, N, M> x)
/// @category math
__generic<let N: int>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
vector<half,N> copysign_half(vector<half,N> x, vector<half,N> y)
{
let ux = reinterpret<vector<uint16_t,N>>(x);
Expand All @@ -7702,7 +7702,7 @@ vector<half,N> copysign_half(vector<half,N> x, vector<half,N> y)
/// @category math
__generic<let N: int>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
vector<float,N> copysign_float(vector<float,N> x, vector<float,N> y)
{
let ux = reinterpret<vector<uint32_t,N>>(x);
Expand All @@ -7719,7 +7719,7 @@ vector<float,N> copysign_float(vector<float,N> x, vector<float,N> y)
/// @category math
__generic<let N: int>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
vector<double,N> copysign_double(vector<double,N> x, vector<double,N> y)
{
let ux = reinterpret<vector<uint64_t,N>>(x);
Expand All @@ -7740,7 +7740,7 @@ vector<T,N> __real_cast(vector<U,N> val);
/// @category math
__generic<T : __BuiltinFloatingPointType, let N: int>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
vector<T,N> copysign(vector<T,N> x, vector<T,N> y)
{
__target_switch
Expand All @@ -7766,7 +7766,7 @@ vector<T,N> copysign(vector<T,N> x, vector<T,N> y)

__generic<T : __BuiltinFloatingPointType>
[__readNone]
[require(cpp_cuda_glsl_hlsl_metal_spirv)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
T copysign(T x, T y)
{
__target_switch
Expand Down
108 changes: 108 additions & 0 deletions tests/autodiff-dstdlib/dstdlib-copysign.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;

[BackwardDifferentiable]
float diffCopysign(float x, float y)
{
return copysign(x, y);
}

[BackwardDifferentiable]
float2 diffCopysign(float2 x, float2 y)
{
return copysign(x, y);
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
// Test 1: forward diff copysign(3.0, -1.0) with dx=2.0, dy=1.0
{
dpfloat dpx = dpfloat(3.0, 2.0);
dpfloat dpy = dpfloat(-1.0, 1.0);
dpfloat res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[0] = res.p;
outputBuffer[1] = res.d;
}

// Test 2: forward diff copysign(-2.0, 4.0) with dx=1.5, dy=-0.5
{
dpfloat dpx = dpfloat(-2.0, 1.5);
dpfloat dpy = dpfloat(4.0, -0.5);
dpfloat res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[2] = res.p;
outputBuffer[3] = res.d;
}

// Test 3: forward diff copysign(0.0, -1.0) with dx=3.0, dy=2.0
{
dpfloat dpx = dpfloat(0.0, 3.0);
dpfloat dpy = dpfloat(-1.0, 2.0);
dpfloat res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[4] = res.p;
outputBuffer[5] = res.d;
}

// Test 4: vector forward diff
{
dpfloat2 dpx = dpfloat2(float2(5.0, -3.0), float2(1.0, 2.0));
dpfloat2 dpy = dpfloat2(float2(-2.0, 4.0), float2(0.5, -1.0));
dpfloat2 res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[6] = res.p[0];
outputBuffer[7] = res.d[0];
outputBuffer[8] = res.p[1];
outputBuffer[9] = res.d[1];
}

// Test 5: backward diff copysign(4.0, -2.0)
{
dpfloat dpx = dpfloat(4.0, 0.0);
dpfloat dpy = dpfloat(-2.0, 0.0);
__bwd_diff(diffCopysign)(dpx, dpy, 1.0);
outputBuffer[10] = dpx.d;
outputBuffer[11] = dpy.d;
}

// Test 6: backward diff copysign(-3.0, 5.0)
{
dpfloat dpx = dpfloat(-3.0, 0.0);
dpfloat dpy = dpfloat(5.0, 0.0);
__bwd_diff(diffCopysign)(dpx, dpy, 2.0);
outputBuffer[12] = dpx.d;
outputBuffer[13] = dpy.d;
}

// Test 7: copysign with y=0 - derivative should be 0
{
dpfloat dpx = dpfloat(3.0, 2.0);
dpfloat dpy = dpfloat(0.0, 1.0);
dpfloat res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[14] = res.p;
outputBuffer[15] = res.d;
}

// Test 8: copysign with x=0 - derivative should be 0
{
dpfloat dpx = dpfloat(0.0, 2.0);
dpfloat dpy = dpfloat(-1.0, 1.0);
dpfloat res = __fwd_diff(diffCopysign)(dpx, dpy);
outputBuffer[16] = res.p;
outputBuffer[17] = res.d;
}

// Test 9: vector backward diff
{
dpfloat2 dpx = dpfloat2(float2(2.0, -1.0), float2(0.0, 0.0));
dpfloat2 dpy = dpfloat2(float2(-3.0, 4.0), float2(0.0, 0.0));
__bwd_diff(diffCopysign)(dpx, dpy, float2(1.0, 3.0));
outputBuffer[18] = dpx.d[0];
outputBuffer[19] = dpx.d[1];
}
}
21 changes: 21 additions & 0 deletions tests/autodiff-dstdlib/dstdlib-copysign.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
type: float
-3.000000
-2.000000
2.000000
-1.500000
0.000000
0.000000
-5.000000
-1.000000
3.000000
-2.000000
-1.000000
0.000000
-2.000000
0.000000
3.000000
0.000000
0.000000
0.000000
-1.000000
-3.000000
Loading