Skip to content

Commit 172e062

Browse files
committed
Fix colname mode
1 parent fbc9fb1 commit 172e062

File tree

1 file changed

+42
-49
lines changed

1 file changed

+42
-49
lines changed

immudb/handler/sqlquery.py

+42-49
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,15 @@ def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, exe
3939
return RowIterator(resp, columnNameMode, dbname)
4040

4141
res = next(resp)
42-
columnNames = getColumnNames(res, columnNameMode)
43-
rows = unpack_rows(res, columnNameMode, columnNames)
44-
return fix_colnames(rows, dbname, columnNameMode)
42+
columnNames = getColumnNames(res, dbname, columnNameMode)
43+
return unpack_rows(res, columnNameMode, columnNames)
4544

4645

47-
def fix_colnames(ret, dbname, columnNameMode):
46+
def fix_colnames(cols, dbname, columnNameMode):
4847
if columnNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]:
49-
return ret
48+
return cols
5049

51-
# newer DB version don't insert database name anymore, we need to
52-
# process it manually
53-
for i, t in enumerate(ret):
54-
newkeys = [
55-
x.replace("[@DB]", dbname.decode("utf-8")) for x in t.keys()]
56-
k = dict(zip(newkeys, list(t.values())))
57-
ret[i] = k
58-
return ret
50+
return [x.replace("[@DB]", dbname.decode("utf-8")) for x in cols]
5951

6052

6153
def unpack_rows(resp, columnNameMode, colNames):
@@ -69,34 +61,36 @@ def unpack_rows(resp, columnNameMode, colNames):
6961
return result
7062

7163

72-
def getColumnNames(resp, columnNameMode):
73-
columnNames = []
74-
if columnNameMode != constants.COLUMN_NAME_MODE_NONE:
75-
for column in resp.columns:
76-
# note that depending on the version parts can be
77-
# '(dbname.tablename.fieldname)' *or*
78-
# '(tablename.fieldname)' without dbnname.
79-
# In that case we mimic the old behavior by using [@DB] as placeholder
80-
# that will be replaced at higher level.
81-
parts = column.name.strip("()").split(".")
82-
if columnNameMode == constants.COLUMN_NAME_MODE_FIELD:
83-
columnNames.append(parts[-1])
84-
continue
85-
if columnNameMode == constants.COLUMN_NAME_MODE_TABLE:
86-
columnNames.append(".".join(parts[-2:]))
87-
continue
88-
print(
89-
"Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated")
90-
if len(parts) == 2:
91-
parts.insert(0, "[@DB]")
92-
if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE:
93-
columnNames.append(".".join(parts))
94-
continue
95-
if columnNameMode == constants.COLUMN_NAME_MODE_FULL:
96-
columnNames.append("("+".".join(parts)+")")
97-
continue
98-
raise ErrPySDKInvalidColumnMode
99-
return columnNames
64+
def getColumnNames(resp, dbname, columnNameMode):
65+
cols = []
66+
if columnNameMode == constants.COLUMN_NAME_MODE_NONE:
67+
return cols
68+
69+
for column in resp.columns:
70+
# note that depending on the version parts can be
71+
# '(dbname.tablename.fieldname)' *or*
72+
# '(tablename.fieldname)' without dbnname.
73+
# In that case we mimic the old behavior by using [@DB] as placeholder
74+
# that will be replaced at higher level.
75+
parts = column.name.strip("()").split(".")
76+
if columnNameMode == constants.COLUMN_NAME_MODE_FIELD:
77+
cols.append(parts[-1])
78+
continue
79+
if columnNameMode == constants.COLUMN_NAME_MODE_TABLE:
80+
cols.append(".".join(parts[-2:]))
81+
continue
82+
print(
83+
"Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated")
84+
if len(parts) == 2:
85+
parts.insert(0, "[@DB]")
86+
if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE:
87+
cols.append(".".join(parts))
88+
continue
89+
if columnNameMode == constants.COLUMN_NAME_MODE_FULL:
90+
cols.append("("+".".join(parts)+")")
91+
continue
92+
raise ErrPySDKInvalidColumnMode
93+
return fix_colnames(cols, dbname, columnNameMode)
10094

10195

10296
class ClosedIterator(BaseException):
@@ -109,7 +103,7 @@ def __init__(self, grpcIt, colNameMode, dbname) -> None:
109103
self._nextRow = 0
110104
self._rows = []
111105
self._columns = None
112-
self._colNameMode = colNameMode if colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD
106+
self._colNameMode = colNameMode
113107
self._dbname = dbname
114108
self._closed = False
115109

@@ -132,22 +126,21 @@ def _fetch_next(self):
132126

133127
res = next(self._grpcIt)
134128
if self._columns == None:
135-
self._columns = getColumnNames(res, self._colNameMode)
129+
self._columns = getColumnNames(res, self._dbname, self._colsMode())
136130

137131
self._rows = unpack_rows(
138-
res, constants.COLUMN_NAME_MODE_NONE, self._columns)
132+
res, self._colNameMode, self._columns)
139133
self._nextRow = 0
140134

141135
if len(self._rows) == 0:
142136
raise StopIteration
143137

138+
def _colsMode(self):
139+
return self._colNameMode if self._colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD
140+
144141
def columns(self):
145142
self._fetch_next()
146-
147-
if self._colNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]:
148-
return self._columns
149-
150-
return [x.replace("[@DB]", self._dbname.decode("utf-8")) for x in self._columns]
143+
return self._columns
151144

152145
def close(self):
153146
if self._closed:

0 commit comments

Comments
 (0)