@@ -436,6 +436,7 @@ def apply(
436436 The data imaged after resampling to reference space.
437437
438438 """
439+
439440 if reference is not None and isinstance (reference , (str , Path )):
440441 reference = _nbload (str (reference ))
441442
@@ -446,40 +447,53 @@ def apply(
446447 if isinstance (spatialimage , (str , Path )):
447448 spatialimage = _nbload (str (spatialimage ))
448449
449- data = np .squeeze (np .asanyarray (spatialimage .dataobj ))
450- output_dtype = output_dtype or data .dtype
450+ # Avoid opening the data array just yet
451+ input_dtype = spatialimage .header .get_data_dtype ()
452+ output_dtype = output_dtype or input_dtype
451453
452- ycoords = self .map (_ref .ndcoords .T )
453- targets = ImageGrid (spatialimage ).index ( # data should be an image
454- _as_homogeneous (np .vstack (ycoords ), dim = _ref .ndim )
455- )
454+ # Prepare physical coordinates of input (grid, points)
455+ xcoords = _ref .ndcoords .astype ("f4" )
456456
457- if data .ndim == 4 :
458- if len (self ) != data .shape [- 1 ]:
457+ # Invert target's (moving) affine once
458+ ras2vox = ~ Affine (spatialimage .affine )
459+
460+ if spatialimage .ndim == 4 :
461+ if len (self ) != spatialimage .shape [- 1 ]:
459462 raise ValueError (
460463 "Attempting to apply %d transforms on a file with "
461- "%d timepoints" % (len (self ), data .shape [- 1 ])
464+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
462465 )
463- targets = targets .reshape ((len (self ), - 1 , targets .shape [- 1 ]))
464- resampled = np .stack (
465- [
466- ndi .map_coordinates (
467- data [..., t ],
468- targets [t , ..., : _ref .ndim ].T ,
469- output = output_dtype ,
470- order = order ,
471- mode = mode ,
472- cval = cval ,
473- prefilter = prefilter ,
474- )
475- for t in range (data .shape [- 1 ])
476- ],
477- axis = 0 ,
466+
467+ # Order F ensures individual volumes are contiguous in memory
468+ # Also matches NIfTI, making final save more efficient
469+ resampled = np .zeros (
470+ (xcoords .T .shape [0 ], ) + spatialimage .shape [- 1 :], dtype = output_dtype , order = "F"
478471 )
479- elif data .ndim in (2 , 3 ):
472+
473+ for t in range (spatialimage .shape [- 1 ]):
474+ # Map the input coordinates on to timepoint t of the target (moving)
475+ ycoords = Affine (self .matrix [t ]).map (xcoords .T )[..., : _ref .ndim ]
476+
477+ # Calculate corresponding voxel coordinates
478+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
479+
480+ # Interpolate
481+ resampled [..., t ] = ndi .map_coordinates (
482+ spatialimage .dataobj [..., t ].astype (input_dtype , copy = False ),
483+ yvoxels .T ,
484+ output = output_dtype ,
485+ order = order ,
486+ mode = mode ,
487+ cval = cval ,
488+ prefilter = prefilter ,
489+ )
490+ elif spatialimage .ndim in (2 , 3 ):
491+ ycoords = self .map (xcoords .T )[..., : _ref .ndim ]
492+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
493+
480494 resampled = ndi .map_coordinates (
481- data ,
482- targets [..., : _ref . ndim ] .T ,
495+ spatialimage . dataobj . astype ( input_dtype , copy = False ) ,
496+ yvoxels .T ,
483497 output = output_dtype ,
484498 order = order ,
485499 mode = mode ,
0 commit comments