@@ -254,12 +254,21 @@ def test_read_dta4(self, file):
254254 )
255255
256256 # these are all categoricals
257- expected = pd .concat (
258- [expected [col ].astype ("category" ) for col in expected ], axis = 1
259- )
257+ for col in expected :
258+ orig = expected [col ].copy ()
259+
260+ categories = np .asarray (expected ["fully_labeled" ][orig .notna ()])
261+ if col == "incompletely_labeled" :
262+ categories = orig
263+
264+ cat = orig .astype ("category" )._values
265+ cat = cat .set_categories (categories , ordered = True )
266+ cat .categories .rename (None , inplace = True )
267+
268+ expected [col ] = cat
260269
261270 # stata doesn't save .category metadata
262- tm .assert_frame_equal (parsed , expected , check_categorical = False )
271+ tm .assert_frame_equal (parsed , expected )
263272
264273 # File containing strls
265274 def test_read_dta12 (self ):
@@ -952,19 +961,27 @@ def test_categorical_writing(self, version):
952961 original = pd .concat (
953962 [original [col ].astype ("category" ) for col in original ], axis = 1
954963 )
964+ expected .index .name = "index"
955965
956966 expected ["incompletely_labeled" ] = expected ["incompletely_labeled" ].apply (str )
957967 expected ["unlabeled" ] = expected ["unlabeled" ].apply (str )
958- expected = pd .concat (
959- [expected [col ].astype ("category" ) for col in expected ], axis = 1
960- )
961- expected .index .name = "index"
968+ for col in expected :
969+ orig = expected [col ].copy ()
970+
971+ cat = orig .astype ("category" )._values
972+ cat = cat .as_ordered ()
973+ if col == "unlabeled" :
974+ cat = cat .set_categories (orig , ordered = True )
975+
976+ cat .categories .rename (None , inplace = True )
977+
978+ expected [col ] = cat
962979
963980 with tm .ensure_clean () as path :
964981 original .to_stata (path , version = version )
965982 written_and_read_again = self .read_dta (path )
966983 res = written_and_read_again .set_index ("index" )
967- tm .assert_frame_equal (res , expected , check_categorical = False )
984+ tm .assert_frame_equal (res , expected )
968985
969986 def test_categorical_warnings_and_errors (self ):
970987 # Warning for non-string labels
@@ -1056,9 +1073,11 @@ def test_categorical_sorting(self, file):
10561073 parsed .index = np .arange (parsed .shape [0 ])
10571074 codes = [- 1 , - 1 , 0 , 1 , 1 , 1 , 2 , 2 , 3 , 4 ]
10581075 categories = ["Poor" , "Fair" , "Good" , "Very good" , "Excellent" ]
1059- cat = pd .Categorical .from_codes (codes = codes , categories = categories )
1076+ cat = pd .Categorical .from_codes (
1077+ codes = codes , categories = categories , ordered = True
1078+ )
10601079 expected = pd .Series (cat , name = "srh" )
1061- tm .assert_series_equal (expected , parsed ["srh" ], check_categorical = False )
1080+ tm .assert_series_equal (expected , parsed ["srh" ])
10621081
10631082 @pytest .mark .parametrize ("file" , ["dta19_115" , "dta19_117" ])
10641083 def test_categorical_ordering (self , file ):
0 commit comments