55
66PyTypeObject * StringScalar_Type = NULL ;
77static PyTypeObject * StringNA_Type = NULL ;
8- static PyObject * NA_OBJ = NULL ;
8+ PyObject * NA_OBJ = NULL ;
99
1010/*
1111 * Internal helper to create new instances
1212 */
1313StringDTypeObject *
14- new_stringdtype_instance (void )
14+ new_stringdtype_instance (PyObject * na_object )
1515{
1616 StringDTypeObject * new = (StringDTypeObject * )PyArrayDescr_Type .tp_new (
1717 (PyTypeObject * )& StringDType , NULL , NULL );
1818 if (new == NULL ) {
1919 return NULL ;
2020 }
21+ Py_INCREF (na_object );
22+ new -> na_object = na_object ;
2123 new -> base .elsize = sizeof (ss );
2224 new -> base .alignment = _Alignof(ss );
2325 new -> base .flags |= NPY_NEEDS_INIT ;
@@ -72,15 +74,15 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
7274 return NULL ;
7375 }
7476
75- PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance ();
77+ PyArray_Descr * ret = (PyArray_Descr * )new_stringdtype_instance (NA_OBJ );
7678 if (ret == NULL ) {
7779 return NULL ;
7880 }
7981 return ret ;
8082}
8183
8284static PyObject *
83- get_value (PyObject * scalar )
85+ get_value (PyObject * scalar , PyObject * na_object )
8486{
8587 PyObject * ret = NULL ;
8688 PyTypeObject * scalar_type = Py_TYPE (scalar );
@@ -96,7 +98,7 @@ get_value(PyObject *scalar)
9698 return NULL ;
9799 }
98100 }
99- else if (scalar_type == StringNA_Type ) {
101+ else if (scalar == na_object ) {
100102 ret = scalar ;
101103 Py_INCREF (ret );
102104 }
@@ -107,7 +109,7 @@ get_value(PyObject *scalar)
107109 return NULL ;
108110 }
109111 if (npy_isnan (scalar_val )) {
110- ret = NA_OBJ ;
112+ ret = na_object ;
111113 Py_INCREF (ret );
112114 }
113115 else {
@@ -128,10 +130,9 @@ get_value(PyObject *scalar)
128130// Take a python object `obj` and insert it into the array of dtype `descr` at
129131// the position given by dataptr.
130132static int
131- stringdtype_setitem (StringDTypeObject * NPY_UNUSED (descr ), PyObject * obj ,
132- char * * dataptr )
133+ stringdtype_setitem (StringDTypeObject * descr , PyObject * obj , char * * dataptr )
133134{
134- PyObject * val_obj = get_value (obj );
135+ PyObject * val_obj = get_value (obj , descr -> na_object );
135136
136137 if (val_obj == NULL ) {
137138 return -1 ;
@@ -143,15 +144,9 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
143144 // ssfree does a NULL check
144145 ssfree (sdata );
145146
146- // RichCompareBool short-circuits to a pointer comparison fast-path
147- // so no need to do pointer comparison first
148- int eq_res = PyObject_RichCompareBool (val_obj , NA_OBJ , Py_EQ );
149-
150- if (eq_res < 0 ) {
151- goto error ;
152- }
153-
154- if (eq_res == 1 ) {
147+ // setting NA *must* check pointer equality since NA types might not
148+ // allow equality
149+ if (val_obj == descr -> na_object ) {
155150 // do nothing, ssfree already NULLed the struct ssdata points to
156151 // so it already contains a NA value
157152 }
@@ -185,14 +180,14 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
185180}
186181
187182static PyObject *
188- stringdtype_getitem (StringDTypeObject * NPY_UNUSED ( descr ) , char * * dataptr )
183+ stringdtype_getitem (StringDTypeObject * descr , char * * dataptr )
189184{
190185 PyObject * val_obj = NULL ;
191186 ss * sdata = (ss * )dataptr ;
192187
193188 if (ss_isnull (sdata )) {
194- Py_INCREF (NA_OBJ );
195- val_obj = NA_OBJ ;
189+ Py_INCREF (descr -> na_object );
190+ val_obj = descr -> na_object ;
196191 }
197192 else {
198193 char * data = sdata -> buf ;
@@ -359,28 +354,48 @@ static PyType_Slot StringDType_Slots[] = {
359354static PyObject *
360355stringdtype_new (PyTypeObject * NPY_UNUSED (cls ), PyObject * args , PyObject * kwds )
361356{
362- static char * kwargs_strs [] = {"size" , NULL };
357+ static char * kwargs_strs [] = {"size" , "na_object" , NULL };
363358
364359 long size = 0 ;
360+ PyObject * na_object = NULL ;
365361
366- if (!PyArg_ParseTupleAndKeywords (args , kwds , "|l :StringDType" , kwargs_strs ,
367- & size )) {
362+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "|lO :StringDType" ,
363+ kwargs_strs , & size , & na_object )) {
368364 return NULL ;
369365 }
370366
371- return (PyObject * )new_stringdtype_instance ();
367+ if (na_object == NULL ) {
368+ na_object = NA_OBJ ;
369+ }
370+
371+ Py_INCREF (na_object );
372+
373+ PyObject * ret = (PyObject * )new_stringdtype_instance (na_object );
374+
375+ Py_DECREF (na_object );
376+
377+ return ret ;
372378}
373379
374380static void
375381stringdtype_dealloc (StringDTypeObject * self )
376382{
383+ Py_DECREF (self -> na_object );
377384 PyArrayDescr_Type .tp_dealloc ((PyObject * )self );
378385}
379386
380387static PyObject *
381- stringdtype_repr (StringDTypeObject * NPY_UNUSED ( self ) )
388+ stringdtype_repr (StringDTypeObject * self )
382389{
383- return PyUnicode_FromString ("StringDType()" );
390+ PyObject * ret = NULL ;
391+ if (self -> na_object != NA_OBJ ) {
392+ ret = PyUnicode_FromFormat ("StringDType(na_object=%R)" ,
393+ self -> na_object );
394+ }
395+ else {
396+ ret = PyUnicode_FromString ("StringDType()" );
397+ }
398+ return ret ;
384399}
385400
386401static int PICKLE_VERSION = 1 ;
@@ -485,6 +500,7 @@ init_string_dtype(void)
485500 PyArrayMethod_Spec * * casts = get_casts ();
486501
487502 PyArrayDTypeMeta_Spec StringDType_DTypeSpec = {
503+ .flags = NPY_DT_PARAMETRIC ,
488504 .typeobj = StringScalar_Type ,
489505 .slots = StringDType_Slots ,
490506 .casts = casts ,
0 commit comments