@@ -32,12 +32,12 @@ function insert_symbolic_gradient(axislist, store)
3232
3333 inbody, prebody = [], []
3434 for (dt, t) in unique (targets)
35- drdt = leibnitz (store. right, t)
35+ drdt = leibnitz (store. right, t, store . nograd )
3636 deltar = if store. finaliser == :identity
3737 simplitimes (simpliconj (drdt), :($ dZ[$ (store. leftraw... )]))
3838 else
3939 rhs = :($ ZED[$ (store. leftraw... )])
40- dldr = leibfinal (store. finaliser, rhs)
40+ dldr = leibfinal (store. finaliser, rhs, store . nograd )
4141 simplitimes (simpliconj (drdt), simpliconj (dldr), :($ dZ[$ (store. leftraw... )]))
4242 end
4343 if store. redfun == :+
@@ -84,16 +84,16 @@ function insert_symbolic_gradient(axislist, store)
8484
8585end
8686
87- leibfinal (fun:: Symbol , res) =
87+ leibfinal (fun:: Symbol , res, no = () ) =
8888 if fun == :log
8989 :(exp (- $ res)) # this exp gets done at every element :(
9090 # :(inv(exp($res)))
9191 else
92- _leibfinal (:($ fun ($ RHS)), res)
92+ _leibfinal (:($ fun ($ RHS)), res, no )
9393 end
9494
95- _leibfinal (out, res) = begin
96- grad1 = leibnitz (out, RHS)
95+ _leibfinal (out, res, no ) = begin
96+ grad1 = leibnitz (out, RHS, no )
9797 grad2 = MacroTools_postwalk (grad1) do ex
9898 # @show ex ex == out
9999 ex == out ? res : ex
@@ -103,13 +103,13 @@ _leibfinal(out, res) = begin
103103 end
104104end
105105
106- leibfinal (ex:: Expr , res) = begin
106+ leibfinal (ex:: Expr , res, no = () ) = begin
107107 if ex. head == :call && ex. args[1 ] isa Expr &&
108108 ex. args[1 ]. head == :(-> ) && ex. args[1 ]. args[1 ] == RHS # then it came from underscores
109109 inner = ex. args[1 ]. args[2 ]
110110 if inner isa Expr && inner. head == :block
111111 lines = filter (a -> ! (a isa LineNumberNode), inner. args)
112- length (lines) == 1 && return _leinfinal (first (lines), res)
112+ length (lines) == 1 && return _leibfinal (first (lines), res, no) # not tested!
113113 end
114114 end
115115 throw (" couldn't understand finaliser" )
@@ -191,9 +191,9 @@ symbwalk(targets, store) = ex -> begin
191191 return ex
192192 end
193193
194- leibnitz (s:: Number , target) = 0
195- leibnitz (s:: Symbol , target) = s == target ? 1 : 0
196- leibnitz (ex:: Expr , target) = begin
194+ leibnitz (s:: Number , target, no = () ) = 0
195+ leibnitz (s:: Symbol , target, no = () ) = s == target ? 1 : 0
196+ leibnitz (ex:: Expr , target, no = () ) = begin
197197 ex == target && return 1
198198 @capture_ (ex, B_[ijk__]) && return 0
199199 if ex. head == Symbol (" '" )
@@ -202,34 +202,35 @@ leibnitz(ex::Expr, target) = begin
202202 end
203203 ex. head == :call || throw (" expected a functionn call, got $ex ." )
204204 fun = ex. args[1 ]
205+ fun in no && return 0
205206 if fun == :log # catch log(a*b) and especially log(a/b)
206207 arg = ex. args[2 ]
207208 if arg isa Expr && arg. args[1 ] == :* && length (arg. args) == 3
208209 newex = :(log ($ (arg. args[2 ])) + log ($ (arg. args[3 ])))
209- return leibnitz (newex, target)
210+ return leibnitz (newex, target, no )
210211 elseif arg isa Expr && arg. args[1 ] == :/
211212 newex = :(log ($ (arg. args[2 ])) - log ($ (arg. args[3 ])))
212- return leibnitz (newex, target)
213+ return leibnitz (newex, target, no )
213214 end
214215 end
215216 if length (ex. args) == 2 # one-arg function
216217 fx = mydiffrule (fun, ex. args[2 ])
217- dx = leibnitz (ex. args[2 ], target)
218+ dx = leibnitz (ex. args[2 ], target, no )
218219 return simplitimes (fx, dx)
219220 elseif length (ex. args) == 3 # two-arg function
220221 fx, fy = mydiffrule (fun, ex. args[2 : end ]. .. )
221- dx = leibnitz (ex. args[2 ], target)
222- dy = leibnitz (ex. args[3 ], target)
222+ dx = leibnitz (ex. args[2 ], target, no )
223+ dy = leibnitz (ex. args[3 ], target, no )
223224 return simpliplus (simplitimes (fx, dx), simplitimes (fy, dy))
224225 elseif fun in [:+ , :* ]
225- fun == :* && return leibnitz (:(* ($ (ex. args[2 ]), * ($ (ex. args[3 : end ]. .. )))), target)
226- dxs = [leibnitz (x, target) for x in ex. args[2 : end ]]
226+ fun == :* && return leibnitz (:(* ($ (ex. args[2 ]), * ($ (ex. args[3 : end ]. .. )))), target, no )
227+ dxs = [leibnitz (x, target, no ) for x in ex. args[2 : end ]]
227228 fun == :+ && return simpliplus (dxs... )
228229 elseif length (ex. args) == 4 # three-arg function such as ifelse
229230 fx, fy, fz = mydiffrule (fun, ex. args[2 : end ]. .. )
230- dx = leibnitz (ex. args[2 ], target)
231- dy = leibnitz (ex. args[3 ], target)
232- dz = leibnitz (ex. args[4 ], target)
231+ dx = leibnitz (ex. args[2 ], target, no )
232+ dy = leibnitz (ex. args[3 ], target, no )
233+ dz = leibnitz (ex. args[4 ], target, no )
233234 return simpliplus (simplitimes (fx, dx), simplitimes (fy, dy), simplitimes (fz, dz))
234235 end
235236 throw (" don't know how to handle $ex ." )
0 commit comments