From 0b01fa38d843a0fe84e733701f2d3e10a36cea14 Mon Sep 17 00:00:00 2001 From: long2ice Date: Tue, 5 Apr 2022 19:38:08 +0800 Subject: [PATCH] feat: add index inspect --- aerich/__init__.py | 2 +- aerich/inspect/__init__.py | 25 +++++++++++++++---------- aerich/inspect/mysql.py | 20 +++++++++++++++++++- aerich/inspect/postgres.py | 3 ++- aerich/inspect/sqlite.py | 14 +++++++++++++- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/aerich/__init__.py b/aerich/__init__.py index 1caf1bd..7303e58 100644 --- a/aerich/__init__.py +++ b/aerich/__init__.py @@ -108,7 +108,7 @@ async def history(self): ret.append(version) return ret - async def inspectdb(self, tables: List[str]) -> str: + async def inspectdb(self, tables: List[str] = None) -> str: connection = get_app_connection(self.tortoise_config, self.app) dialect = connection.schema_generator.DIALECT if dialect == "mysql": diff --git a/aerich/inspect/__init__.py b/aerich/inspect/__init__.py index 5667468..25d7d24 100644 --- a/aerich/inspect/__init__.py +++ b/aerich/inspect/__init__.py @@ -12,17 +12,22 @@ class Column(BaseModel): comment: Optional[str] pk: bool unique: bool + index: bool length: Optional[int] extra: Optional[str] decimal_places: Optional[int] max_digits: Optional[int] def translate(self) -> dict: - comment = default = length = unique = null = pk = "" + comment = default = length = index = null = pk = "" if self.pk: pk = "pk=True, " - if self.unique: - unique = "unique=True, " + else: + if self.unique: + index = "unique=True, " + else: + if self.index: + index = "index=True, " if self.data_type in ["varchar", "VARCHAR"]: length = f"max_length={self.length}, " if self.data_type == "decimal": @@ -53,7 +58,7 @@ def translate(self) -> dict: return { "name": self.name, "pk": pk, - "unique": unique, + "index": index, "null": null, "default": default, "length": length, @@ -99,7 +104,7 @@ async def get_all_tables(self) -> List[str]: @classmethod def decimal_field(cls, **kwargs) -> str: - return "{name} = fields.DecimalField({pk}{unique}{length}{null}{default}{comment})".format( + return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format( **kwargs ) @@ -125,21 +130,21 @@ def text_field(cls, **kwargs) -> str: @classmethod def char_field(cls, **kwargs) -> str: - return "{name} = fields.CharField({pk}{unique}{length}{null}{default}{comment})".format( + return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format( **kwargs ) @classmethod def int_field(cls, **kwargs) -> str: - return "{name} = fields.IntField({pk}{unique}{comment})".format(**kwargs) + return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs) @classmethod def smallint_field(cls, **kwargs) -> str: - return "{name} = fields.SmallIntField({pk}{unique}{comment})".format(**kwargs) + return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs) @classmethod def bigint_field(cls, **kwargs) -> str: - return "{name} = fields.BigIntField({pk}{unique}{default}{comment})".format(**kwargs) + return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs) @classmethod def bool_field(cls, **kwargs) -> str: @@ -147,7 +152,7 @@ def bool_field(cls, **kwargs) -> str: @classmethod def uuid_field(cls, **kwargs) -> str: - return "{name} = fields.UUIDField({pk}{unique}{default}{comment})".format(**kwargs) + return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs) @classmethod def json_field(cls, **kwargs) -> str: diff --git a/aerich/inspect/mysql.py b/aerich/inspect/mysql.py index 28c480a..7d62bff 100644 --- a/aerich/inspect/mysql.py +++ b/aerich/inspect/mysql.py @@ -30,9 +30,25 @@ async def get_all_tables(self) -> List[str]: async def get_columns(self, table: str) -> List[Column]: columns = [] - sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s" + sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME +from information_schema.COLUMNS c + left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME + and c.TABLE_SCHEMA = s.TABLE_SCHEMA + and c.COLUMN_NAME = s.COLUMN_NAME +where c.TABLE_SCHEMA = %s + and c.TABLE_NAME = %s""" ret = await self.conn.execute_query_dict(sql, [self.database, table]) for row in ret: + non_unique = row["NON_UNIQUE"] + if non_unique is None: + unique = False + else: + unique = not non_unique + index_name = row["INDEX_NAME"] + if index_name is None: + index = False + else: + index = row["INDEX_NAME"] != "PRIMARY" columns.append( Column( name=row["COLUMN_NAME"], @@ -43,6 +59,8 @@ async def get_columns(self, table: str) -> List[Column]: comment=row["COLUMN_COMMENT"], unique=row["COLUMN_KEY"] == "UNI", extra=row["EXTRA"], + unque=unique, + index=index, length=row["CHARACTER_MAXIMUM_LENGTH"], max_digits=row["NUMERIC_PRECISION"], decimal_places=row["NUMERIC_SCALE"], diff --git a/aerich/inspect/postgres.py b/aerich/inspect/postgres.py index 205db7f..a30b947 100644 --- a/aerich/inspect/postgres.py +++ b/aerich/inspect/postgres.py @@ -54,7 +54,7 @@ async def get_columns(self, table: str) -> List[Column]: right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) where c.table_catalog = $1 and c.table_name = $2 - and c.table_schema = $3;""" + and c.table_schema = $3""" ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) for row in ret: columns.append( @@ -69,6 +69,7 @@ async def get_columns(self, table: str) -> List[Column]: comment=row["column_comment"], pk=row["column_key"] == "PRIMARY KEY", unique=False, # can't get this simply + index=False, # can't get this simply ) ) return columns diff --git a/aerich/inspect/sqlite.py b/aerich/inspect/sqlite.py index 5ce5bf6..885b9c0 100644 --- a/aerich/inspect/sqlite.py +++ b/aerich/inspect/sqlite.py @@ -25,6 +25,7 @@ async def get_columns(self, table: str) -> List[Column]: columns = [] sql = f"PRAGMA table_info({table})" ret = await self.conn.execute_query_dict(sql) + columns_index = await self._get_columns_index(table) for row in ret: try: length = row["type"].split("(")[1].split(")")[0] @@ -38,11 +39,22 @@ async def get_columns(self, table: str) -> List[Column]: default=row["dflt_value"], length=length, pk=row["pk"] == 1, - unique=False, # can't get this simply + unique=columns_index.get(row["name"]) == "unique", + index=columns_index.get(row["name"]) == "index", ) ) return columns + async def _get_columns_index(self, table: str): + sql = f"PRAGMA index_list ({table})" + indexes = await self.conn.execute_query_dict(sql) + ret = {} + for index in indexes: + sql = f"PRAGMA index_info({index['name']})" + index_info = (await self.conn.execute_query_dict(sql))[0] + ret[index_info["name"]] = "unique" if index["unique"] else "index" + return ret + async def get_all_tables(self) -> List[str]: sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" ret = await self.conn.execute_query_dict(sql)