@@ -456,7 +456,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
456456 }
457457 Py_XDECREF (common );
458458
459- /* Otherwise, set all input operands to StringDType */
459+ /* Otherwise, set all input operands to final_dtype */
460460 for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
461461 PyArray_DTypeMeta * tmp = final_dtype ;
462462 if (signature [i ]) {
@@ -474,21 +474,32 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
474474}
475475
476476static int
477- string_ufunc_promoter ( PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
478- PyArray_DTypeMeta * signature [],
479- PyArray_DTypeMeta * new_op_dtypes [])
477+ string_object_promoter ( PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
478+ PyArray_DTypeMeta * signature [],
479+ PyArray_DTypeMeta * new_op_dtypes [])
480480{
481- return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
481+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
482+ signature , new_op_dtypes ,
483+ (PyArray_DTypeMeta * )& PyArray_ObjectDType );
484+ }
485+
486+ static int
487+ string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
488+ PyArray_DTypeMeta * signature [],
489+ PyArray_DTypeMeta * new_op_dtypes [])
490+ {
491+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
492+ signature , new_op_dtypes ,
482493 (PyArray_DTypeMeta * )& StringDType );
483494}
484495
485496static int
486- pandas_string_ufunc_promoter (PyUFuncObject * ufunc ,
487- PyArray_DTypeMeta * op_dtypes [],
488- PyArray_DTypeMeta * signature [],
489- PyArray_DTypeMeta * new_op_dtypes [])
497+ pandas_string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
498+ PyArray_DTypeMeta * signature [],
499+ PyArray_DTypeMeta * new_op_dtypes [])
490500{
491- return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
501+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
502+ signature , new_op_dtypes ,
492503 (PyArray_DTypeMeta * )& PandasStringDType );
493504}
494505
@@ -538,7 +549,7 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
538549int
539550add_promoter (PyObject * numpy , const char * ufunc_name ,
540551 PyArray_DTypeMeta * ldtype , PyArray_DTypeMeta * rdtype ,
541- PyArray_DTypeMeta * edtype , int is_pandas )
552+ PyArray_DTypeMeta * edtype , promoter_function * promoter_impl )
542553{
543554 PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
544555
@@ -553,16 +564,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
553564 return -1 ;
554565 }
555566
556- PyObject * promoter_capsule = NULL ;
557-
558- if (is_pandas == 0 ) {
559- promoter_capsule = PyCapsule_New ((void * )& string_ufunc_promoter ,
560- "numpy._ufunc_promoter" , NULL );
561- }
562- else {
563- promoter_capsule = PyCapsule_New ((void * )& pandas_string_ufunc_promoter ,
564- "numpy._ufunc_promoter" , NULL );
565- }
567+ PyObject * promoter_capsule = PyCapsule_New ((void * )promoter_impl ,
568+ "numpy._ufunc_promoter" , NULL );
566569
567570 if (promoter_capsule == NULL ) {
568571 Py_DECREF (ufunc );
@@ -592,21 +595,31 @@ init_ufuncs(void)
592595 return -1 ;
593596 }
594597
595- StringDType_type * * dtype_classes = NULL ;
596598 int num_dtypes ;
597599
598600 if (PANDAS_AVAILABLE ) {
599- dtype_classes = malloc (sizeof (StringDType_type * ) * 2 );
600- dtype_classes [0 ] = & StringDType ;
601- dtype_classes [1 ] = & PandasStringDType ;
602601 num_dtypes = 2 ;
603602 }
604603 else {
605- dtype_classes = malloc (sizeof (StringDType_type * ) * 1 );
606- dtype_classes [0 ] = & StringDType ;
607604 num_dtypes = 1 ;
608605 }
609606
607+ StringDType_type * * dtype_classes =
608+ malloc (sizeof (StringDType_type * ) * num_dtypes );
609+ promoter_function * * unicode_promoters =
610+ malloc (sizeof (promoter_function * ) * num_dtypes );
611+ dtype_classes [0 ] = & StringDType ;
612+ unicode_promoters [0 ] = & string_unicode_promoter ;
613+
614+ if (PANDAS_AVAILABLE ) {
615+ dtype_classes [1 ] = & PandasStringDType ;
616+ unicode_promoters [1 ] = & pandas_string_unicode_promoter ;
617+ }
618+
619+ static char * comparison_ufunc_names [6 ] = {"equal" , "not_equal" ,
620+ "greater" , "greater_equal" ,
621+ "less" , "less_equal" };
622+
610623 for (int di = 0 ; di < num_dtypes ; di ++ ) {
611624 PyArray_DTypeMeta * comparison_dtypes [] = {
612625 (PyArray_DTypeMeta * )dtype_classes [di ],
@@ -654,34 +667,32 @@ init_ufuncs(void)
654667 goto error ;
655668 }
656669
657- static char * ufunc_names [6 ] = {"equal" , "not_equal" ,
658- "greater" , "greater_equal" ,
659- "less" , "less_equal" };
660-
661670 for (int i = 0 ; i < 6 ; i ++ ) {
662- if (add_promoter (numpy , ufunc_names [i ],
671+ if (add_promoter (numpy , comparison_ufunc_names [i ],
663672 (PyArray_DTypeMeta * )dtype_classes [di ],
664673 & PyArray_UnicodeDType , & PyArray_BoolDType ,
665- 0 ) < 0 ) {
674+ unicode_promoters [ di ] ) < 0 ) {
666675 goto error ;
667676 }
668677
669- if (add_promoter (numpy , ufunc_names [i ], & PyArray_UnicodeDType ,
678+ if (add_promoter (numpy , comparison_ufunc_names [i ],
679+ & PyArray_UnicodeDType ,
670680 (PyArray_DTypeMeta * )dtype_classes [di ],
671- & PyArray_BoolDType , 0 ) < 0 ) {
681+ & PyArray_BoolDType , unicode_promoters [ di ] ) < 0 ) {
672682 goto error ;
673683 }
674684
675- if (add_promoter (numpy , ufunc_names [i ], & PyArray_ObjectDType ,
676- (PyArray_DTypeMeta * )dtype_classes [di ],
677- & PyArray_BoolDType , 0 ) < 0 ) {
685+ if (add_promoter (
686+ numpy , comparison_ufunc_names [i ], & PyArray_ObjectDType ,
687+ (PyArray_DTypeMeta * )dtype_classes [di ],
688+ & PyArray_BoolDType , & string_object_promoter ) < 0 ) {
678689 goto error ;
679690 }
680691
681- if (add_promoter (numpy , ufunc_names [i ],
692+ if (add_promoter (numpy , comparison_ufunc_names [i ],
682693 (PyArray_DTypeMeta * )dtype_classes [di ],
683694 & PyArray_ObjectDType , & PyArray_BoolDType ,
684- 0 ) < 0 ) {
695+ & string_object_promoter ) < 0 ) {
685696 goto error ;
686697 }
687698 }
@@ -720,10 +731,36 @@ init_ufuncs(void)
720731 }
721732 }
722733
734+ // add promoters for all ufuncs so comparison operations mixing StringDType
735+ // and PandasStringDType work correctly.
736+
737+ if (PANDAS_AVAILABLE ) {
738+ for (int i = 0 ; i < 6 ; i ++ ) {
739+ if (add_promoter (numpy , comparison_ufunc_names [i ],
740+ (PyArray_DTypeMeta * )& StringDType ,
741+ (PyArray_DTypeMeta * )& PandasStringDType ,
742+ & PyArray_BoolDType ,
743+ string_unicode_promoter ) < 0 ) {
744+ goto error ;
745+ }
746+
747+ if (add_promoter (numpy , comparison_ufunc_names [i ],
748+ (PyArray_DTypeMeta * )& PandasStringDType ,
749+ (PyArray_DTypeMeta * )& StringDType ,
750+ & PyArray_BoolDType ,
751+ string_unicode_promoter ) < 0 ) {
752+ goto error ;
753+ }
754+ }
755+ }
756+ free (dtype_classes );
757+ free (unicode_promoters );
723758 Py_DECREF (numpy );
724759 return 0 ;
725760
726761error :
762+ free (dtype_classes );
763+ free (unicode_promoters );
727764 Py_DECREF (numpy );
728765 return -1 ;
729766}
0 commit comments