@@ -50,6 +50,7 @@ identical memory layout to a Julia `Array` of the same size.
5050`st` should be the stride(s) *in bytes* between elements in each dimension
5151"""
5252function f_contiguous (:: Type{T} , sz:: NTuple{N,Int} , st:: NTuple{N,Int} ) where {T,N}
53+ N == 0 && return true # 0-dimensional arrays have 1 element, always contiguous
5354 if st[1 ] != sizeof (T)
5455 # not contiguous
5556 return false
@@ -153,77 +154,72 @@ function copy(a::PyArray{T,N}) where {T,N}
153154 return A
154155end
155156
156- # TODO : need to do bounds-checking of these indices!
157- # TODO : need to GC root these `a`s to guard against the PyArray getting gc'd,
158- # e.g. if it's a temporary in a function:
159- # `two_rands() = pycall(np.rand, PyArray, 10)[1:2]`
160-
161-
162- getindex (a:: PyArray{T,0} ) where {T} = unsafe_load (a. data)
163- getindex (a:: PyArray{T,1} , i:: Integer ) where {T} = unsafe_load (a. data, 1 + (i- 1 )* a. st[1 ])
157+ unsafe_data_load (a:: PyArray , i:: Integer ) = GC. @preserve a unsafe_load (a. data, i)
158+
159+ @inline data_index (a:: PyArray{<:Any,N} , i:: CartesianIndex{N} ) where {N} =
160+ 1 + sum (ntuple (dim -> (i[dim]- 1 ) * a. st[dim], Val {N} ())) # Val lets julia unroll/inline
161+ data_index (a:: PyArray{<:Any,0} , i:: CartesianIndex{0} ) = 1
162+
163+ # handle passing fewer/more indices than dimensions by canonicalizing to M==N
164+ @inline function fixindex (a:: PyArray{<:Any,N} , i:: CartesianIndex{M} ) where {M,N}
165+ if M == N
166+ return i
167+ elseif M < N
168+ @boundscheck (all (ntuple (k -> size (a,k+ M)== 1 , Val {N-M} ())) ||
169+ throw (BoundsError (a, i))) # trailing sizes must == 1
170+ return CartesianIndex (Tuple (i)... , ntuple (k -> 1 , Val {N-M} ())... )
171+ else # M > N
172+ @boundscheck (all (ntuple (k -> i[k+ N]== 1 , Val {M-N} ())) ||
173+ throw (BoundsError (a, i))) # trailing indices must == 1
174+ return CartesianIndex (ntuple (k -> i[k], Val {N} ()))
175+ end
176+ end
164177
165- getindex (a:: PyArray{T,2} , i:: Integer , j:: Integer ) where {T} =
166- unsafe_load (a. data, 1 + (i- 1 )* a. st[1 ] + (j- 1 )* a. st[2 ])
178+ @inline function getindex (a:: PyArray , i:: CartesianIndex )
179+ j = fixindex (a, i)
180+ @boundscheck checkbounds (a, j)
181+ unsafe_data_load (a, data_index (a, j))
182+ end
183+ @inline getindex (a:: PyArray , i:: Integer... ) = a[CartesianIndex (i)]
184+ @inline getindex (a:: PyArray{<:Any,1} , i:: Integer ) = a[CartesianIndex (i)]
167185
186+ # linear indexing
168187function getindex (a:: PyArray , i:: Integer )
188+ @boundscheck checkbounds (a, i)
169189 if a. f_contig
170- return unsafe_load (a . data , i)
190+ return unsafe_data_load (a , i)
171191 else
172- return a[ind2sub (a . dims, i) ... ]
192+ @inbounds return a[CartesianIndices (a)[i] ]
173193 end
174194end
175195
176- function getindex (a:: PyArray , is:: Integer... )
177- index = 1
178- n = min (length (is),length (a. st))
179- for i = 1 : n
180- index += (is[i]- 1 )* a. st[i]
181- end
182- for i = n+ 1 : length (is)
183- if is[i] != 1
184- throw (BoundsError ())
185- end
186- end
187- unsafe_load (a. data, index)
188- end
189-
190196function writeok_assign (a:: PyArray , v, i:: Integer )
191197 if a. info. readonly
192198 throw (ArgumentError (" read-only PyArray" ))
193199 else
194- unsafe_store! (a. data, v, i)
200+ GC . @preserve a unsafe_store! (a. data, v, i)
195201 end
196- return a
202+ return v
197203end
198204
199- setindex! (a:: PyArray{T,0} , v) where {T} = writeok_assign (a, v, 1 )
200- setindex! (a:: PyArray{T,1} , v, i:: Integer ) where {T} = writeok_assign (a, v, 1 + (i- 1 )* a. st[1 ])
201-
202- setindex! (a:: PyArray{T,2} , v, i:: Integer , j:: Integer ) where {T} =
203- writeok_assign (a, v, 1 + (i- 1 )* a. st[1 ] + (j- 1 )* a. st[2 ])
205+ @inline function setindex! (a:: PyArray , v, i:: CartesianIndex )
206+ j = fixindex (a, i)
207+ @boundscheck checkbounds (a, j)
208+ writeok_assign (a, v, data_index (a, j))
209+ end
210+ @inline setindex! (a:: PyArray , v, i:: Integer... ) = setindex! (a, v, CartesianIndex (i))
211+ @inline setindex! (a:: PyArray{<:Any,1} , v, i:: Integer ) = setindex! (a, v, CartesianIndex (i))
204212
213+ # linear indexing
205214function setindex! (a:: PyArray , v, i:: Integer )
215+ @boundscheck checkbounds (a, i)
206216 if a. f_contig
207217 return writeok_assign (a, v, i)
208218 else
209- return setindex! (a, v, ind2sub (a . dims, i) ... )
219+ @inbounds return setindex! (a, v, CartesianIndices (a)[i] )
210220 end
211221end
212222
213- function setindex! (a:: PyArray , v, is:: Integer... )
214- index = 1
215- n = min (length (is),length (a. st))
216- for i = 1 : n
217- index += (is[i]- 1 )* a. st[i]
218- end
219- for i = n+ 1 : length (is)
220- if is[i] != 1
221- throw (BoundsError ())
222- end
223- end
224- writeok_assign (a, v, index)
225- end
226-
227223stride (a:: PyArray , i:: Integer ) = a. st[i]
228224
229225Base. unsafe_convert (:: Type{Ptr{T}} , a:: PyArray{T} ) where {T} = a. data
@@ -244,68 +240,56 @@ summary(a::PyArray{T}) where {T} = string(Base.dims2string(size(a)), " ",
244240# ########################################################################
245241# PyArray <-> PyObject conversions
246242
247- const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr}
243+ const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr,PyObject }
248244
249245PyObject (a:: PyArray ) = a. o
250246
251247convert (:: Type{PyArray} , o:: PyObject ) = PyArray (o)
252248
249+ # PyObject arrays are created by taking a NumPy array of PyPtr and converting
250+ pyo2ptr (T:: Type ) = T
251+ pyo2ptr (:: Type{PyObject} ) = PyPtr
252+ pyocopy (a) = copy (a)
253+ pyocopy (a:: AbstractArray{PyPtr} ) = GC. @preserve a map (pyincref, a)
254+
253255function convert (:: Type{Array{T, 1}} , o:: PyObject ) where T<: PYARR_TYPES
254256 try
255- copy (PyArray {T , 1} (o, PyArray_Info (o))) # will check T and N vs. info
257+ return pyocopy (PyArray {pyo2ptr(T) , 1} (o, PyArray_Info (o))) # will check T and N vs. info
256258 catch
257- len = @pycheckz ccall ((@pysym :PySequence_Size ), Int, (PyPtr,), o)
258- A = Array {pyany_toany(T)} (undef, len)
259- py2array (T, A, o, 1 , 1 )
259+ return py2vector (T, o)
260260 end
261261end
262262
263263function convert (:: Type{Array{T}} , o:: PyObject ) where T<: PYARR_TYPES
264264 try
265265 info = PyArray_Info (o)
266266 try
267- copy (PyArray {T , length(info.sz)} (o, info)) # will check T == eltype(info)
267+ return pyocopy (PyArray {pyo2ptr(T) , length(info.sz)} (o, info)) # will check T == eltype(info)
268268 catch
269- return py2array (T, Array {pyany_toany(T) } (undef, info. sz... ), o, 1 , 1 )
269+ return py2array (T, Array {T } (undef, info. sz... ), o, 1 , 1 )
270270 end
271271 catch
272- py2array (T, o)
272+ return py2array (T, o)
273273 end
274274end
275275
276276function convert (:: Type{Array{T,N}} , o:: PyObject ) where {T<: PYARR_TYPES ,N}
277277 try
278278 info = PyArray_Info (o)
279279 try
280- copy (PyArray {T ,N} (o, info)) # will check T,N == eltype(info),ndims(info)
280+ pyocopy (PyArray {pyo2ptr(T) ,N} (o, info)) # will check T,N == eltype(info),ndims(info)
281281 catch
282282 nd = length (info. sz)
283- if nd != N
284- throw (ArgumentError (" cannot convert $(nd) d array to $(N) d" ))
285- end
286- return py2array (T, Array {pyany_toany(T)} (undef, info. sz... ), o, 1 , 1 )
283+ nd == N || throw (ArgumentError (" cannot convert $(nd) d array to $(N) d" ))
284+ return py2array (T, Array {T} (undef, info. sz... ), o, 1 , 1 )
287285 end
288286 catch
289287 A = py2array (T, o)
290- if ndims (A) != N
291- throw (ArgumentError (" cannot convert $(ndims (A)) d array to $(N) d" ))
292- end
293- A
288+ ndims (A) == N || throw (ArgumentError (" cannot convert $(ndims (A)) d array to $(N) d" ))
289+ return A
294290 end
295291end
296292
297- function convert (:: Type{Array{PyObject}} , o:: PyObject )
298- map (pyincref, convert (Array{PyPtr}, o))
299- end
300-
301- function convert (:: Type{Array{PyObject,1}} , o:: PyObject )
302- map (pyincref, convert (Array{PyPtr, 1 }, o))
303- end
304-
305- function convert (:: Type{Array{PyObject,N}} , o:: PyObject ) where N
306- map (pyincref, convert (Array{PyPtr, N}, o))
307- end
308-
309293array_format (o:: PyObject ) = array_format (PyBuffer (o, PyBUF_ND_STRIDED))
310294
311295"""
0 commit comments