@@ -2479,6 +2479,93 @@ def _check_set(df, cond, check_dtypes=True):
24792479 expected = df [df ['a' ] == 1 ].reindex (df .index )
24802480 assert_frame_equal (result , expected )
24812481
2482+ def test_where_array_like (self ):
2483+ # see gh-15414
2484+ klasses = [list , tuple , np .array ]
2485+
2486+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2487+ cond = [[False ], [True ], [True ]]
2488+ expected = DataFrame ({'a' : [np .nan , 2 , 3 ]})
2489+
2490+ for klass in klasses :
2491+ result = df .where (klass (cond ))
2492+ assert_frame_equal (result , expected )
2493+
2494+ df ['b' ] = 2
2495+ expected ['b' ] = [2 , np .nan , 2 ]
2496+ cond = [[False , True ], [True , False ], [True , True ]]
2497+
2498+ for klass in klasses :
2499+ result = df .where (klass (cond ))
2500+ assert_frame_equal (result , expected )
2501+
2502+ def test_where_invalid_input (self ):
2503+ # see gh-15414: only boolean arrays accepted
2504+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2505+ msg = "Boolean array expected for the condition"
2506+
2507+ conds = [
2508+ [[1 ], [0 ], [1 ]],
2509+ Series ([[2 ], [5 ], [7 ]]),
2510+ [["True" ], ["False" ], ["True" ]],
2511+ [[Timestamp ("2017-01-01" )],
2512+ [pd .NaT ], [Timestamp ("2017-01-02" )]]
2513+ ]
2514+
2515+ for cond in conds :
2516+ with tm .assertRaisesRegexp (ValueError , msg ):
2517+ df .where (cond )
2518+
2519+ df ['b' ] = 2
2520+ conds = [
2521+ [[0 , 1 ], [1 , 0 ], [1 , 1 ]],
2522+ Series ([[0 , 2 ], [5 , 0 ], [4 , 7 ]]),
2523+ [["False" , "True" ], ["True" , "False" ],
2524+ ["True" , "True" ]],
2525+ [[pd .NaT , Timestamp ("2017-01-01" )],
2526+ [Timestamp ("2017-01-02" ), pd .NaT ],
2527+ [Timestamp ("2017-01-03" ), Timestamp ("2017-01-03" )]]
2528+ ]
2529+
2530+ for cond in conds :
2531+ with tm .assertRaisesRegexp (ValueError , msg ):
2532+ df .where (cond )
2533+
2534+ def test_where_dataframe_col_match (self ):
2535+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2536+ cond = DataFrame ([[True , False , True ], [False , False , True ]])
2537+
2538+ out = df .where (cond )
2539+ expected = DataFrame ([[1.0 , np .nan , 3 ], [np .nan , np .nan , 6 ]])
2540+ tm .assert_frame_equal (out , expected )
2541+
2542+ cond .columns = ["a" , "b" , "c" ] # Columns no longer match.
2543+ msg = "Boolean array expected for the condition"
2544+ with tm .assertRaisesRegexp (ValueError , msg ):
2545+ df .where (cond )
2546+
2547+ def test_where_ndframe_align (self ):
2548+ msg = "Array conditional must be same shape as self"
2549+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2550+
2551+ cond = [True ]
2552+ with tm .assertRaisesRegexp (ValueError , msg ):
2553+ df .where (cond )
2554+
2555+ expected = DataFrame ([[1 , 2 , 3 ], [np .nan , np .nan , np .nan ]])
2556+
2557+ out = df .where (Series (cond ))
2558+ tm .assert_frame_equal (out , expected )
2559+
2560+ cond = np .array ([False , True , False , True ])
2561+ with tm .assertRaisesRegexp (ValueError , msg ):
2562+ df .where (cond )
2563+
2564+ expected = DataFrame ([[np .nan , np .nan , np .nan ], [4 , 5 , 6 ]])
2565+
2566+ out = df .where (Series (cond ))
2567+ tm .assert_frame_equal (out , expected )
2568+
24822569 def test_where_bug (self ):
24832570
24842571 # GH 2793
0 commit comments