1414
1515from nibabel .loadsave import load as _nbload
1616from nibabel .affines import from_matvec
17+ from nibabel .arrayproxy import get_obj_dtype
1718
1819from nitransforms .base import (
1920 ImageGrid ,
@@ -216,14 +217,13 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
216217 is_array = cls != Affine
217218 errors = []
218219 for potential_fmt in fmtlist :
219- if ( potential_fmt == "itk" and Path (filename ).suffix == ".mat" ) :
220+ if potential_fmt == "itk" and Path (filename ).suffix == ".mat" :
220221 is_array = False
221222 cls = Affine
222223
223224 try :
224225 struct = get_linear_factory (
225- potential_fmt ,
226- is_array = is_array
226+ potential_fmt , is_array = is_array
227227 ).from_filename (filename )
228228 except (TransformFileError , FileNotFoundError ) as err :
229229 errors .append ((potential_fmt , err ))
@@ -316,6 +316,11 @@ def __init__(self, transforms, reference=None):
316316 )
317317 self ._inverse = np .linalg .inv (self ._matrix )
318318
319+ def __iter__ (self ):
320+ """Enable iterating over the series of transforms."""
321+ for _m in self .matrix :
322+ yield Affine (_m , reference = self ._reference )
323+
319324 def __getitem__ (self , i ):
320325 """Enable indexed access to the series of matrices."""
321326 return Affine (self .matrix [i , ...], reference = self ._reference )
@@ -436,6 +441,7 @@ def apply(
436441 The data imaged after resampling to reference space.
437442
438443 """
444+
439445 if reference is not None and isinstance (reference , (str , Path )):
440446 reference = _nbload (str (reference ))
441447
@@ -446,40 +452,49 @@ def apply(
446452 if isinstance (spatialimage , (str , Path )):
447453 spatialimage = _nbload (str (spatialimage ))
448454
449- data = np .squeeze (np .asanyarray (spatialimage .dataobj ))
450- output_dtype = output_dtype or data .dtype
455+ # Avoid opening the data array just yet
456+ input_dtype = get_obj_dtype (spatialimage .dataobj )
457+ output_dtype = output_dtype or input_dtype
451458
452- ycoords = self .map (_ref .ndcoords .T )
453- targets = ImageGrid (spatialimage ).index ( # data should be an image
454- _as_homogeneous (np .vstack (ycoords ), dim = _ref .ndim )
455- )
459+ # Prepare physical coordinates of input (grid, points)
460+ xcoords = _ref .ndcoords .astype ("f4" ).T
456461
457- if data .ndim == 4 :
458- if len (self ) != data .shape [- 1 ]:
459- raise ValueError (
460- "Attempting to apply %d transforms on a file with "
461- "%d timepoints" % (len (self ), data .shape [- 1 ])
462- )
463- targets = targets .reshape ((len (self ), - 1 , targets .shape [- 1 ]))
464- resampled = np .stack (
465- [
466- ndi .map_coordinates (
467- data [..., t ],
468- targets [t , ..., : _ref .ndim ].T ,
469- output = output_dtype ,
470- order = order ,
471- mode = mode ,
472- cval = cval ,
473- prefilter = prefilter ,
474- )
475- for t in range (data .shape [- 1 ])
476- ],
477- axis = 0 ,
462+ # Invert target's (moving) affine once
463+ ras2vox = ~ Affine (spatialimage .affine )
464+
465+ if spatialimage .ndim == 4 and (len (self ) != spatialimage .shape [- 1 ]):
466+ raise ValueError (
467+ "Attempting to apply %d transforms on a file with "
468+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
478469 )
479- elif data .ndim in (2 , 3 ):
480- resampled = ndi .map_coordinates (
481- data ,
482- targets [..., : _ref .ndim ].T ,
470+
471+ # Order F ensures individual volumes are contiguous in memory
472+ # Also matches NIfTI, making final save more efficient
473+ resampled = np .zeros (
474+ (xcoords .shape [0 ], len (self )), dtype = output_dtype , order = "F"
475+ )
476+
477+ dataobj = (
478+ np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
479+ if spatialimage .ndim in (2 , 3 )
480+ else None
481+ )
482+
483+ for t , xfm_t in enumerate (self ):
484+ # Map the input coordinates on to timepoint t of the target (moving)
485+ ycoords = xfm_t .map (xcoords )[..., : _ref .ndim ]
486+
487+ # Calculate corresponding voxel coordinates
488+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
489+
490+ # Interpolate
491+ resampled [..., t ] = ndi .map_coordinates (
492+ (
493+ dataobj
494+ if dataobj is not None
495+ else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
496+ ),
497+ yvoxels .T ,
483498 output = output_dtype ,
484499 order = order ,
485500 mode = mode ,
@@ -488,10 +503,8 @@ def apply(
488503 )
489504
490505 if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
491- newdata = resampled .reshape ((len (self ), * _ref .shape ))
492- moved = spatialimage .__class__ (
493- np .moveaxis (newdata , 0 , - 1 ), _ref .affine , spatialimage .header
494- )
506+ newdata = resampled .reshape (_ref .shape + (len (self ),))
507+ moved = spatialimage .__class__ (newdata , _ref .affine , spatialimage .header )
495508 moved .header .set_data_dtype (output_dtype )
496509 return moved
497510
0 commit comments