@@ -75,6 +75,9 @@ static PyType_Slot s2s_slots[] = {
7575 {0 , NULL }};
7676
7777static char * s2s_name = "cast_StringDType_to_StringDType" ;
78+ static char * p2p_name = "cast_PandasStringDType_to_PandasStringDType" ;
79+ static char * s2p_name = "cast_StringDType_to_PandasStringDType" ;
80+ static char * p2s_name = "cast_PandasStringDType_to_StringDType" ;
7881
7982// unicode to string
8083
@@ -476,38 +479,80 @@ get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
476479}
477480
478481PyArrayMethod_Spec * *
479- get_casts (void )
482+ get_casts (PyArray_DTypeMeta * this , PyArray_DTypeMeta * other )
480483{
481- PyArray_DTypeMeta * * s2s_dtypes = get_dtypes ( NULL , NULL ) ;
484+ char * t2t_name = NULL ;
482485
483- PyArrayMethod_Spec * StringToStringCastSpec =
484- get_cast_spec (s2s_name , NPY_NO_CASTING ,
485- NPY_METH_SUPPORTS_UNALIGNED , s2s_dtypes , s2s_slots );
486+ if (this == (PyArray_DTypeMeta * )& StringDType ) {
487+ t2t_name = s2s_name ;
488+ }
489+ else {
490+ t2t_name = p2p_name ;
491+ }
492+
493+ PyArray_DTypeMeta * * t2t_dtypes = get_dtypes (this , this );
494+
495+ PyArrayMethod_Spec * ThisToThisCastSpec =
496+ get_cast_spec (t2t_name , NPY_NO_CASTING ,
497+ NPY_METH_SUPPORTS_UNALIGNED , t2t_dtypes , s2s_slots );
498+
499+ PyArrayMethod_Spec * ThisToOtherCastSpec = NULL ;
500+ PyArrayMethod_Spec * OtherToThisCastSpec = NULL ;
501+
502+ int is_pandas = (this == (PyArray_DTypeMeta * )& PandasStringDType );
503+
504+ int num_casts = 5 ;
505+
506+ if (is_pandas ) {
507+ num_casts = 7 ;
508+
509+ PyArray_DTypeMeta * * t2o_dtypes = get_dtypes (this , other );
486510
487- PyArray_DTypeMeta * * u2s_dtypes = get_dtypes (& PyArray_UnicodeDType , NULL );
511+ ThisToOtherCastSpec = get_cast_spec (p2s_name , NPY_NO_CASTING ,
512+ NPY_METH_SUPPORTS_UNALIGNED ,
513+ t2o_dtypes , s2s_slots );
514+
515+ PyArray_DTypeMeta * * o2t_dtypes = get_dtypes (other , this );
516+
517+ OtherToThisCastSpec = get_cast_spec (s2p_name , NPY_NO_CASTING ,
518+ NPY_METH_SUPPORTS_UNALIGNED ,
519+ o2t_dtypes , s2s_slots );
520+ }
521+
522+ PyArray_DTypeMeta * * u2s_dtypes = get_dtypes (& PyArray_UnicodeDType , this );
488523
489524 PyArrayMethod_Spec * UnicodeToStringCastSpec = get_cast_spec (
490525 u2s_name , NPY_SAFE_CASTING , NPY_METH_NO_FLOATINGPOINT_ERRORS ,
491526 u2s_dtypes , u2s_slots );
492527
493- PyArray_DTypeMeta * * s2u_dtypes = get_dtypes (NULL , & PyArray_UnicodeDType );
528+ PyArray_DTypeMeta * * s2u_dtypes = get_dtypes (this , & PyArray_UnicodeDType );
494529
495530 PyArrayMethod_Spec * StringToUnicodeCastSpec = get_cast_spec (
496531 s2u_name , NPY_SAFE_CASTING , NPY_METH_NO_FLOATINGPOINT_ERRORS ,
497532 s2u_dtypes , s2u_slots );
498533
499- PyArray_DTypeMeta * * s2b_dtypes = get_dtypes (NULL , & PyArray_BoolDType );
534+ PyArray_DTypeMeta * * s2b_dtypes = get_dtypes (this , & PyArray_BoolDType );
500535
501536 PyArrayMethod_Spec * StringToBoolCastSpec = get_cast_spec (
502537 s2b_name , NPY_UNSAFE_CASTING , NPY_METH_NO_FLOATINGPOINT_ERRORS ,
503538 s2b_dtypes , s2b_slots );
504539
505- PyArrayMethod_Spec * * casts = malloc (5 * sizeof (PyArrayMethod_Spec * ));
506- casts [0 ] = StringToStringCastSpec ;
540+ PyArrayMethod_Spec * * casts = NULL ;
541+
542+ casts = malloc (num_casts * sizeof (PyArrayMethod_Spec * ));
543+
544+ casts [0 ] = ThisToThisCastSpec ;
507545 casts [1 ] = UnicodeToStringCastSpec ;
508546 casts [2 ] = StringToUnicodeCastSpec ;
509547 casts [3 ] = StringToBoolCastSpec ;
510- casts [4 ] = NULL ;
548+ if (is_pandas ) {
549+ casts [4 ] = ThisToOtherCastSpec ;
550+ casts [5 ] = OtherToThisCastSpec ;
551+ casts [6 ] = NULL ;
552+ }
553+ else {
554+ casts [4 ] = NULL ;
555+ }
511556
512557 return casts ;
513558}
0 commit comments