@@ -282,6 +282,7 @@ quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtyp
282282 // Determine target backend and if casting is needed
283283 NPY_CASTING casting = NPY_NO_CASTING;
284284 if (descr_in1->backend != descr_in2->backend ) {
285+
285286 target_backend = BACKEND_LONGDOUBLE;
286287 casting = NPY_SAFE_CASTING;
287288 }
@@ -397,12 +398,12 @@ static int
397398quad_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
398399 PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
399400{
400- printf ( " called comparison promoter \n " );
401+
401402 int nin = ufunc->nin ;
402403 int nargs = ufunc->nargs ;
403404 PyArray_DTypeMeta *common = NULL ;
404405 bool has_quad = false ;
405- printf ( " dtyp1: %s dtype2: %s \n " , get_dtype_name (op_dtypes[ 0 ]), get_dtype_name (op_dtypes[ 1 ]));
406+
406407 // Handle the special case for reductions
407408 if (op_dtypes[0 ] == NULL ) {
408409 assert (nin == 2 && ufunc->nout == 1 ); /* must be reduction */
@@ -416,7 +417,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
416417 // Check if any input or signature is QuadPrecision
417418 for (int i = 0 ; i < nin; i++) {
418419 if (op_dtypes[i] == &QuadPrecDType) {
419- printf ( " Quaddtype found at index: %d \n " , i);
420+
420421 has_quad = true ;
421422 }
422423 }
@@ -460,7 +461,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
460461 else {
461462 // Otherwise, use the common dtype
462463 Py_INCREF (common);
463- printf ( " setting output to %s dtype \n " , get_dtype_name (common));
464+
464465 new_op_dtypes[i] = common;
465466 }
466467 }
@@ -560,6 +561,47 @@ init_quad_binary_ops(PyObject *numpy)
560561
561562// comparison functions
562563
564+ static NPY_CASTING
565+ quad_comparison_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
566+ PyArray_Descr *const given_descrs[],
567+ PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED (view_offset))
568+ {
569+ QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0 ];
570+ QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1 ];
571+ QuadBackendType target_backend;
572+
573+ // As dealing with different backends then cast to boolean
574+ NPY_CASTING casting = NPY_NO_CASTING;
575+ if (descr_in1->backend != descr_in2->backend ) {
576+ target_backend = BACKEND_LONGDOUBLE;
577+ casting = NPY_SAFE_CASTING;
578+ }
579+ else {
580+ target_backend = descr_in1->backend ;
581+ }
582+
583+ // Set up input descriptors, casting if necessary
584+ for (int i = 0 ; i < 2 ; i++) {
585+ if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
586+ loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
587+ if (!loop_descrs[i]) {
588+ return (NPY_CASTING)-1 ;
589+ }
590+ }
591+ else {
592+ Py_INCREF (given_descrs[i]);
593+ loop_descrs[i] = given_descrs[i];
594+ }
595+ }
596+
597+ // Set up output descriptor
598+ loop_descrs[2 ] = PyArray_DescrFromType (NPY_BOOL);
599+ if (!loop_descrs[2 ]) {
600+ return (NPY_CASTING)-1 ;
601+ }
602+ return casting;
603+ }
604+
563605template <cmp_quad_def sleef_comp, cmp_londouble_def ld_comp>
564606int
565607quad_generic_comp_strided_loop (PyArrayMethod_Context *context, char *const data[],
@@ -581,15 +623,18 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
581623 while (N--) {
582624 memcpy (&in1, in1_ptr, elem_size);
583625 memcpy (&in2, in2_ptr, elem_size);
626+ npy_bool result;
584627
585628 if (backend == BACKEND_SLEEF) {
586- *((npy_bool *)out_ptr) = sleef_comp (&in1.sleef_value , &in2.sleef_value );
629+ result = sleef_comp (&in1.sleef_value , &in2.sleef_value );
587630 }
588631 else {
589- printf ( " %Lf % Lf \n " , in1. longdouble_value , in2. longdouble_value );
590- *((npy_bool *)out_ptr) = ld_comp (&in1.longdouble_value , &in2.longdouble_value );
632+
633+ result = ld_comp (&in1.longdouble_value , &in2.longdouble_value );
591634 }
592635
636+ *((npy_bool *)out_ptr) = result;
637+
593638 in1_ptr += in1_stride;
594639 in2_ptr += in2_stride;
595640 out_ptr += out_stride;
@@ -624,6 +669,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
624669 PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &PyArray_BoolDType};
625670
626671 PyType_Slot slots[] = {
672+ {NPY_METH_resolve_descriptors, (void *)&quad_comparison_op_resolve_descriptors},
627673 {NPY_METH_strided_loop, (void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
628674 {NPY_METH_unaligned_strided_loop,
629675 (void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
@@ -633,7 +679,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
633679 .name = " quad_comp" ,
634680 .nin = 2 ,
635681 .nout = 1 ,
636- .casting = NPY_NO_CASTING ,
682+ .casting = NPY_SAFE_CASTING ,
637683 .flags = NPY_METH_SUPPORTS_UNALIGNED,
638684 .dtypes = dtypes,
639685 .slots = slots,
0 commit comments