@@ -2479,6 +2479,95 @@ def _check_set(df, cond, check_dtypes=True):
24792479 expected = df [df ['a' ] == 1 ].reindex (df .index )
24802480 assert_frame_equal (result , expected )
24812481
2482+ def test_where_array_like (self ):
2483+ # see gh-15414
2484+ klasses = [list , tuple , np .array ]
2485+
2486+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2487+ cond = [[False ], [True ], [True ]]
2488+ expected = DataFrame ({'a' : [np .nan , 2 , 3 ]})
2489+
2490+ for klass in klasses :
2491+ result = df .where (klass (cond ))
2492+ assert_frame_equal (result , expected )
2493+
2494+ df ['b' ] = 2
2495+ expected ['b' ] = [2 , np .nan , 2 ]
2496+ cond = [[False , True ], [True , False ], [True , True ]]
2497+
2498+ for klass in klasses :
2499+ result = df .where (klass (cond ))
2500+ assert_frame_equal (result , expected )
2501+
2502+ def test_where_invalid_input (self ):
2503+ # see gh-15414: only boolean arrays accepted
2504+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2505+ msg = "Boolean array expected for the condition"
2506+
2507+ conds = [
2508+ [[1 ], [0 ], [1 ]],
2509+ Series ([[2 ], [5 ], [7 ]]),
2510+ DataFrame ({'a' : [2 , 5 , 7 ]}),
2511+ [["True" ], ["False" ], ["True" ]],
2512+ [[Timestamp ("2017-01-01" )],
2513+ [pd .NaT ], [Timestamp ("2017-01-02" )]]
2514+ ]
2515+
2516+ for cond in conds :
2517+ with tm .assertRaisesRegexp (ValueError , msg ):
2518+ df .where (cond )
2519+
2520+ df ['b' ] = 2
2521+ conds = [
2522+ [[0 , 1 ], [1 , 0 ], [1 , 1 ]],
2523+ Series ([[0 , 2 ], [5 , 0 ], [4 , 7 ]]),
2524+ [["False" , "True" ], ["True" , "False" ],
2525+ ["True" , "True" ]],
2526+ DataFrame ({'a' : [2 , 5 , 7 ], 'b' : [4 , 8 , 9 ]}),
2527+ [[pd .NaT , Timestamp ("2017-01-01" )],
2528+ [Timestamp ("2017-01-02" ), pd .NaT ],
2529+ [Timestamp ("2017-01-03" ), Timestamp ("2017-01-03" )]]
2530+ ]
2531+
2532+ for cond in conds :
2533+ with tm .assertRaisesRegexp (ValueError , msg ):
2534+ df .where (cond )
2535+
2536+ def test_where_dataframe_col_match (self ):
2537+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2538+ cond = DataFrame ([[True , False , True ], [False , False , True ]])
2539+
2540+ out = df .where (cond )
2541+ expected = DataFrame ([[1.0 , np .nan , 3 ], [np .nan , np .nan , 6 ]])
2542+ tm .assert_frame_equal (out , expected )
2543+
2544+ cond .columns = ["a" , "b" , "c" ] # Columns no longer match.
2545+ msg = "Boolean array expected for the condition"
2546+ with tm .assertRaisesRegexp (ValueError , msg ):
2547+ df .where (cond )
2548+
2549+ def test_where_ndframe_align (self ):
2550+ msg = "Array conditional must be same shape as self"
2551+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2552+
2553+ cond = [True ]
2554+ with tm .assertRaisesRegexp (ValueError , msg ):
2555+ df .where (cond )
2556+
2557+ expected = DataFrame ([[1 , 2 , 3 ], [np .nan , np .nan , np .nan ]])
2558+
2559+ out = df .where (Series (cond ))
2560+ tm .assert_frame_equal (out , expected )
2561+
2562+ cond = np .array ([False , True , False , True ])
2563+ with tm .assertRaisesRegexp (ValueError , msg ):
2564+ df .where (cond )
2565+
2566+ expected = DataFrame ([[np .nan , np .nan , np .nan ], [4 , 5 , 6 ]])
2567+
2568+ out = df .where (Series (cond ))
2569+ tm .assert_frame_equal (out , expected )
2570+
24822571 def test_where_bug (self ):
24832572
24842573 # GH 2793
0 commit comments