|
22 | 22 |
|
23 | 23 | var isComplexDataType = require( '@stdlib/ndarray/base/assert/is-complex-floating-point-data-type' ); |
24 | 24 | var isRealDataType = require( '@stdlib/ndarray/base/assert/is-real-data-type' ); |
| 25 | +var isComplexArray = require( '@stdlib/array/base/assert/is-complex-typed-array' ); |
| 26 | +var isBooleanArray = require( '@stdlib/array/base/assert/is-booleanarray' ); |
25 | 27 | var iterationOrder = require( '@stdlib/ndarray/base/iteration-order' ); |
26 | 28 | var castReturn = require( '@stdlib/complex/base/cast-return' ); |
27 | 29 | var complexCtors = require( '@stdlib/complex/ctors' ); |
28 | 30 | var minmaxViewBufferIndex = require( '@stdlib/ndarray/base/minmax-view-buffer-index' ); |
29 | 31 | var ndarray2object = require( '@stdlib/ndarray/base/ndarraylike2object' ); |
| 32 | +var reinterpretComplex = require( '@stdlib/strided/base/reinterpret-complex' ); |
| 33 | +var reinterpretBoolean = require( '@stdlib/strided/base/reinterpret-boolean' ); |
| 34 | +var gscal = require( '@stdlib/blas/base/gscal' ); |
30 | 35 | var blockedaccessorassign2d = require( './2d_blocked_accessors.js' ); |
31 | 36 | var blockedaccessorassign3d = require( './3d_blocked_accessors.js' ); |
32 | 37 | var blockedaccessorassign4d = require( './4d_blocked_accessors.js' ); |
@@ -123,6 +128,57 @@ var BLOCKED_ACCESSOR_ASSIGN = [ |
123 | 128 | ]; |
124 | 129 | var MAX_DIMS = ASSIGN.length - 1; |
125 | 130 |
|
| 131 | +// TODO: consider adding a package utility for mapping a complex dtype to its complementary real-valued counterpart |
| 132 | +var COMPLEX_TO_REAL = { // WARNING: this table needs to be manually updated if we add support for additional complex number dtypes |
| 133 | + 'complex128': 'float64', |
| 134 | + 'complex64': 'float32' |
| 135 | +}; |
| 136 | + |
| 137 | + |
| 138 | +// FUNCTIONS // |
| 139 | + |
| 140 | +/** |
| 141 | +* Converts a boolean ndarray to an 8-bit unsigned integer ndarray. |
| 142 | +* |
| 143 | +* ## Notes |
| 144 | +* |
| 145 | +* - The function mutates the input ndarray object. |
| 146 | +* |
| 147 | +* @private |
| 148 | +* @param {Object} x - input ndarray object |
| 149 | +* @returns {Object} output ndarray object |
| 150 | +*/ |
| 151 | +function boolean2uint8( x ) { |
| 152 | + x.data = reinterpretBoolean( x.data, 0 ); |
| 153 | + x.accessorProtocol = false; |
| 154 | + return x; |
| 155 | +} |
| 156 | + |
| 157 | +/** |
| 158 | +* Converts a complex-valued floating-point ndarray to a real-valued floating-point ndarray. |
| 159 | +* |
| 160 | +* ## Notes |
| 161 | +* |
| 162 | +* - The function mutates the input ndarray object. |
| 163 | +* |
| 164 | +* @private |
| 165 | +* @param {Object} x - input ndarray object |
| 166 | +* @returns {Object} output ndarray object |
| 167 | +*/ |
| 168 | +function complex2real( x ) { |
| 169 | + x.data = reinterpretComplex( x.data, 0 ); |
| 170 | + x.accessorProtocol = false; |
| 171 | + x.dtype = COMPLEX_TO_REAL[ x.dtype ]; |
| 172 | + x.strides = gscal( x.shape.length, 2, x.strides, 1 ); |
| 173 | + x.offset *= 2; |
| 174 | + |
| 175 | + // Append a trailing dimension where each element is the real and imaginary component for a corresponding element in the original input ndarray (note: this means that a two-dimensional complex-valued ndarray becomes a three-dimensional real-valued ndarray; while this does entail additional loop overhead, it is still significantly faster than sending complex-valued ndarrays down the accessor path): |
| 176 | + x.shape.push( 2 ); // real and imaginary components |
| 177 | + x.strides.push( 1 ); // real and imaginary components are assumed to be adjacent in memory |
| 178 | + |
| 179 | + return x; |
| 180 | +} |
| 181 | + |
126 | 182 |
|
127 | 183 | // MAIN // |
128 | 184 |
|
@@ -210,8 +266,16 @@ function assign( arrays ) { |
210 | 266 | x = ndarray2object( arrays[ 0 ] ); |
211 | 267 | y = ndarray2object( arrays[ 1 ] ); |
212 | 268 |
|
| 269 | + // Check for known array types which can be reinterpreted for better iteration performance... |
| 270 | + if ( isBooleanArray( x.data ) && isBooleanArray( y.data ) ) { |
| 271 | + x = boolean2uint8( x ); |
| 272 | + y = boolean2uint8( y ); |
| 273 | + } else if ( isComplexArray( x.data ) && isComplexArray( y.data ) ) { |
| 274 | + x = complex2real( x ); |
| 275 | + y = complex2real( y ); |
| 276 | + } |
213 | 277 | // Determine whether we are casting a real data type to a complex data type and we need to use a specialized accessor (note: we don't support the other way, complex-to-real, as this is not an allowed (mostly) safe cast)... |
214 | | - if ( isRealDataType( x.dtype ) && isComplexDataType( y.dtype ) ) { |
| 278 | + else if ( isRealDataType( x.dtype ) && isComplexDataType( y.dtype ) ) { |
215 | 279 | x.accessorProtocol = true; |
216 | 280 | x.accessors[ 0 ] = castReturn( x.accessors[ 0 ], 2, complexCtors( y.dtype ) ); // eslint-disable-line max-len |
217 | 281 | } |
|
0 commit comments