|
1 |
| -using .Distances |
| 1 | +module ZygoteDistancesExt |
| 2 | + |
| 3 | +if isdefined(Base, :get_extension) |
| 4 | + using Zygote |
| 5 | + using Distances |
| 6 | + using LinearAlgebra |
| 7 | +else |
| 8 | + using ..Zygote |
| 9 | + using ..Distances |
| 10 | + using ..LinearAlgebra |
| 11 | +end |
| 12 | + |
| 13 | +using Zygote: @adjoint, AContext, _pullback |
2 | 14 |
|
3 | 15 | @adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
|
4 | 16 | δ = x .- y
|
|
66 | 78 |
|
67 | 79 | _sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)
|
68 | 80 |
|
69 |
| -@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) |
| 81 | +function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)), |
| 82 | + kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean, |
| 83 | + X::AbstractMatrix, Y::AbstractMatrix) |
70 | 84 | # Modify the forwards-pass slightly to ensure stability on the reverse.
|
| 85 | + dims = kws.dims |
71 | 86 | function _pairwise_euclidean(sqdist::SqEuclidean, X, Y)
|
72 | 87 | D2 = pairwise(sqdist, X, Y; dims=dims)
|
73 | 88 | δ = eps(eltype(D2))
|
74 | 89 | return _sqrt_if_positive.(D2, δ)
|
75 | 90 | end
|
76 |
| - return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) |
| 91 | + res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) |
| 92 | + pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...) |
| 93 | + return res, pairwise_Euclidean_pullback |
77 | 94 | end
|
78 | 95 |
|
79 |
| -@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2) |
| 96 | +function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)), |
| 97 | + kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean, |
| 98 | + X::AbstractMatrix) |
80 | 99 | # Modify the forwards-pass slightly to ensure stability on the reverse.
|
| 100 | + dims = kws.dims |
81 | 101 | function _pairwise_euclidean(sqdist::SqEuclidean, X)
|
82 | 102 | D2 = pairwise(sqdist, X; dims=dims)
|
83 | 103 | δ = eps(eltype(D2))
|
84 | 104 | return _sqrt_if_positive.(D2, δ)
|
85 | 105 | end
|
86 |
| - return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X) |
| 106 | + res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X) |
| 107 | + pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...) |
| 108 | + return res, pairwise_Euclidean_pullback |
| 109 | +end |
| 110 | + |
87 | 111 | end
|
0 commit comments