@@ -317,6 +317,11 @@ def __init__(self, transforms, reference=None):
317317 )
318318 self ._inverse = np .linalg .inv (self ._matrix )
319319
320+ def __iter__ (self ):
321+ """Enable iterating over the series of transforms."""
322+ for _m in self .matrix :
323+ yield Affine (_m , reference = self ._reference )
324+
320325 def __getitem__ (self , i ):
321326 """Enable indexed access to the series of matrices."""
322327 return Affine (self .matrix [i , ...], reference = self ._reference )
@@ -458,42 +463,37 @@ def apply(
458463 # Invert target's (moving) affine once
459464 ras2vox = ~ Affine (spatialimage .affine )
460465
461- if spatialimage .ndim == 4 :
462- if len (self ) != spatialimage .shape [- 1 ]:
463- raise ValueError (
464- "Attempting to apply %d transforms on a file with "
465- "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
466- )
467-
468- # Order F ensures individual volumes are contiguous in memory
469- # Also matches NIfTI, making final save more efficient
470- resampled = np .zeros (
471- (xcoords .T .shape [0 ], ) + spatialimage .shape [- 1 :], dtype = output_dtype , order = "F"
466+ if spatialimage .ndim == 4 and (len (self ) != spatialimage .shape [- 1 ]):
467+ raise ValueError (
468+ "Attempting to apply %d transforms on a file with "
469+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
472470 )
473471
474- for t in range (spatialimage .shape [- 1 ]):
475- # Map the input coordinates on to timepoint t of the target (moving)
476- ycoords = Affine (self .matrix [t ]).map (xcoords .T )[..., : _ref .ndim ]
477-
478- # Calculate corresponding voxel coordinates
479- yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
480-
481- # Interpolate
482- resampled [..., t ] = ndi .map_coordinates (
483- spatialimage .dataobj [..., t ].astype (input_dtype , copy = False ),
484- yvoxels .T ,
485- output = output_dtype ,
486- order = order ,
487- mode = mode ,
488- cval = cval ,
489- prefilter = prefilter ,
490- )
491- elif spatialimage .ndim in (2 , 3 ):
492- ycoords = self .map (xcoords .T )[..., : _ref .ndim ]
472+ # Order F ensures individual volumes are contiguous in memory
473+ # Also matches NIfTI, making final save more efficient
474+ resampled = np .zeros (
475+ (xcoords .T .shape [0 ], len (self )), dtype = output_dtype , order = "F"
476+ )
477+
478+ dataobj = (
479+ np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
480+ if spatialimage .ndim in (2 , 3 )
481+ else None
482+ )
483+
484+ for t , xfm_t in enumerate (self ):
485+ # Map the input coordinates on to timepoint t of the target (moving)
486+ ycoords = xfm_t .map (xcoords .T )[..., : _ref .ndim ]
487+
488+ # Calculate corresponding voxel coordinates
493489 yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
494490
495- resampled = ndi .map_coordinates (
496- spatialimage .dataobj .astype (input_dtype , copy = False ),
491+ # Interpolate
492+ resampled [..., t ] = ndi .map_coordinates (
493+ (
494+ dataobj if dataobj is not None
495+ else np .asanyarray (spatialimage .dataobj [..., t ], dtype = input_dtype )
496+ ),
497497 yvoxels .T ,
498498 output = output_dtype ,
499499 order = order ,
@@ -503,9 +503,9 @@ def apply(
503503 )
504504
505505 if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
506- newdata = resampled .reshape ((len (self ), * _ref . shape ))
506+ newdata = resampled .reshape (_ref . shape + (len (self ), ))
507507 moved = spatialimage .__class__ (
508- np . moveaxis ( newdata , 0 , - 1 ) , _ref .affine , spatialimage .header
508+ newdata , _ref .affine , spatialimage .header
509509 )
510510 moved .header .set_data_dtype (output_dtype )
511511 return moved
0 commit comments