@@ -850,6 +850,31 @@ def test_astype_column_metadata(self, dtype):
850850 df = df .astype (dtype )
851851 tm .assert_index_equal (df .columns , columns )
852852
853+ def test_df_where_change_dtype (self ):
854+ # GH 16979
855+ df = DataFrame (np .arange (2 * 3 ).reshape (2 , 3 ), columns = list ("ABC" ))
856+ mask = np .array ([[True , False , False ], [False , False , True ]])
857+
858+ result = df .where (mask )
859+ expected = DataFrame (
860+ [[0 , np .nan , np .nan ], [np .nan , np .nan , 5 ]], columns = list ("ABC" )
861+ )
862+
863+ tm .assert_frame_equal (result , expected )
864+
865+ # change type to category
866+ df .A = df .A .astype ("category" )
867+ df .B = df .B .astype ("category" )
868+ df .C = df .C .astype ("category" )
869+
870+ result = df .where (mask )
871+ A = pd .Categorical ([0 , np .nan ], categories = [0 , 3 ])
872+ B = pd .Categorical ([np .nan , np .nan ], categories = [1 , 4 ])
873+ C = pd .Categorical ([np .nan , 5 ], categories = [2 , 5 ])
874+ expected = DataFrame ({"A" : A , "B" : B , "C" : C })
875+
876+ tm .assert_frame_equal (result , expected )
877+
853878 @pytest .mark .parametrize ("dtype" , ["M8" , "m8" ])
854879 @pytest .mark .parametrize ("unit" , ["ns" , "us" , "ms" , "s" , "h" , "m" , "D" ])
855880 def test_astype_from_datetimelike_to_objectt (self , dtype , unit ):
0 commit comments