Skip to content

Commit

Permalink
feat: inspectdb support sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed Apr 1, 2022
1 parent 75480e2 commit 801dde1
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### 0.6.3

- Improve `inspectdb` and support `postgres`.
- Improve `inspectdb` and support `postgres` & `sqlite`.

### 0.6.2

Expand Down
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ style: deps
check: deps
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles)
@bandit -x tests -r $(checkfiles)
#@mypy $(checkfiles)

test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Now your db is rolled back to the specified version.

### Inspect db tables to TortoiseORM model

Currently `inspectdb` only supports MySQL & Postgres.
Currently `inspectdb` support MySQL & Postgres & SQLite.

```shell
Usage: aerich inspectdb [OPTIONS]
Expand Down
3 changes: 3 additions & 0 deletions aerich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aerich.exceptions import DowngradeError
from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres
from aerich.inspect.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
Expand Down Expand Up @@ -114,6 +115,8 @@ async def inspectdb(self, tables: List[str]) -> str:
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
Expand Down
2 changes: 1 addition & 1 deletion aerich/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def wrapper(*args, **kwargs):
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init"]:
if Tortoise._inited:
loop.run_until_complete(Tortoise.close_connections())

return wrapper
Expand Down
11 changes: 7 additions & 4 deletions aerich/inspect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ def translate(self) -> dict:
pk = "pk=True, "
if self.unique:
unique = "unique=True, "
if self.data_type == "varchar":
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type == "tinyint":
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz"]:
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
Expand Down Expand Up @@ -66,7 +66,10 @@ class Inspect:

def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.database = conn.database
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables

@property
Expand Down
5 changes: 3 additions & 2 deletions aerich/inspect/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def field_map(self) -> dict:
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
Expand All @@ -30,8 +31,8 @@ async def get_all_tables(self) -> List[str]:
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
ret = await self.conn.execute_query(sql, [self.database, table])
for row in ret[1]:
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
columns.append(
Column(
name=row["COLUMN_NAME"],
Expand Down
6 changes: 4 additions & 2 deletions aerich/inspect/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
Expand Down Expand Up @@ -53,8 +55,8 @@ async def get_columns(self, table: str) -> List[Column]:
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3;"""
ret = await self.conn.execute_query(sql, [self.database, table, self.schema])
for row in ret[1]:
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
Expand Down
49 changes: 49 additions & 0 deletions aerich/inspect/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List

from aerich.inspect import Column, Inspect


class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}

async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=False, # can't get this simply
)
)
return columns

async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

0 comments on commit 801dde1

Please sign in to comment.