@@ -137,6 +137,7 @@ function parse_options(exs...)
137137 )
138138 expr = nothing
139139 nograd = Symbol[]
140+ safe = Symbol[]
140141 ranges = Tuple[]
141142 for ex in exs
142143 # Actual options:
@@ -160,6 +161,16 @@ function parse_options(exs...)
160161 throw (" this accepts nograd=A or nograd=(A,B,C)" )
161162 end
162163
164+ # Safe keyword
165+ elseif isexpr (ex, :(= )) && ex. args[1 ] == :safe
166+ if ex. args[2 ] isa Symbol
167+ push! (safe, ex. args[2 ])
168+ elseif isexpr (ex. args[2 ], :tuple )
169+ append! (safe, ex. args[2 ]. args)
170+ else
171+ throw (" this accepts safe=i or safe=(i,j,k)" )
172+ end
173+
163174 # Ranges specified outside:
164175 elseif isexpr (ex, :call ) && ex. args[1 ] in [:in , :∈ ]
165176 push! (ranges, (ex. args[2 ], ex. args[3 ]))
@@ -201,6 +212,7 @@ function parse_options(exs...)
201212 cuda= opts[:cuda ],
202213 tensor= opts[:tensor ],
203214 nograd= nograd,
215+ safe= safe,
204216 ), ranges, expr
205217end
206218
@@ -586,7 +598,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex
586598 MacroTools_postwalk (i) do x
587599 @capture_ (x, B_[inner__]) || return x
588600 # Now we have found an array which indexes another one, mark its indices unsafe
589- append! (list, filter (j -> j isa Symbol, inner))
601+ append! (list, setdiff ( filter (j -> j isa Symbol, inner), store . safe ))
590602 unique! (list)
591603 # and don't compute a gradient for the inner array
592604 B isa Symbol && push! (store. nograd, B)
0 commit comments