@@ -250,13 +250,13 @@ cdef class HashTable:
250250
251251{{py:
252252
253- # name, dtype, null_condition, float_group
254- dtypes = [('Float64', 'float64', 'val != val', True ),
255- ('UInt64', 'uint64', ' False', False ),
256- ('Int64', 'int64', 'val == iNaT', False )]
253+ # name, dtype, float_group, default_na_value
254+ dtypes = [('Float64', 'float64', True, 'nan' ),
255+ ('UInt64', 'uint64', False, 0 ),
256+ ('Int64', 'int64', False, ' iNaT')]
257257
258258def get_dispatch(dtypes):
259- for (name, dtype, null_condition, float_group ) in dtypes:
259+ for (name, dtype, float_group, default_na_value ) in dtypes:
260260 unique_template = """\
261261 cdef:
262262 Py_ssize_t i, n = len(values)
@@ -298,13 +298,13 @@ def get_dispatch(dtypes):
298298 return uniques.to_array()
299299 """
300300
301- unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group)
301+ unique_template = unique_template.format(name=name, dtype=dtype, float_group=float_group)
302302
303- yield (name, dtype, null_condition, float_group , unique_template)
303+ yield (name, dtype, float_group, default_na_value , unique_template)
304304}}
305305
306306
307- {{for name, dtype, null_condition, float_group , unique_template in get_dispatch(dtypes)}}
307+ {{for name, dtype, float_group, default_na_value , unique_template in get_dispatch(dtypes)}}
308308
309309cdef class {{name}}HashTable(HashTable):
310310
@@ -408,24 +408,36 @@ cdef class {{name}}HashTable(HashTable):
408408 @cython.boundscheck(False)
409409 def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
410410 Py_ssize_t count_prior, Py_ssize_t na_sentinel,
411- bint check_null=True ):
411+ object na_value=None ):
412412 cdef:
413413 Py_ssize_t i, n = len(values)
414414 int64_t[:] labels
415415 Py_ssize_t idx, count = count_prior
416416 int ret = 0
417- {{dtype}}_t val
417+ {{dtype}}_t val, na_value2
418418 khiter_t k
419419 {{name}}VectorData *ud
420+ bint use_na_value
420421
421422 labels = np.empty(n, dtype=np.int64)
422423 ud = uniques.data
424+ use_na_value = na_value is not None
425+
426+ if use_na_value:
427+ # We need this na_value2 because we want to allow users
428+ # to *optionally* specify an NA sentinel *of the correct* type.
429+ # We use None, to make it optional, which requires `object` type
430+ # for the parameter. To please the compiler, we use na_value2,
431+ # which is only used if it's *specified*.
432+ na_value2 = <{{dtype}}_t>na_value
433+ else:
434+ na_value2 = {{default_na_value}}
423435
424436 with nogil:
425437 for i in range(n):
426438 val = values[i]
427439
428- if check_null and {{null_condition}} :
440+ if val != val or (use_na_value and val == na_value2) :
429441 labels[i] = na_sentinel
430442 continue
431443
@@ -695,7 +707,7 @@ cdef class StringHashTable(HashTable):
695707 @cython.boundscheck(False)
696708 def get_labels(self, ndarray[object] values, ObjectVector uniques,
697709 Py_ssize_t count_prior, int64_t na_sentinel,
698- bint check_null=1 ):
710+ object na_value=None ):
699711 cdef:
700712 Py_ssize_t i, n = len(values)
701713 int64_t[:] labels
@@ -706,18 +718,21 @@ cdef class StringHashTable(HashTable):
706718 char *v
707719 char **vecs
708720 khiter_t k
721+ bint use_na_value
709722
710723 # these by-definition *must* be strings
711724 labels = np.zeros(n, dtype=np.int64)
712725 uindexer = np.empty(n, dtype=np.int64)
726+ use_na_value = na_value is not None
713727
714728 # pre-filter out missing
715729 # and assign pointers
716730 vecs = <char **> malloc(n * sizeof(char *))
717731 for i in range(n):
718732 val = values[i]
719733
720- if PyUnicode_Check(val) or PyString_Check(val):
734+ if ((PyUnicode_Check(val) or PyString_Check(val)) and
735+ not (use_na_value and val == na_value)):
721736 v = util.get_c_string(val)
722737 vecs[i] = v
723738 else:
@@ -868,22 +883,25 @@ cdef class PyObjectHashTable(HashTable):
868883
869884 def get_labels(self, ndarray[object] values, ObjectVector uniques,
870885 Py_ssize_t count_prior, int64_t na_sentinel,
871- bint check_null=True ):
886+ object na_value=None ):
872887 cdef:
873888 Py_ssize_t i, n = len(values)
874889 int64_t[:] labels
875890 Py_ssize_t idx, count = count_prior
876891 int ret = 0
877892 object val
878893 khiter_t k
894+ bint use_na_value
879895
880896 labels = np.empty(n, dtype=np.int64)
897+ use_na_value = na_value is not None
881898
882899 for i in range(n):
883900 val = values[i]
884901 hash(val)
885902
886- if check_null and val != val or val is None:
903+ if ((val != val or val is None) or
904+ (use_na_value and val == na_value)):
887905 labels[i] = na_sentinel
888906 continue
889907
0 commit comments