@@ -111,16 +111,113 @@ end
111111
112112function ChainRulesCore. rrule (s:: Sinus , x:: AbstractVector , y:: AbstractVector )
113113 d = x - y
114- sind = sinpi .(d)
115- abs2_sind_r = abs2 .(sind) ./ s. r
114+ abs2_sind_r = (sinpi .(d) ./ s. r) .^ 2
116115 val = sum (abs2_sind_r)
117- gradx = twoπ .* cospi .(d) .* sind ./ ( s. r .^ 2 )
116+ gradx = π .* sinpi .( 2 .* d) ./ s. r .^ 2
118117 function evaluate_pullback (Δ:: Any )
119- return (r= - 2 Δ .* abs2_sind_r,), Δ * gradx, - Δ * gradx
118+ r̄ = - 2 Δ .* abs2_sind_r ./ s. r
119+ s̄ = ChainRulesCore. Tangent {typeof(s)} (; r= r̄)
120+ return s̄, Δ * gradx, - Δ * gradx
120121 end
121122 return val, evaluate_pullback
122123end
123124
125+ function ChainRulesCore. rrule (
126+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix ; dims= 2
127+ )
128+ project_x = ProjectTo (x)
129+ function pairwise_pullback (z̄)
130+ Δ = unthunk (z̄)
131+ n = size (x, dims)
132+ x̄ = collect (zero (x))
133+ r̄ = zero (d. r)
134+ if dims == 1
135+ for j in 1 : n, i in 1 : n
136+ xi = view (x, i, :)
137+ xj = view (x, j, :)
138+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- xj)) ./ d. r .^ 2
139+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
140+ x̄[i, :] += ds
141+ x̄[j, :] -= ds
142+ end
143+ elseif dims == 2
144+ for j in 1 : n, i in 1 : n
145+ xi = view (x, :, i)
146+ xj = view (x, :, j)
147+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
148+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
149+ x̄[:, i] .+ = ds
150+ x̄[:, j] .- = ds
151+ end
152+ end
153+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
154+ return NoTangent (), d̄, @thunk (project_x (x̄))
155+ end
156+ return Distances. pairwise (d, x; dims), pairwise_pullback
157+ end
158+
159+ function ChainRulesCore. rrule (
160+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix ; dims= 2
161+ )
162+ project_x = ProjectTo (x)
163+ project_y = ProjectTo (y)
164+ function pairwise_pullback (z̄)
165+ Δ = unthunk (z̄)
166+ n = size (x, dims)
167+ m = size (y, dims)
168+ x̄ = collect (zero (x))
169+ ȳ = collect (zero (y))
170+ r̄ = zero (d. r)
171+ if dims == 1
172+ for j in 1 : m, i in 1 : n
173+ xi = view (x, i, :)
174+ yj = view (y, j, :)
175+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- yj)) ./ d. r .^ 2
176+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
177+ x̄[i, :] .+ = ds
178+ ȳ[j, :] .- = ds
179+ end
180+ elseif dims == 2
181+ for j in 1 : m, i in 1 : n
182+ xi = view (x, :, i)
183+ yj = view (y, :, j)
184+ ds = π .* Δ[i, j] .* sinpi .(2 .* (xi .- yj)) ./ d. r .^ 2
185+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
186+ x̄[:, i] .+ = ds
187+ ȳ[:, j] .- = ds
188+ end
189+ end
190+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
191+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
192+ end
193+ return Distances. pairwise (d, x, y; dims), pairwise_pullback
194+ end
195+
196+ function ChainRulesCore. rrule (
197+ :: typeof (Distances. colwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix
198+ )
199+ project_x = ProjectTo (x)
200+ project_y = ProjectTo (y)
201+ function colwise_pullback (z̄)
202+ Δ = unthunk (z̄)
203+ n = size (x, 2 )
204+ x̄ = collect (zero (x))
205+ ȳ = collect (zero (y))
206+ r̄ = zero (d. r)
207+ for i in 1 : n
208+ xi = view (x, :, i)
209+ yi = view (y, :, i)
210+ ds = π .* Δ[i] .* sinpi .(2 .* (xi .- yi)) ./ d. r .^ 2
211+ r̄ .- = 2 .* Δ[i] .* sinpi .(xi .- yi) .^ 2 ./ d. r .^ 3
212+ x̄[:, i] .+ = ds
213+ ȳ[:, i] .- = ds
214+ end
215+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
216+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
217+ end
218+ return Distances. colwise (d, x, y), colwise_pullback
219+ end
220+
124221# # Reverse Rules for matrix wrappers
125222
126223function ChainRulesCore. rrule (:: Type{<:ColVecs} , X:: AbstractMatrix )
0 commit comments