@@ -39,23 +39,15 @@ def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, exe
39
39
return RowIterator (resp , columnNameMode , dbname )
40
40
41
41
res = next (resp )
42
- columnNames = getColumnNames (res , columnNameMode )
43
- rows = unpack_rows (res , columnNameMode , columnNames )
44
- return fix_colnames (rows , dbname , columnNameMode )
42
+ columnNames = getColumnNames (res , dbname , columnNameMode )
43
+ return unpack_rows (res , columnNameMode , columnNames )
45
44
46
45
47
- def fix_colnames (ret , dbname , columnNameMode ):
46
+ def fix_colnames (cols , dbname , columnNameMode ):
48
47
if columnNameMode not in [constants .COLUMN_NAME_MODE_DATABASE , constants .COLUMN_NAME_MODE_FULL ]:
49
- return ret
48
+ return cols
50
49
51
- # newer DB version don't insert database name anymore, we need to
52
- # process it manually
53
- for i , t in enumerate (ret ):
54
- newkeys = [
55
- x .replace ("[@DB]" , dbname .decode ("utf-8" )) for x in t .keys ()]
56
- k = dict (zip (newkeys , list (t .values ())))
57
- ret [i ] = k
58
- return ret
50
+ return [x .replace ("[@DB]" , dbname .decode ("utf-8" )) for x in cols ]
59
51
60
52
61
53
def unpack_rows (resp , columnNameMode , colNames ):
@@ -69,34 +61,36 @@ def unpack_rows(resp, columnNameMode, colNames):
69
61
return result
70
62
71
63
72
- def getColumnNames (resp , columnNameMode ):
73
- columnNames = []
74
- if columnNameMode != constants .COLUMN_NAME_MODE_NONE :
75
- for column in resp .columns :
76
- # note that depending on the version parts can be
77
- # '(dbname.tablename.fieldname)' *or*
78
- # '(tablename.fieldname)' without dbnname.
79
- # In that case we mimic the old behavior by using [@DB] as placeholder
80
- # that will be replaced at higher level.
81
- parts = column .name .strip ("()" ).split ("." )
82
- if columnNameMode == constants .COLUMN_NAME_MODE_FIELD :
83
- columnNames .append (parts [- 1 ])
84
- continue
85
- if columnNameMode == constants .COLUMN_NAME_MODE_TABLE :
86
- columnNames .append ("." .join (parts [- 2 :]))
87
- continue
88
- print (
89
- "Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated" )
90
- if len (parts ) == 2 :
91
- parts .insert (0 , "[@DB]" )
92
- if columnNameMode == constants .COLUMN_NAME_MODE_DATABASE :
93
- columnNames .append ("." .join (parts ))
94
- continue
95
- if columnNameMode == constants .COLUMN_NAME_MODE_FULL :
96
- columnNames .append ("(" + "." .join (parts )+ ")" )
97
- continue
98
- raise ErrPySDKInvalidColumnMode
99
- return columnNames
64
+ def getColumnNames (resp , dbname , columnNameMode ):
65
+ cols = []
66
+ if columnNameMode == constants .COLUMN_NAME_MODE_NONE :
67
+ return cols
68
+
69
+ for column in resp .columns :
70
+ # note that depending on the version parts can be
71
+ # '(dbname.tablename.fieldname)' *or*
72
+ # '(tablename.fieldname)' without dbnname.
73
+ # In that case we mimic the old behavior by using [@DB] as placeholder
74
+ # that will be replaced at higher level.
75
+ parts = column .name .strip ("()" ).split ("." )
76
+ if columnNameMode == constants .COLUMN_NAME_MODE_FIELD :
77
+ cols .append (parts [- 1 ])
78
+ continue
79
+ if columnNameMode == constants .COLUMN_NAME_MODE_TABLE :
80
+ cols .append ("." .join (parts [- 2 :]))
81
+ continue
82
+ print (
83
+ "Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated" )
84
+ if len (parts ) == 2 :
85
+ parts .insert (0 , "[@DB]" )
86
+ if columnNameMode == constants .COLUMN_NAME_MODE_DATABASE :
87
+ cols .append ("." .join (parts ))
88
+ continue
89
+ if columnNameMode == constants .COLUMN_NAME_MODE_FULL :
90
+ cols .append ("(" + "." .join (parts )+ ")" )
91
+ continue
92
+ raise ErrPySDKInvalidColumnMode
93
+ return fix_colnames (cols , dbname , columnNameMode )
100
94
101
95
102
96
class ClosedIterator (BaseException ):
@@ -109,7 +103,7 @@ def __init__(self, grpcIt, colNameMode, dbname) -> None:
109
103
self ._nextRow = 0
110
104
self ._rows = []
111
105
self ._columns = None
112
- self ._colNameMode = colNameMode if colNameMode != constants . COLUMN_NAME_MODE_NONE else constants . COLUMN_NAME_MODE_FIELD
106
+ self ._colNameMode = colNameMode
113
107
self ._dbname = dbname
114
108
self ._closed = False
115
109
@@ -132,22 +126,21 @@ def _fetch_next(self):
132
126
133
127
res = next (self ._grpcIt )
134
128
if self ._columns == None :
135
- self ._columns = getColumnNames (res , self ._colNameMode )
129
+ self ._columns = getColumnNames (res , self ._dbname , self . _colsMode () )
136
130
137
131
self ._rows = unpack_rows (
138
- res , constants . COLUMN_NAME_MODE_NONE , self ._columns )
132
+ res , self . _colNameMode , self ._columns )
139
133
self ._nextRow = 0
140
134
141
135
if len (self ._rows ) == 0 :
142
136
raise StopIteration
143
137
138
+ def _colsMode (self ):
139
+ return self ._colNameMode if self ._colNameMode != constants .COLUMN_NAME_MODE_NONE else constants .COLUMN_NAME_MODE_FIELD
140
+
144
141
def columns (self ):
145
142
self ._fetch_next ()
146
-
147
- if self ._colNameMode not in [constants .COLUMN_NAME_MODE_DATABASE , constants .COLUMN_NAME_MODE_FULL ]:
148
- return self ._columns
149
-
150
- return [x .replace ("[@DB]" , self ._dbname .decode ("utf-8" )) for x in self ._columns ]
143
+ return self ._columns
151
144
152
145
def close (self ):
153
146
if self ._closed :
0 commit comments