Skip to content

Commit 932b704

Browse files
authored
Broadcast the propagation_expr for vector mode AD (#93)
* Broadcast the `propagation_expr` for vector mode AD * Use the `muladd` macro to optimize `propagation_expr` * New release * Fix propagation_expr * Explicit broadcast * Revert "Explicit broadcast" This reverts commit b7d1da2.
2 parents de61cd3 + 91b0be4 commit 932b704

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.5.1"
3+
version = "0.5.2"
4+
5+
[deps]
6+
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
47

58
[compat]
69
julia = "^1.0"
10+
MuladdMacro = "0.2.1"
711

812
[extras]
913
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/rule_definition_tools.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
2+
using MuladdMacro: @muladd
23

34
"""
45
@scalar_rule(f(x₁, x₂, ...),
@@ -208,9 +209,25 @@ end
208209
function propagation_expr(Δs, ∂s)
209210
# This is basically Δs ⋅ ∂s
210211
∂s = map(esc, ∂s)
212+
n∂s = length(∂s)
211213

212-
∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s))
213-
return :(+($(∂_mul_Δs...)))
214+
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
215+
# literals.
216+
∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
217+
218+
# Avoiding the extra `+` operation, it is potentially expensive for vector
219+
# mode AD.
220+
sumed_∂_mul_Δs = if n∂s > 1
221+
# we use `@.` to broadcast `*` and `+`
222+
:(@. +($(∂_mul_Δs...)))
223+
else
224+
# Note: we don't want to do broadcasting with only 1 multiply (no `+`),
225+
# because some arrays overload multiply with scalar. Avoiding
226+
# broadcasting saves compilation time.
227+
∂_mul_Δs[1]
228+
end
229+
230+
return :(@muladd $sumed_∂_mul_Δs)
214231
end
215232

216233
"""

0 commit comments

Comments
 (0)