@@ -20,6 +20,38 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
2020
2121 Py_XINCREF (na_object );
2222 ((StringDTypeObject * )new )-> na_object = na_object ;
23+ int hasnull = na_object != NULL ;
24+ int has_nan_na = 0 ;
25+ int has_string_na = 0 ;
26+ ss default_string = EMPTY_STRING ;
27+ if (hasnull ) {
28+ double na_float = PyFloat_AsDouble (na_object );
29+ if (na_float == -1.0 && PyErr_Occurred ()) {
30+ // not a float, still treat as nan if PyObject_IsTrue raises
31+ // (e.g. pandas.NA)
32+ PyErr_Clear ();
33+ int is_truthy = PyObject_IsTrue (na_object );
34+ if (is_truthy == -1 ) {
35+ PyErr_Clear ();
36+ has_nan_na = 1 ;
37+ }
38+ }
39+ else if (npy_isnan (na_float )) {
40+ has_nan_na = 1 ;
41+ }
42+
43+ if (PyUnicode_Check (na_object )) {
44+ has_string_na = 1 ;
45+ Py_ssize_t size = 0 ;
46+ const char * buf = PyUnicode_AsUTF8AndSize (na_object , & size );
47+ default_string .len = size ;
48+ // discards const, how to avoid?
49+ default_string .buf = (char * )buf ;
50+ }
51+ }
52+ ((StringDTypeObject * )new )-> has_nan_na = has_nan_na ;
53+ ((StringDTypeObject * )new )-> has_string_na = has_string_na ;
54+ ((StringDTypeObject * )new )-> default_string = default_string ;
2355 ((StringDTypeObject * )new )-> coerce = coerce ;
2456
2557 PyArray_Descr * base = (PyArray_Descr * )new ;
@@ -28,6 +60,9 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
2860 base -> flags |= NPY_NEEDS_INIT ;
2961 base -> flags |= NPY_LIST_PICKLE ;
3062 base -> flags |= NPY_ITEM_REFCOUNT ;
63+ if (hasnull && !(has_string_na && has_nan_na )) {
64+ base -> flags |= NPY_NEEDS_PYAPI ;
65+ }
3166
3267 return new ;
3368}
@@ -227,25 +262,43 @@ int
227262_compare (void * a , void * b , StringDTypeObject * descr )
228263{
229264 int hasnull = descr -> na_object != NULL ;
265+ int has_string_na = descr -> has_string_na ;
266+ int has_nan_na = descr -> has_nan_na ;
267+ if (hasnull && !(has_string_na && has_nan_na )) {
268+ // check if an error occured already to avoid setting an error again
269+ if (PyErr_Occurred ()) {
270+ return 0 ;
271+ }
272+ }
273+ const ss * default_string = & descr -> default_string ;
230274 const ss * ss_a = (ss * )a ;
231275 const ss * ss_b = (ss * )b ;
232276 int a_is_null = ss_isnull (ss_a );
233277 int b_is_null = ss_isnull (ss_b );
234278 if (NPY_UNLIKELY (a_is_null || b_is_null )) {
235- if (hasnull ) {
236- if (a_is_null ) {
237- return 1 ;
279+ if (hasnull && !has_string_na ) {
280+ if (has_nan_na ) {
281+ if (a_is_null ) {
282+ return 1 ;
283+ }
284+ else if (b_is_null ) {
285+ return -1 ;
286+ }
238287 }
239- else if (b_is_null ) {
240- return -1 ;
288+ else {
289+ // we must hold the GIL in this branch
290+ PyErr_SetString (
291+ PyExc_ValueError ,
292+ "Cannot compare null this is not a nan-like value" );
293+ return 0 ;
241294 }
242295 }
243296 else {
244297 if (a_is_null ) {
245- ss_a = & EMPTY_STRING ;
298+ ss_a = default_string ;
246299 }
247300 if (b_is_null ) {
248- ss_b = & EMPTY_STRING ;
301+ ss_b = default_string ;
249302 }
250303 }
251304 }
@@ -349,6 +402,94 @@ stringdtype_get_fill_zero_loop(void *NPY_UNUSED(traverse_context),
349402 return 0 ;
350403}
351404
405+ static int
406+ stringdtype_is_known_scalar_type (PyArray_DTypeMeta * NPY_UNUSED (cls ),
407+ PyTypeObject * pytype )
408+ {
409+ if (pytype == & PyFloat_Type ) {
410+ return 1 ;
411+ }
412+ if (pytype == & PyLong_Type ) {
413+ return 1 ;
414+ }
415+ if (pytype == & PyBool_Type ) {
416+ return 1 ;
417+ }
418+ if (pytype == & PyComplex_Type ) {
419+ return 1 ;
420+ }
421+ if (pytype == & PyUnicode_Type ) {
422+ return 1 ;
423+ }
424+ if (pytype == & PyBytes_Type ) {
425+ return 1 ;
426+ }
427+ if (pytype == & PyBoolArrType_Type ) {
428+ return 1 ;
429+ }
430+ if (pytype == & PyByteArrType_Type ) {
431+ return 1 ;
432+ }
433+ if (pytype == & PyShortArrType_Type ) {
434+ return 1 ;
435+ }
436+ if (pytype == & PyIntArrType_Type ) {
437+ return 1 ;
438+ }
439+ if (pytype == & PyLongArrType_Type ) {
440+ return 1 ;
441+ }
442+ if (pytype == & PyLongLongArrType_Type ) {
443+ return 1 ;
444+ }
445+ if (pytype == & PyUByteArrType_Type ) {
446+ return 1 ;
447+ }
448+ if (pytype == & PyUShortArrType_Type ) {
449+ return 1 ;
450+ }
451+ if (pytype == & PyUIntArrType_Type ) {
452+ return 1 ;
453+ }
454+ if (pytype == & PyULongArrType_Type ) {
455+ return 1 ;
456+ }
457+ if (pytype == & PyULongLongArrType_Type ) {
458+ return 1 ;
459+ }
460+ if (pytype == & PyHalfArrType_Type ) {
461+ return 1 ;
462+ }
463+ if (pytype == & PyFloatArrType_Type ) {
464+ return 1 ;
465+ }
466+ if (pytype == & PyDoubleArrType_Type ) {
467+ return 1 ;
468+ }
469+ if (pytype == & PyLongDoubleArrType_Type ) {
470+ return 1 ;
471+ }
472+ if (pytype == & PyCFloatArrType_Type ) {
473+ return 1 ;
474+ }
475+ if (pytype == & PyCDoubleArrType_Type ) {
476+ return 1 ;
477+ }
478+ if (pytype == & PyCLongDoubleArrType_Type ) {
479+ return 1 ;
480+ }
481+ if (pytype == & PyIntpArrType_Type ) {
482+ return 1 ;
483+ }
484+ if (pytype == & PyUIntpArrType_Type ) {
485+ return 1 ;
486+ }
487+ if (pytype == & PyDatetimeArrType_Type ) {
488+ return 1 ;
489+ }
490+ return 0 ;
491+ }
492+
352493static PyType_Slot StringDType_Slots [] = {
353494 {NPY_DT_common_instance , & common_instance },
354495 {NPY_DT_common_dtype , & common_dtype },
@@ -363,6 +504,7 @@ static PyType_Slot StringDType_Slots[] = {
363504 {NPY_DT_PyArray_ArrFuncs_argmin , & argmin },
364505 {NPY_DT_get_clear_loop , & stringdtype_get_clear_loop },
365506 {NPY_DT_get_fill_zero_loop , & stringdtype_get_fill_zero_loop },
507+ {_NPY_DT_is_known_scalar_type , & stringdtype_is_known_scalar_type },
366508 {0 , NULL }};
367509
368510static PyObject *
@@ -530,7 +672,7 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
530672 // pointer equality catches pandas.NA and other NA singletons
531673 eq = 1 ;
532674 }
533- else {
675+ else if ( PyFloat_Check ( sna ) && PyFloat_Check ( ona )) {
534676 // nan check catches np.nan and float('nan')
535677 double sna_float = PyFloat_AsDouble (sna );
536678 if (sna_float == -1.0 && PyErr_Occurred ()) {
@@ -543,13 +685,12 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
543685 if (npy_isnan (sna_float ) && npy_isnan (ona_float )) {
544686 eq = 1 ;
545687 }
546-
688+ }
689+ else {
547690 // finally check if a python equals comparison returns True
548- else if (PyObject_RichCompareBool (sna , ona , Py_EQ ) == 1 ) {
549- eq = 1 ;
550- }
551- else {
552- eq = 0 ;
691+ eq = PyObject_RichCompareBool (sna , ona , Py_EQ );
692+ if (eq == -1 ) {
693+ return NULL ;
553694 }
554695 }
555696
0 commit comments