Skip to content

Commit

Permalink
CSV importer/exporter can now handle UNIQUE with multiple columns as …
Browse files Browse the repository at this point in the history
…long

as the last FK resolves to a single column.
  • Loading branch information
bgol committed Nov 28, 2014
1 parent 3785045 commit 6a609d3
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 50 deletions.
33 changes: 24 additions & 9 deletions cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,11 @@ def processImportFile(tdenv, db, importPath, tableName):
fkeySelectStr = ("("
"SELECT {newValue}"
" FROM {table}"
" WHERE {table}.{column} = ?"
" WHERE {stmt}"
")"
)
uniquePfx = "unq:"
ignorePfx = "!"

with importPath.open(encoding='utf-8') as importFile:
csvin = csv.reader(importFile, delimiter=',', quotechar="'", doublequote=True)
Expand All @@ -636,9 +637,9 @@ def processImportFile(tdenv, db, importPath, tableName):

# split up columns and values
# this is necessqary because the insert might use a foreign key
columnNames = []
bindColumns = []
bindValues = []
joinHelper = []
uniqueIndexes = []
for (cIndex, cName) in enumerate(columnDefs):
splitNames = cName.split('@')
Expand All @@ -647,23 +648,34 @@ def processImportFile(tdenv, db, importPath, tableName):
if colName.startswith(uniquePfx):
uniqueIndexes += [ cIndex ]
colName = colName[len(uniquePfx):]
columnNames.append(colName)
if colName.startswith(ignorePfx):
# this column is only used to resolve an FK
colName = colName[len(ignorePfx):]
joinHelper.append( "{}@{}".format(colName, splitNames[1]) )
continue

if len(splitNames) == 1:
# no foreign key, straight insert
bindColumns.append(colName)
bindValues.append('?')
else:
# foreign key, we need to make a select
splitJoin = splitNames[1].split('.')
joinTable = splitJoin[0]
joinColumn = splitJoin[1]
bindColumns.append(joinColumn)
splitJoin = splitNames[1].split('.')
joinTable = [ splitJoin[0] ]
joinStmt = []
for joinRow in joinHelper:
helperNames = joinRow.split('@')
helperJoin = helperNames[1].split('.')
joinTable.append( "INNER JOIN {} USING({})".format(helperJoin[0], helperJoin[1]) )
joinStmt.append( "{}.{} = ?".format(helperJoin[0], helperNames[0]) )
joinHelper = []
joinStmt.append("{}.{} = ?".format(splitJoin[0], colName))
bindColumns.append(splitJoin[1])
bindValues.append(
fkeySelectStr.format(
newValue=splitNames[1],
table=joinTable,
column=colName,
table=" ".join(joinTable),
stmt=" AND ".join(joinStmt),
)
)
# now we can make the sql statement
Expand Down Expand Up @@ -730,6 +742,9 @@ def processImportFile(tdenv, db, importPath, tableName):
)
) from None
importCount += 1
else:
if not tdenv.quiet:
print("Wrong number of columns ({}:{}): {}".format(importPath, lineNo, ', '.join(linein)))
db.commit()
tdenv.DEBUG0("{count} {table}s imported",
count=importCount,
Expand Down
124 changes: 83 additions & 41 deletions commands/export_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
#
# Generate the CSV files for the master data of the database.
#
# Note: This script makes some assumptions about the structure
# of the database:
# * The column name of an foreign key reference must be the same
# * The referenced table must have a column named "name"
# which is UNIQUE
# * One column primary keys will be handled by the database engine
# Note: This command makes some assumptions about the structure
# of the database:
# * The table should only have one UNIQUE index
# * The referenced table must have one UNIQUE index
# * One column primary keys will be handled by the database engine
#
######################################################################
# CAUTION: If the database structure gets changed this script might
Expand Down Expand Up @@ -79,9 +78,9 @@
######################################################################
# Helpers

def search_dict(list, key, val):
def search_keyList(list, val):
for row in list:
if row[key] == val: return row
if row['from'] == row['to'] == val: return row

def getUniqueIndex(conn, tableName):
# return the first unique index
Expand All @@ -99,6 +98,41 @@ def getUniqueIndex(conn, tableName):
return unqIndex
return unqIndex

def getFKeyList(conn, tableName):
# get all single column FKs
keyList = []
keyCount = -1
keyCursor = conn.cursor()
for keyRow in keyCursor.execute("PRAGMA foreign_key_list('%s')" % tableName):
if keyRow['seq'] == 0:
keyCount += 1
keyList.append( {'table': keyRow['table'],
'from': keyRow['from'],
'to': keyRow['to']}
)
if keyRow['seq'] == 1:
# if there is a second column, remove it from the list
keyList.remove( keyList[keyCount] )
keyCount -= 1

return keyList

def buildFKeyStmt(conn, tableName, key):
unqIndex = getUniqueIndex(conn, key['table'])
keyList = getFKeyList(conn, key['table'])
keyStmt = []
for colName in unqIndex:
# check if the column is a foreign key
keyKey = search_keyList(keyList, colName)
if keyKey:
newStmt = buildFKeyStmt(conn, key['table'], keyKey)
for row in newStmt:
keyStmt.append(row)
else:
keyStmt.append( {'table': tableName, 'column': colName, 'joinTable': key['table'], 'joinColumn': key['to']} )

return keyStmt

######################################################################
# Perform query and populate result set

Expand Down Expand Up @@ -160,46 +194,61 @@ def run(results, cmdenv, tdb):
exportOut = csv.writer(exportFile, delimiter=",", quotechar="'", doublequote=True, quoting=csv.QUOTE_NONNUMERIC, lineterminator="\n")

cur = conn.cursor()
keyList = []
for key in cur.execute("PRAGMA foreign_key_list('%s')" % tableName):
# ignore FKs to table StationItem
if key['table'] != 'StationItem':
# only support FK joins with the same column name
if key['from'] == key['to']:
keyList += [ {'table': key['table'], 'column': key['from']} ]

# check for single PRIMARY KEY
pkCount = 0
for col in cur.execute("PRAGMA table_info('%s')" % tableName):
for columnRow in cur.execute("PRAGMA table_info('%s')" % tableName):
# count the columns of the primary key
if col['pk'] > 0: pkCount += 1
if columnRow['pk'] > 0: pkCount += 1

# build column list
columnList = []
for columnRow in cur.execute("PRAGMA table_info('%s')" % tableName):
# if there is only one PK column, ignore it
if columnRow['pk'] > 0 and pkCount == 1: continue
columnList.append(columnRow)

# reverse the first two columns for some tables
if tableName in reverseList:
columnList[0], columnList[1] = columnList[1], columnList[0]

# initialize helper lists
csvHead = []
stmtColumn = []
stmtTable = [ tableName ]
stmtOrder = []
unqIndex = getUniqueIndex(conn, tableName)
keyList = getFKeyList(conn, tableName)

# iterate over all columns of the table
for col in cur.execute("PRAGMA table_info('%s')" % tableName):
# if there is only one PK column, ignore it
if col['pk'] > 0 and pkCount == 1: continue
cmdenv.DEBUG0('UNIQUE: ' + ", ".join(unqIndex))

# iterate over all columns of the table
for col in columnList:
# check if the column is a foreign key
key = search_dict(keyList, 'column', col['name'])
key = search_keyList(keyList, col['name'])
if key:
# there must be a "name" column in the referenced table
if col['name'] in unqIndex:
# column is part of an unique index
csvHead += [ uniquePfx + "name@{}.{}".format(key['table'], key['column']) ]
else:
csvHead += [ "name@{}.{}".format(key['table'], key['column']) ]
stmtColumn += [ "{}.name".format(key['table']) ]
if col['notnull']:
stmtTable += [ 'INNER JOIN {} USING({})'.format(key['table'], key['column']) ]
else:
stmtTable += [ 'LEFT OUTER JOIN {} USING({})'.format(key['table'], key['column']) ]
stmtOrder += [ "{}.name".format(key['table']) ]
# make the join statement
keyStmt = buildFKeyStmt(conn, tableName, key)
for keyRow in keyStmt:
if cmdenv.debug > 0:
print('FK-Stmt: {}'.format(keyRow))
# is the join for the same table
if keyRow['table'] == tableName:
csvPfx = ''
joinStmt = 'USING({})'.format(keyRow['joinColumn'])
else:
csvPfx = '!'
joinStmt = 'ON {}.{} = {}.{}'.format(keyRow['table'], keyRow['joinColumn'], keyRow['joinTable'], keyRow['joinColumn'])
if col['name'] in unqIndex:
# column is part of an unique index
csvPfx = uniquePfx + csvPfx
csvHead += [ "{}{}@{}.{}".format(csvPfx, keyRow['column'], keyRow['joinTable'], keyRow['joinColumn']) ]
stmtColumn += [ "{}.{}".format(keyRow['joinTable'], keyRow['column']) ]
if col['notnull']:
stmtTable += [ 'INNER JOIN {} {}'.format(keyRow['joinTable'], joinStmt) ]
else:
stmtTable += [ 'LEFT OUTER JOIN {} {}'.format(keyRow['joinTable'], joinStmt) ]
stmtOrder += [ "{}.{}".format(keyRow['joinTable'], keyRow['column']) ]
else:
# ordinary column
if col['name'] in unqIndex:
Expand All @@ -210,13 +259,6 @@ def run(results, cmdenv, tdb):
csvHead += [ col['name'] ]
stmtColumn += [ "{}.{}".format(tableName, col['name']) ]

# reverse the first two columns for some tables
if tableName in reverseList:
csvHead[0], csvHead[1] = csvHead[1], csvHead[0]
stmtColumn[0], stmtColumn[1] = stmtColumn[1], stmtColumn[0]
if len(stmtOrder) > 1:
stmtOrder[0], stmtOrder[1] = stmtOrder[1], stmtOrder[0]

# build the SQL statement
sqlStmt = "SELECT {} FROM {}".format(",".join(stmtColumn)," ".join(stmtTable))
if len(stmtOrder) > 0:
Expand Down

0 comments on commit 6a609d3

Please sign in to comment.