44
55import unittest
66from unittest .mock import MagicMock , patch
7- import sys
8- from typing import List , Dict , Any
9-
10- # Add the necessary path to import the filter module
11- sys .path .append ("/home/varun.edachali/conn/databricks-sql-python/src" )
127
138from databricks .sql .backend .filters import ResultSetFilter
149
@@ -20,24 +15,39 @@ def setUp(self):
2015 """Set up test fixtures."""
2116 # Create a mock SeaResultSet
2217 self .mock_sea_result_set = MagicMock ()
23- self .mock_sea_result_set ._response = {
24- "result" : {
25- "data_array" : [
26- ["catalog1" , "schema1" , "table1" , "TABLE" , "" ],
27- ["catalog1" , "schema1" , "table2" , "VIEW" , "" ],
28- ["catalog1" , "schema1" , "table3" , "SYSTEM TABLE" , "" ],
29- ["catalog1" , "schema1" , "table4" , "EXTERNAL TABLE" , "" ],
30- ],
31- "row_count" : 4 ,
32- }
33- }
18+
19+ # Set up the remaining_rows method on the results attribute
20+ self .mock_sea_result_set .results = MagicMock ()
21+ self .mock_sea_result_set .results .remaining_rows .return_value = [
22+ ["catalog1" , "schema1" , "table1" , "owner1" , "2023-01-01" , "TABLE" , "" ],
23+ ["catalog1" , "schema1" , "table2" , "owner1" , "2023-01-01" , "VIEW" , "" ],
24+ [
25+ "catalog1" ,
26+ "schema1" ,
27+ "table3" ,
28+ "owner1" ,
29+ "2023-01-01" ,
30+ "SYSTEM TABLE" ,
31+ "" ,
32+ ],
33+ [
34+ "catalog1" ,
35+ "schema1" ,
36+ "table4" ,
37+ "owner1" ,
38+ "2023-01-01" ,
39+ "EXTERNAL TABLE" ,
40+ "" ,
41+ ],
42+ ]
3443
3544 # Set up the connection and other required attributes
3645 self .mock_sea_result_set .connection = MagicMock ()
3746 self .mock_sea_result_set .backend = MagicMock ()
3847 self .mock_sea_result_set .buffer_size_bytes = 1000
3948 self .mock_sea_result_set .arraysize = 100
4049 self .mock_sea_result_set .statement_id = "test-statement-id"
50+ self .mock_sea_result_set .lz4_compressed = False
4151
4252 # Create a mock CommandId
4353 from databricks .sql .backend .types import CommandId , BackendType
@@ -50,70 +60,102 @@ def setUp(self):
5060 ("catalog_name" , "string" , None , None , None , None , True ),
5161 ("schema_name" , "string" , None , None , None , None , True ),
5262 ("table_name" , "string" , None , None , None , None , True ),
63+ ("owner" , "string" , None , None , None , None , True ),
64+ ("creation_time" , "string" , None , None , None , None , True ),
5365 ("table_type" , "string" , None , None , None , None , True ),
5466 ("remarks" , "string" , None , None , None , None , True ),
5567 ]
5668 self .mock_sea_result_set .has_been_closed_server_side = False
69+ self .mock_sea_result_set ._arrow_schema_bytes = None
5770
58- def test_filter_tables_by_type (self ):
59- """Test filtering tables by type ."""
60- # Test with specific table types
61- table_types = ["TABLE " , "VIEW " ]
71+ def test_filter_by_column_values (self ):
72+ """Test filtering by column values with various options ."""
73+ # Case 1: Case-sensitive filtering
74+ allowed_values = ["table1 " , "table3 " ]
6275
63- # Make the mock_sea_result_set appear to be a SeaResultSet
6476 with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
6577 with patch (
6678 "databricks.sql.result_set.SeaResultSet"
6779 ) as mock_sea_result_set_class :
68- # Set up the mock to return a new mock when instantiated
6980 mock_instance = MagicMock ()
7081 mock_sea_result_set_class .return_value = mock_instance
7182
72- result = ResultSetFilter .filter_tables_by_type (
73- self .mock_sea_result_set , table_types
83+ # Call filter_by_column_values on the table_name column (index 2)
84+ result = ResultSetFilter .filter_by_column_values (
85+ self .mock_sea_result_set , 2 , allowed_values , case_sensitive = True
7486 )
7587
7688 # Verify the filter was applied correctly
7789 mock_sea_result_set_class .assert_called_once ()
7890
79- def test_filter_tables_by_type_case_insensitive (self ):
80- """Test filtering tables by type with case insensitivity."""
81- # Test with lowercase table types
82- table_types = ["table" , "view" ]
91+ # Check the filtered data passed to the constructor
92+ args , kwargs = mock_sea_result_set_class .call_args
93+ result_data = kwargs .get ("result_data" )
94+ self .assertIsNotNone (result_data )
95+ self .assertEqual (len (result_data .data ), 2 )
96+ self .assertIn (result_data .data [0 ][2 ], allowed_values )
97+ self .assertIn (result_data .data [1 ][2 ], allowed_values )
8398
84- # Make the mock_sea_result_set appear to be a SeaResultSet
99+ # Case 2: Case-insensitive filtering
100+ mock_sea_result_set_class .reset_mock ()
85101 with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
86102 with patch (
87103 "databricks.sql.result_set.SeaResultSet"
88104 ) as mock_sea_result_set_class :
89- # Set up the mock to return a new mock when instantiated
90105 mock_instance = MagicMock ()
91106 mock_sea_result_set_class .return_value = mock_instance
92107
93- result = ResultSetFilter .filter_tables_by_type (
94- self .mock_sea_result_set , table_types
108+ # Call filter_by_column_values with case-insensitive matching
109+ result = ResultSetFilter .filter_by_column_values (
110+ self .mock_sea_result_set ,
111+ 2 ,
112+ ["TABLE1" , "TABLE3" ],
113+ case_sensitive = False ,
95114 )
96-
97- # Verify the filter was applied correctly
98115 mock_sea_result_set_class .assert_called_once ()
99116
100- def test_filter_tables_by_type_default (self ):
101- """Test filtering tables by type with default types."""
102- # Make the mock_sea_result_set appear to be a SeaResultSet
103- with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
104- with patch (
105- "databricks.sql.result_set.SeaResultSet"
106- ) as mock_sea_result_set_class :
107- # Set up the mock to return a new mock when instantiated
108- mock_instance = MagicMock ()
109- mock_sea_result_set_class .return_value = mock_instance
117+ # Case 3: Unsupported result set type
118+ mock_unsupported_result_set = MagicMock ()
119+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = False ):
120+ with patch ("databricks.sql.backend.filters.logger" ) as mock_logger :
121+ result = ResultSetFilter .filter_by_column_values (
122+ mock_unsupported_result_set , 0 , ["value" ], True
123+ )
124+ mock_logger .warning .assert_called_once ()
125+ self .assertEqual (result , mock_unsupported_result_set )
126+
127+ def test_filter_tables_by_type (self ):
128+ """Test filtering tables by type with various options."""
129+ # Case 1: Specific table types
130+ table_types = ["TABLE" , "VIEW" ]
110131
111- result = ResultSetFilter .filter_tables_by_type (
112- self .mock_sea_result_set , None
132+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
133+ with patch .object (
134+ ResultSetFilter , "filter_by_column_values"
135+ ) as mock_filter :
136+ ResultSetFilter .filter_tables_by_type (
137+ self .mock_sea_result_set , table_types
113138 )
139+ args , kwargs = mock_filter .call_args
140+ self .assertEqual (args [0 ], self .mock_sea_result_set )
141+ self .assertEqual (args [1 ], 5 ) # Table type column index
142+ self .assertEqual (args [2 ], table_types )
143+ self .assertEqual (kwargs .get ("case_sensitive" ), True )
114144
115- # Verify the filter was applied correctly
116- mock_sea_result_set_class .assert_called_once ()
145+ # Case 2: Default table types (None or empty list)
146+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
147+ with patch .object (
148+ ResultSetFilter , "filter_by_column_values"
149+ ) as mock_filter :
150+ # Test with None
151+ ResultSetFilter .filter_tables_by_type (self .mock_sea_result_set , None )
152+ args , kwargs = mock_filter .call_args
153+ self .assertEqual (args [2 ], ["TABLE" , "VIEW" , "SYSTEM TABLE" ])
154+
155+ # Test with empty list
156+ ResultSetFilter .filter_tables_by_type (self .mock_sea_result_set , [])
157+ args , kwargs = mock_filter .call_args
158+ self .assertEqual (args [2 ], ["TABLE" , "VIEW" , "SYSTEM TABLE" ])
117159
118160
119161if __name__ == "__main__" :
0 commit comments