11import logging
2- import uuid
32import time
43import re
5- from typing import Dict , Tuple , List , Optional , Any , Union , TYPE_CHECKING , Set
4+ from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
65
7- from databricks .sql .backend .sea .models .base import ExternalLink
6+ from databricks .sql .backend .sea .models .base import ResultManifest
87from databricks .sql .backend .sea .utils .constants import (
98 ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
109 ResultFormat ,
1110 ResultDisposition ,
1211 ResultCompression ,
1312 WaitTimeout ,
13+ MetadataCommands ,
1414)
1515
1616if TYPE_CHECKING :
2525 BackendType ,
2626 ExecuteResponse ,
2727)
28- from databricks .sql .exc import ServerOperationError
28+ from databricks .sql .exc import DatabaseError , ServerOperationError
2929from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
30- from databricks .sql .thrift_api .TCLIService import ttypes
3130from databricks .sql .types import SSLOptions
3231
3332from databricks .sql .backend .sea .models import (
4140 ExecuteStatementResponse ,
4241 GetStatementResponse ,
4342 CreateSessionResponse ,
44- GetChunksResponse ,
4543)
4644from databricks .sql .backend .sea .models .responses import (
47- parse_status ,
48- parse_manifest ,
49- parse_result ,
45+ _parse_status ,
46+ _parse_manifest ,
47+ _parse_result ,
5048)
5149
5250logger = logging .getLogger (__name__ )
@@ -92,7 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
9290 STATEMENT_PATH = BASE_PATH + "statements"
9391 STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9492 CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
95- CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
93+
94+ # SEA constants
95+ POLL_INTERVAL_SECONDS = 0.2
9696
9797 def __init__ (
9898 self ,
@@ -124,7 +124,7 @@ def __init__(
124124 http_path ,
125125 )
126126
127- super (). __init__ ( ssl_options , ** kwargs )
127+ self . _max_download_threads = kwargs . get ( "max_download_threads" , 10 )
128128
129129 # Extract warehouse ID from http_path
130130 self .warehouse_id = self ._extract_warehouse_id (http_path )
@@ -136,7 +136,7 @@ def __init__(
136136 http_path = http_path ,
137137 http_headers = http_headers ,
138138 auth_provider = auth_provider ,
139- ssl_options = self . _ssl_options ,
139+ ssl_options = ssl_options ,
140140 ** kwargs ,
141141 )
142142
@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291291 """
292292 return list (ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP .keys ())
293293
294- def _extract_description_from_manifest (self , manifest_obj ) -> Optional [List ]:
294+ def _extract_description_from_manifest (
295+ self , manifest : ResultManifest
296+ ) -> Optional [List ]:
295297 """
296- Extract column description from a manifest object.
298+ Extract column description from a manifest object, in the format defined by
299+ the spec: https://peps.python.org/pep-0249/#description
297300
298301 Args:
299- manifest_obj : The ResultManifest object containing schema information
302+ manifest : The ResultManifest object containing schema information
300303
301304 Returns:
302305 Optional[List]: A list of column tuples or None if no columns are found
303306 """
304307
305- schema_data = manifest_obj .schema
308+ schema_data = manifest .schema
306309 columns_data = schema_data .get ("columns" , [])
307310
308311 if not columns_data :
309312 return None
310313
311314 columns = []
312315 for col_data in columns_data :
313- if not isinstance (col_data , dict ):
314- continue
315-
316316 # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317317 columns .append (
318318 (
@@ -328,38 +328,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
328328
329329 return columns if columns else None
330330
331- def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
332- """
333- Get links for chunks starting from the specified index.
334-
335- Args:
336- statement_id: The statement ID
337- chunk_index: The starting chunk index
338-
339- Returns:
340- ExternalLink: External link for the chunk
341- """
342-
343- response_data = self .http_client ._make_request (
344- method = "GET" ,
345- path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
346- )
347- response = GetChunksResponse .from_dict (response_data )
348-
349- links = response .external_links
350- link = next ((l for l in links if l .chunk_index == chunk_index ), None )
351- if not link :
352- raise ServerOperationError (
353- f"No link found for chunk index { chunk_index } " ,
354- {
355- "operation-id" : statement_id ,
356- "diagnostic-info" : None ,
357- },
358- )
359-
360- return link
361-
362- def _results_message_to_execute_response (self , sea_response , command_id ):
331+ def _results_message_to_execute_response (
332+ self , response : GetStatementResponse
333+ ) -> ExecuteResponse :
363334 """
364335 Convert a SEA response to an ExecuteResponse and extract result data.
365336
@@ -368,33 +339,65 @@ def _results_message_to_execute_response(self, sea_response, command_id):
368339 command_id: The command ID
369340
370341 Returns:
371- tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372- result data object, and manifest object
342+ ExecuteResponse: The normalized execute response
373343 """
374344
375- # Parse the response
376- status = parse_status (sea_response )
377- manifest_obj = parse_manifest (sea_response )
378- result_data_obj = parse_result (sea_response )
379-
380345 # Extract description from manifest schema
381- description = self ._extract_description_from_manifest (manifest_obj )
346+ description = self ._extract_description_from_manifest (response . manifest )
382347
383348 # Check for compression
384- lz4_compressed = manifest_obj .result_compression == "LZ4_FRAME"
349+ lz4_compressed = (
350+ response .manifest .result_compression == ResultCompression .LZ4_FRAME
351+ )
385352
386353 execute_response = ExecuteResponse (
387- command_id = command_id ,
388- status = status .state ,
354+ command_id = CommandId . from_sea_statement_id ( response . statement_id ) ,
355+ status = response . status .state ,
389356 description = description ,
390357 has_been_closed_server_side = False ,
391358 lz4_compressed = lz4_compressed ,
392359 is_staging_operation = False ,
393360 arrow_schema_bytes = None ,
394- result_format = manifest_obj .format ,
361+ result_format = response . manifest .format ,
395362 )
396363
397- return execute_response , result_data_obj , manifest_obj
364+ return execute_response
365+
366+ def _check_command_not_in_failed_or_closed_state (
367+ self , state : CommandState , command_id : CommandId
368+ ) -> None :
369+ if state == CommandState .CLOSED :
370+ raise DatabaseError (
371+ "Command {} unexpectedly closed server side" .format (command_id ),
372+ {
373+ "operation-id" : command_id ,
374+ },
375+ )
376+ if state == CommandState .FAILED :
377+ raise ServerOperationError (
378+ "Command {} failed" .format (command_id ),
379+ {
380+ "operation-id" : command_id ,
381+ },
382+ )
383+
384+ def _wait_until_command_done (
385+ self , response : ExecuteStatementResponse
386+ ) -> CommandState :
387+ """
388+ Wait until a command is done.
389+ """
390+
391+ state = response .status .state
392+ command_id = CommandId .from_sea_statement_id (response .statement_id )
393+
394+ while state in [CommandState .PENDING , CommandState .RUNNING ]:
395+ time .sleep (self .POLL_INTERVAL_SECONDS )
396+ state = self .get_query_state (command_id )
397+
398+ self ._check_command_not_in_failed_or_closed_state (state , command_id )
399+
400+ return state
398401
399402 def execute_command (
400403 self ,
@@ -405,7 +408,7 @@ def execute_command(
405408 lz4_compression : bool ,
406409 cursor : "Cursor" ,
407410 use_cloud_fetch : bool ,
408- parameters : List ,
411+ parameters : List [ Dict [ str , Any ]] ,
409412 async_op : bool ,
410413 enforce_embedded_schema_correctness : bool ,
411414 ) -> Union ["ResultSet" , None ]:
@@ -439,9 +442,9 @@ def execute_command(
439442 for param in parameters :
440443 sea_parameters .append (
441444 StatementParameter (
442- name = param . name ,
443- value = param . value ,
444- type = param . type if hasattr ( param , "type" ) else None ,
445+ name = param [ " name" ] ,
446+ value = param [ " value" ] ,
447+ type = param [ " type" ] if "type" in param else None ,
445448 )
446449 )
447450
@@ -493,24 +496,7 @@ def execute_command(
493496 if async_op :
494497 return None
495498
496- # For synchronous operation, wait for the statement to complete
497- status = response .status
498- state = status .state
499-
500- # Keep polling until we reach a terminal state
501- while state in [CommandState .PENDING , CommandState .RUNNING ]:
502- time .sleep (0.5 ) # add a small delay to avoid excessive API calls
503- state = self .get_query_state (command_id )
504-
505- if state != CommandState .SUCCEEDED :
506- raise ServerOperationError (
507- f"Statement execution did not succeed: { status .error .message if status .error else 'Unknown error' } " ,
508- {
509- "operation-id" : command_id .to_sea_statement_id (),
510- "diagnostic-info" : None ,
511- },
512- )
513-
499+ self ._wait_until_command_done (response )
514500 return self .get_execution_result (command_id , cursor )
515501
516502 def cancel_command (self , command_id : CommandId ) -> None :
@@ -622,25 +608,21 @@ def get_execution_result(
622608 path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
623609 data = request .to_dict (),
624610 )
611+ response = GetStatementResponse .from_dict (response_data )
625612
626613 # Create and return a SeaResultSet
627614 from databricks .sql .result_set import SeaResultSet
628615
629- # Convert the response to an ExecuteResponse and extract result data
630- (
631- execute_response ,
632- result_data ,
633- manifest ,
634- ) = self ._results_message_to_execute_response (response_data , command_id )
616+ execute_response = self ._results_message_to_execute_response (response )
635617
636618 return SeaResultSet (
637619 connection = cursor .connection ,
638620 execute_response = execute_response ,
639621 sea_client = self ,
640622 buffer_size_bytes = cursor .buffer_size_bytes ,
641623 arraysize = cursor .arraysize ,
642- result_data = result_data ,
643- manifest = manifest ,
624+ result_data = response . result ,
625+ manifest = response . manifest ,
644626 )
645627
646628 # == Metadata Operations ==
@@ -654,7 +636,7 @@ def get_catalogs(
654636 ) -> "ResultSet" :
655637 """Get available catalogs by executing 'SHOW CATALOGS'."""
656638 result = self .execute_command (
657- operation = "SHOW CATALOGS" ,
639+ operation = MetadataCommands . SHOW_CATALOGS . value ,
658640 session_id = session_id ,
659641 max_rows = max_rows ,
660642 max_bytes = max_bytes ,
@@ -681,10 +663,10 @@ def get_schemas(
681663 if not catalog_name :
682664 raise ValueError ("Catalog name is required for get_schemas" )
683665
684- operation = f"SHOW SCHEMAS IN ` { catalog_name } `"
666+ operation = MetadataCommands . SHOW_SCHEMAS . value . format ( catalog_name )
685667
686668 if schema_name :
687- operation += f" LIKE ' { schema_name } '"
669+ operation += MetadataCommands . LIKE_PATTERN . value . format ( schema_name )
688670
689671 result = self .execute_command (
690672 operation = operation ,
@@ -716,17 +698,19 @@ def get_tables(
716698 if not catalog_name :
717699 raise ValueError ("Catalog name is required for get_tables" )
718700
719- operation = "SHOW TABLES IN " + (
720- "ALL CATALOGS"
701+ operation = (
702+ MetadataCommands . SHOW_TABLES_ALL_CATALOGS . value
721703 if catalog_name in [None , "*" , "%" ]
722- else f"CATALOG `{ catalog_name } `"
704+ else MetadataCommands .SHOW_TABLES .value .format (
705+ MetadataCommands .CATALOG_SPECIFIC .value .format (catalog_name )
706+ )
723707 )
724708
725709 if schema_name :
726- operation += f" SCHEMA LIKE ' { schema_name } '"
710+ operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
727711
728712 if table_name :
729- operation += f" LIKE ' { table_name } '"
713+ operation += MetadataCommands . LIKE_PATTERN . value . format ( table_name )
730714
731715 result = self .execute_command (
732716 operation = operation ,
@@ -742,7 +726,7 @@ def get_tables(
742726 )
743727 assert result is not None , "execute_command returned None in synchronous mode"
744728
745- # Apply client-side filtering by table_types if specified
729+ # Apply client-side filtering by table_types
746730 from databricks .sql .backend .filters import ResultSetFilter
747731
748732 result = ResultSetFilter .filter_tables_by_type (result , table_types )
@@ -764,16 +748,16 @@ def get_columns(
764748 if not catalog_name :
765749 raise ValueError ("Catalog name is required for get_columns" )
766750
767- operation = f"SHOW COLUMNS IN CATALOG ` { catalog_name } `"
751+ operation = MetadataCommands . SHOW_COLUMNS . value . format ( catalog_name )
768752
769753 if schema_name :
770- operation += f" SCHEMA LIKE ' { schema_name } '"
754+ operation += MetadataCommands . SCHEMA_LIKE_PATTERN . value . format ( schema_name )
771755
772756 if table_name :
773- operation += f" TABLE LIKE ' { table_name } '"
757+ operation += MetadataCommands . TABLE_LIKE_PATTERN . value . format ( table_name )
774758
775759 if column_name :
776- operation += f" LIKE ' { column_name } '"
760+ operation += MetadataCommands . LIKE_PATTERN . value . format ( column_name )
777761
778762 result = self .execute_command (
779763 operation = operation ,
0 commit comments