|
10 | 10 | from pandas.core.dtypes.common import ( |
11 | 11 | _ensure_platform_int, |
12 | 12 | is_list_like, is_bool_dtype, |
13 | | - needs_i8_conversion, is_sparse) |
| 13 | + needs_i8_conversion, is_sparse, is_object_dtype) |
14 | 14 | from pandas.core.dtypes.cast import maybe_promote |
15 | 15 | from pandas.core.dtypes.missing import notna |
16 | 16 |
|
@@ -697,7 +697,7 @@ def _convert_level_number(level_num, columns): |
697 | 697 |
|
698 | 698 |
|
699 | 699 | def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, |
700 | | - columns=None, sparse=False, drop_first=False): |
| 700 | + columns=None, sparse=False, drop_first=False, dtype=None): |
701 | 701 | """ |
702 | 702 | Convert categorical variable into dummy/indicator variables |
703 | 703 |
|
@@ -728,6 +728,11 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, |
728 | 728 |
|
729 | 729 | .. versionadded:: 0.18.0 |
730 | 730 |
|
| 731 | + dtype : dtype, default np.uint8 |
| 732 | + Data type for new columns. Only a single dtype is allowed. |
| 733 | +
|
| 734 | + .. versionadded:: 0.22.0 |
| 735 | +
|
731 | 736 | Returns |
732 | 737 | ------- |
733 | 738 | dummies : DataFrame or SparseDataFrame |
@@ -783,6 +788,12 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, |
783 | 788 | 3 0 0 |
784 | 789 | 4 0 0 |
785 | 790 |
|
| 791 | + >>> pd.get_dummies(pd.Series(list('abc')), dtype=float) |
| 792 | + a b c |
| 793 | + 0 1.0 0.0 0.0 |
| 794 | + 1 0.0 1.0 0.0 |
| 795 | + 2 0.0 0.0 1.0 |
| 796 | +
|
786 | 797 | See Also |
787 | 798 | -------- |
788 | 799 | Series.str.get_dummies |
@@ -835,20 +846,29 @@ def check_len(item, name): |
835 | 846 |
|
836 | 847 | dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep, |
837 | 848 | dummy_na=dummy_na, sparse=sparse, |
838 | | - drop_first=drop_first) |
| 849 | + drop_first=drop_first, dtype=dtype) |
839 | 850 | with_dummies.append(dummy) |
840 | 851 | result = concat(with_dummies, axis=1) |
841 | 852 | else: |
842 | 853 | result = _get_dummies_1d(data, prefix, prefix_sep, dummy_na, |
843 | | - sparse=sparse, drop_first=drop_first) |
| 854 | + sparse=sparse, |
| 855 | + drop_first=drop_first, |
| 856 | + dtype=dtype) |
844 | 857 | return result |
845 | 858 |
|
846 | 859 |
|
847 | 860 | def _get_dummies_1d(data, prefix, prefix_sep='_', dummy_na=False, |
848 | | - sparse=False, drop_first=False): |
| 861 | + sparse=False, drop_first=False, dtype=None): |
849 | 862 | # Series avoids inconsistent NaN handling |
850 | 863 | codes, levels = _factorize_from_iterable(Series(data)) |
851 | 864 |
|
| 865 | + if dtype is None: |
| 866 | + dtype = np.uint8 |
| 867 | + dtype = np.dtype(dtype) |
| 868 | + |
| 869 | + if is_object_dtype(dtype): |
| 870 | + raise ValueError("dtype=object is not a valid dtype for get_dummies") |
| 871 | + |
852 | 872 | def get_empty_Frame(data, sparse): |
853 | 873 | if isinstance(data, Series): |
854 | 874 | index = data.index |
@@ -903,18 +923,18 @@ def get_empty_Frame(data, sparse): |
903 | 923 | sp_indices = sp_indices[1:] |
904 | 924 | dummy_cols = dummy_cols[1:] |
905 | 925 | for col, ixs in zip(dummy_cols, sp_indices): |
906 | | - sarr = SparseArray(np.ones(len(ixs), dtype=np.uint8), |
| 926 | + sarr = SparseArray(np.ones(len(ixs), dtype=dtype), |
907 | 927 | sparse_index=IntIndex(N, ixs), fill_value=0, |
908 | | - dtype=np.uint8) |
| 928 | + dtype=dtype) |
909 | 929 | sparse_series[col] = SparseSeries(data=sarr, index=index) |
910 | 930 |
|
911 | 931 | out = SparseDataFrame(sparse_series, index=index, columns=dummy_cols, |
912 | 932 | default_fill_value=0, |
913 | | - dtype=np.uint8) |
| 933 | + dtype=dtype) |
914 | 934 | return out |
915 | 935 |
|
916 | 936 | else: |
917 | | - dummy_mat = np.eye(number_of_cols, dtype=np.uint8).take(codes, axis=0) |
| 937 | + dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=0) |
918 | 938 |
|
919 | 939 | if not dummy_na: |
920 | 940 | # reset NaN GH4446 |
|
0 commit comments