diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..dd0473f --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,91 @@ + +name: CD + +on: + push: + tags: + - "v[0-9]+.[0-9]+.[0-9]+*" + +env: + FORCE_COLOR: 1 + +jobs: + pre-commit: + runs-on: ubuntu-latest + timeout-minutes: 90 + strategy: + matrix: + python: ["3.8"] + steps: + - uses: actions/checkout@v3.3.0 + - name: "Install dependencies" + run: sudo apt-get -y install graphviz + - name: Cache python dependencies + id: cache-pip + uses: actions/cache@v3.2.4 + with: + path: ~/.cache/pip + key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} + restore-keys: pip-${{ matrix.python }}-tests- + - name: Set up Python + uses: actions/setup-python@v4.5.0 + with: + python-version: ${{ matrix.python }} + - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. + run: pip install virtualenv>20 + - name: Install Tox + run: pip install tox + - name: Run pre-commit in Tox + run: tox -e pre-commit + test: + runs-on: "ubuntu-latest" + strategy: + matrix: + python: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3.3.0 + - name: "Install dependencies" + run: sudo apt-get -y install graphviz + - uses: actions/checkout@v3.3.0 + - name: Cache python dependencies + id: cache-pip + uses: actions/cache@v3.2.4 + with: + path: ~/.cache/pip + key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} + restore-keys: pip-${{ matrix.python }}-tests- + - name: Set up Python + uses: actions/setup-python@v4.5.0 + with: + python-version: ${{ matrix.python }} + - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. + run: pip install virtualenv>20 + - name: Install Tox + run: pip install tox + - name: "Test with tox" + run: | + tox -e py${{ matrix.python }} -- tests/ --cov=./sqlalchemy_schemadisplay --cov-append --cov-report=xml --cov-report=term-missing + + release: + if: "github.event_name == 'push' && startsWith(github.ref, 'refs/tags')" + runs-on: "ubuntu-latest" + needs: "test" + steps: + - uses: "actions/checkout@v3" + + - name: "Install dependencies" + run: | + python -m pip install --upgrade pip + pip install build + + - name: "Build" + run: | + python -m build + git status --ignored + + - name: "Publish" + uses: "pypa/gh-action-pypi-publish@release/v1" + with: + user: "__token__" + password: "${{ secrets.TEST_PYPI_API_TOKEN }}" + repository_url: "https://test.pypi.org/legacy/" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2f89310 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,63 @@ +name: CI + +on: [push, pull_request] + +env: + FORCE_COLOR: 1 + +jobs: + pre-commit: + runs-on: ubuntu-latest + timeout-minutes: 90 + strategy: + matrix: + python: ["3.8"] + steps: + - uses: actions/checkout@v3.3.0 + - name: "Install dependencies" + run: sudo apt-get -y install graphviz + - name: Cache python dependencies + id: cache-pip + uses: actions/cache@v3.2.4 + with: + path: ~/.cache/pip + key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} + restore-keys: pip-${{ matrix.python }}-tests- + - name: Set up Python + uses: actions/setup-python@v4.5.0 + with: + python-version: ${{ matrix.python }} + - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. + run: pip install virtualenv>20 + - name: Install Tox + run: pip install tox + - name: Run pre-commit in Tox + run: tox -e pre-commit + test: + runs-on: "ubuntu-latest" + strategy: + matrix: + python: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3.3.0 + - name: "Install dependencies" + run: sudo apt-get -y install graphviz + - uses: actions/checkout@v3.3.0 + - name: Cache python dependencies + id: cache-pip + uses: actions/cache@v3.2.4 + with: + path: ~/.cache/pip + key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} + restore-keys: pip-${{ matrix.python }}-tests- + - name: Set up Python + uses: actions/setup-python@v4.5.0 + with: + python-version: ${{ matrix.python }} + - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. + run: pip install virtualenv>20 + - name: Install Tox + run: pip install tox + - name: "Test with tox" + run: | + tox -e py${{ matrix.python }} -- tests/ --cov=./sqlalchemy_schemadisplay --cov-append --cov-report=xml --cov-report=term-missing diff --git a/.github/workflows/test-and-release.yml b/.github/workflows/test-and-release.yml deleted file mode 100644 index fe617d4..0000000 --- a/.github/workflows/test-and-release.yml +++ /dev/null @@ -1,60 +0,0 @@ -name: "Test & Release" - -on: - push: - branches: [ "master" ] - tags: [ "*" ] - pull_request: - branches: [ "master" ] - -env: - FORCE_COLOR: 1 - -jobs: - test: - runs-on: "ubuntu-latest" - strategy: - fail-fast: false - matrix: - include: - - tox-envs: "sqla06-py27,sqla07-py27,sqla08-py27,sqla09-py27,sqlalchemy-py27" - - tox-envs: "sqlalchemy-py3" - steps: - - uses: "actions/checkout@v3" - - name: "Install dependencies" - run: | - sudo apt-get -y install graphviz - python -m pip install tox flake8 - - name: "Lint with flake8" - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: "Test with tox" - run: | - tox -v -e ${{ matrix.tox-envs }} -- -v --color=yes - - release: - if: "github.event_name == 'push' && startsWith(github.ref, 'refs/tags')" - runs-on: "ubuntu-latest" - needs: "test" - steps: - - uses: "actions/checkout@v3" - - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install build - - - name: "Build" - run: | - python -m build - git status --ignored - - - name: "Publish" - uses: "pypa/gh-action-pypi-publish@release/v1" - with: - user: "__token__" - password: "${{ secrets.TEST_PYPI_API_TOKEN }}" - repository_url: "https://test.pypi.org/legacy/" diff --git a/.gitignore b/.gitignore index 542fe77..ce6890b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /.coverage /dist/ /htmlcov/ +.vscode/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5d3e27a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +# Install pre-commit hooks via +# pre-commit install +exclude: > + (?x)^(tests/) + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-json + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/asottile/pyupgrade + rev: v2.31.1 + hooks: + - id: pyupgrade + args: [--py38-plus] + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black diff --git a/LICENSE b/LICENSE index 11713d9..e70f96a 100644 --- a/LICENSE +++ b/LICENSE @@ -17,4 +17,4 @@ INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PA PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. \ No newline at end of file +DEALINGS IN THE SOFTWARE. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..786598c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[build-system] +requires = ["flit_core >=3.4,<4"] +build-backend = "flit_core.buildapi" + +[project] +name = "sqlalchemy_schemadisplay" +dynamic = ["version", "description"] +authors = [{name = "Florian Schulze", email = "florian.schulze@gmx.net"}] +readme = "README.rst" +license = {file = "LICENSE"} +classifiers = [ + "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Database :: Front-Ends", + "Operating System :: OS Independent" +] +keywords = ["aiida", "workflows", "lammps"] +requires-python = ">=3.8" +dependencies = [ + 'setuptools', + 'sqlalchemy>=2.0,<3', + 'pydot', + 'Pillow' +] + +[project.urls] +Documentation = "https://github.com/fschulze/sqlalchemy_schemadisplay/blob/master/README.rst" +Source = "https://github.com/fschulze/sqlalchemy_schemadisplay" + +[project.optional-dependencies] +testing = [ + "attrs>=17.4.0", + "pgtest", + "pytest", + "pytest-cov", + "coverage", + "pytest-timeout", + "pytest-regressions" +] + +pre-commit = [ + "pre-commit", + "tox>=3.23.0", + "virtualenv>20" +] + +[tool.flit.module] +name = "sqlalchemy_schemadisplay" + +[tool.flit.sdist] +exclude = [ + "docs/", + "tests/", +] + +[tool.coverage.run] +# Configuration of [coverage.py](https://coverage.readthedocs.io) +# reporting which lines of your plugin are covered by tests +source=["sqlalchemy_schemadisplay"] + +[tool.isort] +skip = ["venv"] +# Force imports to be sorted by module, independent of import type +force_sort_within_sections = true +# Group first party and local folder imports together +no_lines_before = ["LOCALFOLDER"] + +# Configure isort to work without access to site-packages +known_first_party = ["sqlalchemy_schemadisplay"] + +[tool.tox] +legacy_tox_ini = """ +[tox] +envlist = pre-commit,py{3.8,3.9,3.10} +requires = virtualenv >= 20 +isolated_build = True + +[testenv] +commands = pytest {posargs} +extras = testing + +[testenv:pre-commit] +allowlist_externals = bash +commands = bash -ec 'pre-commit run --all-files || ( git diff; git status; exit 1; )' +extras = + pre-commit + tests + +[flake8] +max-line-length = 140 +import-order-style = edited +[pycodestyle] +max-line-length = 140 +""" diff --git a/setup.py b/setup.py deleted file mode 100644 index 7474e0c..0000000 --- a/setup.py +++ /dev/null @@ -1,33 +0,0 @@ -from setuptools import setup - -import os - -version = '1.4.dev0' - -long_description = open(os.path.join(os.path.dirname(__file__), 'README.rst')).read() - -setup( - name='sqlalchemy_schemadisplay', - version=version, - description="Turn SQLAlchemy DB Model into a graph", - author="Florian Schulze", - author_email="florian.schulze@gmx.net", - license="MIT License", - long_description=long_description[long_description.find('\n\n'):], - url='https://github.com/fschulze/sqlalchemy_schemadisplay', - py_modules=['sqlalchemy_schemadisplay'], - zip_safe=True, - install_requires=[ - 'setuptools', - 'sqlalchemy', - 'pydot', - ], - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python", - "Programming Language :: Python :: Implementation :: CPython", - "Topic :: Database :: Front-Ends", - "Operating System :: OS Independent"], -) diff --git a/sqlalchemy_schemadisplay.py b/sqlalchemy_schemadisplay.py deleted file mode 100644 index 6bbf947..0000000 --- a/sqlalchemy_schemadisplay.py +++ /dev/null @@ -1,289 +0,0 @@ -# updated SQLA schema display to work with pydot 1.0.2 - -from sqlalchemy.orm.properties import RelationshipProperty -from sqlalchemy.orm import sync -import pydot -import types - -__all__ = ['create_uml_graph', 'create_schema_graph', 'show_uml_graph', 'show_schema_graph'] - -def _mk_label(mapper, show_operations, show_attributes, show_datatypes, show_inherited, bordersize): - html = '<' % (bordersize, mapper.class_.__name__) - def format_col(col): - colstr = '+%s' % (col.name) - if show_datatypes: - colstr += ' : %s' % (col.type.__class__.__name__) - return colstr - - if show_attributes: - if not show_inherited: - cols = [c for c in mapper.columns if c.table == mapper.tables[0]] - else: - cols = mapper.columns - html += '' % '
'.join(format_col(col) for col in cols) - else: - [format_col(col) for col in sorted(mapper.columns, key=lambda col:not col.primary_key)] - if show_operations: - html += '' % '
'.join( - '%s(%s)' % (name,", ".join(default is _mk_label and ("%s") % arg or ("%s=%s" % (arg,repr(default))) for default,arg in - zip((func.func_defaults and len(func.func_code.co_varnames)-1-(len(func.func_defaults) or 0) or func.func_code.co_argcount-1)*[_mk_label]+list(func.func_defaults or []), func.func_code.co_varnames[1:]) - )) - for name,func in mapper.class_.__dict__.items() if isinstance(func, types.FunctionType) and func.__module__ == mapper.class_.__module__ - ) - html+= '
%s
%s
%s
>' - return html - - -def escape(name): - return '"%s"' % name - - -def create_uml_graph(mappers, show_operations=True, show_attributes=True, show_inherited=True, show_multiplicity_one=False, show_datatypes=True, linewidth=1.0, font="Bitstream-Vera Sans"): - graph = pydot.Dot(prog='neato',mode="major",overlap="0", sep="0.01",dim="3", pack="True", ratio=".75") - relations = set() - for mapper in mappers: - graph.add_node(pydot.Node(escape(mapper.class_.__name__), - shape="plaintext", label=_mk_label(mapper, show_operations, show_attributes, show_datatypes, show_inherited, linewidth), - fontname=font, fontsize="8.0", - )) - if mapper.inherits: - graph.add_edge(pydot.Edge(escape(mapper.inherits.class_.__name__),escape(mapper.class_.__name__), - arrowhead='none',arrowtail='empty', style="setlinewidth(%s)" % linewidth, arrowsize=str(linewidth))) - for loader in mapper.iterate_properties: - if isinstance(loader, RelationshipProperty) and loader.mapper in mappers: - if hasattr(loader, 'reverse_property'): - relations.add(frozenset([loader, loader.reverse_property])) - else: - relations.add(frozenset([loader])) - - for relation in relations: - #if len(loaders) > 2: - # raise Exception("Warning: too many loaders for join %s" % join) - args = {} - def multiplicity_indicator(prop): - if prop.uselist: - return ' *' - if hasattr(prop, 'local_side'): - cols = prop.local_side - else: - cols = prop.local_columns - if any(col.nullable for col in cols): - return ' 0..1' - if show_multiplicity_one: - return ' 1' - return '' - - if len(relation) == 2: - src, dest = relation - from_name = escape(src.parent.class_.__name__) - to_name = escape(dest.parent.class_.__name__) - - def calc_label(src,dest): - return '+' + src.key + multiplicity_indicator(src) - args['headlabel'] = calc_label(src,dest) - - args['taillabel'] = calc_label(dest,src) - args['arrowtail'] = 'none' - args['arrowhead'] = 'none' - args['constraint'] = False - else: - prop, = relation - from_name = escape(prop.parent.class_.__name__) - to_name = escape(prop.mapper.class_.__name__) - args['headlabel'] = '+%s%s' % (prop.key, multiplicity_indicator(prop)) - args['arrowtail'] = 'none' - args['arrowhead'] = 'vee' - - graph.add_edge(pydot.Edge(from_name,to_name, - fontname=font, fontsize="7.0", style="setlinewidth(%s)"%linewidth, arrowsize=str(linewidth), - **args) - ) - - return graph - -from sqlalchemy.dialects.postgresql.base import PGDialect -from sqlalchemy import Table, text, ForeignKeyConstraint - - -def _render_table_html( - table, metadata, - show_indexes, show_datatypes, show_column_keys, show_schema_name, - format_schema_name, format_table_name -): - # add in (PK) OR (FK) suffixes to column names that are considered to be primary key or foreign key - use_column_key_attr = hasattr(ForeignKeyConstraint, 'column_keys') # sqlalchemy > 1.0 uses column_keys to return list of strings for foreign keys, previously was columns - if show_column_keys: - if (use_column_key_attr): - # sqlalchemy > 1.0 - fk_col_names = set([h for f in table.foreign_key_constraints for h in f.columns.keys()]) - else: - # sqlalchemy pre 1.0? - fk_col_names = set([h.name for f in table.foreign_keys for h in f.constraint.columns]) - # fk_col_names = set([h for f in table.foreign_key_constraints for h in f.columns.keys()]) - pk_col_names = set([f for f in table.primary_key.columns.keys()]) - else: - fk_col_names = set() - pk_col_names = set() - - def format_col_type(col): - try: - return col.type.get_col_spec() - except (AttributeError, NotImplementedError): - return str(col.type) - def format_col_str(col): - # add in (PK) OR (FK) suffixes to column names that are considered to be primary key or foreign key - suffix = '(FK)' if col.name in fk_col_names else '(PK)' if col.name in pk_col_names else '' - if show_datatypes: - return "- %s : %s" % (col.name + suffix, format_col_type(col)) - else: - return "- %s" % (col.name + suffix) - - def format_name(obj_name, format_dict): - # Check if format_dict was provided - if format_dict is not None: - # Should color be checked? Could use /^#([A-Fa-f0-9]{8}|[A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$/ - return '{bld}{it}{name}{e_it}{e_bld}'.format( - name=obj_name, - color=format_dict.get('color') if 'color' in format_dict else 'initial', - size=float(format_dict['fontsize']) if 'fontsize' in format_dict else 'initial', - it='' if format_dict.get('italics') else '', - e_it='' if format_dict.get('italics') else '', - bld='' if format_dict.get('bold') else '', - e_bld='' if format_dict.get('bold') else '' - ) - else: - return obj_name - - schema_str = "" - if show_schema_name == True and hasattr(table, 'schema') and table.schema is not None: - # Build string for schema name, empty if show_schema_name is False - schema_str = format_name(table.schema, format_schema_name) - table_str = format_name(table.name, format_table_name) - - # Assemble table header - html = '<' % ( - schema_str, - '.' if show_schema_name else '', - table_str - ) - - html += ''.join('' % (col.name, format_col_str(col)) for col in table.columns) - if metadata.bind and isinstance(metadata.bind.dialect, PGDialect): - # postgres engine doesn't reflect indexes - indexes = dict((name,defin) for name,defin in metadata.bind.execute( - text("SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '%s'" % table.name) - )) - if indexes and show_indexes: - html += '' - for index, defin in indexes.items(): - ilabel = 'UNIQUE' in defin and 'UNIQUE ' or 'INDEX ' - ilabel += defin[defin.index('('):] - html += '' % ilabel - html += '
%s%s%s
%s
%s
>' - return html - -def create_schema_graph(tables=None, metadata=None, show_indexes=True, show_datatypes=True, font="Bitstream-Vera Sans", - concentrate=True, relation_options={}, rankdir='TB', show_column_keys=False, restrict_tables=None, - show_schema_name=False, format_schema_name=None, format_table_name=None): - """ - Args: - - metadata (sqlalchemy.MetaData, default=None): SqlAlchemy `MetaData` with reference to related tables. If none - is provided, uses metadata from first entry of `tables` argument. - - concentrate (bool, default=True): Specifies if multiedges should be merged into a single edge & partially - parallel edges to share overlapping path. Passed to `pydot.Dot` object. - - relation_options (dict, default: None): kwargs passed to pydot.Edge init. Most attributes in - pydot.EDGE_ATTRIBUTES are viable options. A few values are set programmatically. - - rankdir (string, default='TB'): Sets direction of graph layout. Passed to `pydot.Dot` object. Options are - 'TB' (top to bottom), 'BT' (bottom to top), 'LR' (left to right), 'RL' (right to left). - - show_column_keys (bool, default=False): If true then add a PK/FK suffix to columns names that are primary and - foreign keys. - - restrict_tables (None or list of strings): Restrict the graph to only consider tables whose name are defined - `restrict_tables`. - - show_schema_name (bool, default=False): If true, then prepend '.' to the table name resulting in - '.'. - - format_schema_name (dict, default=None): If provided, allowed keys include: 'color' (hex color code incl #), - 'fontsize' as a float, and 'bold' and 'italics' as bools. - - format_table_name (dict, default=None): If provided, allowed keys include: 'color' (hex color code incl #), - 'fontsize' as a float, and 'bold' and 'italics' as bools. - """ - - relation_kwargs = { - 'fontsize':"7.0", - 'dir':'both' - } - relation_kwargs.update(relation_options) - - if metadata is None and tables is not None and len(tables): - metadata = tables[0].metadata - elif tables is None and metadata is not None: - if not len(metadata.tables): - metadata.reflect() - tables = metadata.tables.values() - else: - raise ValueError("You need to specify at least tables or metadata") - - # check if unexpected keys were used in format_schema_name param - if format_schema_name is not None and \ - len(set(format_schema_name.keys()).difference({'color','fontsize', 'italics', 'bold'})) > 0: - raise KeyError('Unrecognized keys were used in dict provided for `format_schema_name` parameter') - # check if unexpected keys were used in format_table_name param - if format_table_name is not None and \ - len(set(format_table_name.keys()).difference({'color','fontsize', 'italics', 'bold'})) > 0: - raise KeyError('Unrecognized keys were used in dict provided for `format_table_name` parameter') - - graph = pydot.Dot(prog="dot",mode="ipsep",overlap="ipsep",sep="0.01",concentrate=str(concentrate), rankdir=rankdir) - if restrict_tables is None: - restrict_tables = set([t.name.lower() for t in tables]) - else: - restrict_tables = set([t.lower() for t in restrict_tables]) - tables = [t for t in tables if t.name.lower() in restrict_tables] - for table in tables: - - graph.add_node(pydot.Node(str(table.name), - shape="plaintext", - label=_render_table_html( - table, metadata, - show_indexes, show_datatypes, show_column_keys, show_schema_name, - format_schema_name, format_table_name - ), - fontname=font, fontsize="7.0" - )) - - for table in tables: - for fk in table.foreign_keys: - if fk.column.table not in tables: - continue - edge = [table.name, fk.column.table.name] - is_inheritance = fk.parent.primary_key and fk.column.primary_key - if is_inheritance: - edge = edge[::-1] - graph_edge = pydot.Edge( - headlabel="+ %s"%fk.column.name, taillabel='+ %s'%fk.parent.name, - arrowhead=is_inheritance and 'none' or 'odot' , - arrowtail=(fk.parent.primary_key or fk.parent.unique) and 'empty' or 'crow' , - fontname=font, - #samehead=fk.column.name, sametail=fk.parent.name, - *edge, **relation_kwargs - ) - graph.add_edge(graph_edge) - -# not sure what this part is for, doesn't work with pydot 1.0.2 -# graph_edge.parent_graph = graph.parent_graph -# if table.name not in [e.get_source() for e in graph.get_edge_list()]: -# graph.edge_src_list.append(table.name) -# if fk.column.table.name not in graph.edge_dst_list: -# graph.edge_dst_list.append(fk.column.table.name) -# graph.sorted_graph_elements.append(graph_edge) - return graph - -def show_uml_graph(*args, **kwargs): - from cStringIO import StringIO - from PIL import Image - iostream = StringIO(create_uml_graph(*args, **kwargs).create_png()) - Image.open(iostream).show(command=kwargs.get('command','gwenview')) - -def show_schema_graph(*args, **kwargs): - from cStringIO import StringIO - from PIL import Image - iostream = StringIO(create_schema_graph(*args, **kwargs).create_png()) - Image.open(iostream).show(command=kwargs.get('command','gwenview')) diff --git a/sqlalchemy_schemadisplay/__init__.py b/sqlalchemy_schemadisplay/__init__.py new file mode 100644 index 0000000..d1e5336 --- /dev/null +++ b/sqlalchemy_schemadisplay/__init__.py @@ -0,0 +1,13 @@ +"""Package for the generation of diagrams based on SQLAlchemy ORM models and or the database itself""" +from .db_diagram import create_schema_graph +from .model_diagram import create_uml_graph +from .utils import show_schema_graph, show_uml_graph + +__version__ = "0.1.0" + +__all__ = ( + "create_schema_graph", + "create_uml_graph", + "show_schema_graph", + "show_uml_graph", +) diff --git a/sqlalchemy_schemadisplay/db_diagram.py b/sqlalchemy_schemadisplay/db_diagram.py new file mode 100644 index 0000000..7e02a78 --- /dev/null +++ b/sqlalchemy_schemadisplay/db_diagram.py @@ -0,0 +1,321 @@ +"""Set of functions to generate the diagram of the actual database""" +from typing import List, Union + +import pydot +from sqlalchemy import Column, ForeignKeyConstraint, MetaData, Table, text +from sqlalchemy.dialects.postgresql.base import PGDialect +from sqlalchemy.engine import Engine + + +def _render_table_html( + table: Table, + engine: Engine, + show_indexes: bool, + show_datatypes: bool, + show_column_keys: bool, + show_schema_name: bool, + format_schema_name: dict, + format_table_name: dict, +) -> str: + # pylint: disable=too-many-locals,too-many-arguments + """Create a rendering of a table in the database + + Args: + table (sqlalchemy.Table): SqlAlchemy table which is going to be rendered. + engine (sqlalchemy.engine.Engine): SqlAlchemy database engine to connect to the database. + show_indexes (bool): Whether to display the index column in the table + show_datatypes (bool): Whether to display the type of the columns in the table + show_column_keys (bool): If true then add a PK/FK suffix to columns names that \ + are primary and foreign keys + show_schema_name (bool): If true, then prepend '.' to the table \ + name resulting in '.
' + format_schema_name (dict): If provided, allowed keys include: \ + 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' \ + as bools + format_table_name (dict): If provided, allowed keys include: \ + 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' as \ + bools + + Returns: + str: html string with the rendering of the table + """ + # add in (PK) OR (FK) suffixes to column names that are considered to be primary key or + # foreign key + use_column_key_attr = hasattr(ForeignKeyConstraint, "column_keys") + # sqlalchemy > 1.0 uses column_keys to return list of strings for foreign keys,previously + # was columns + if show_column_keys: + if use_column_key_attr: + # sqlalchemy > 1.0 + fk_col_names = { + h for f in table.foreign_key_constraints for h in f.columns.keys() + } + else: + # sqlalchemy pre 1.0? + fk_col_names = { + h.name for f in table.foreign_keys for h in f.constraint.columns + } + # fk_col_names = set([h for f in table.foreign_key_constraints for h in f.columns.keys()]) + pk_col_names = set(list(table.primary_key.columns.keys())) + else: + fk_col_names = set() + pk_col_names = set() + + def format_col_type(col: Column) -> str: + """Get the type of the column as a string + + Args: + col (Column): SqlAlchemy column of the table that is being rendered + + Returns: + str: column type + """ + try: + return col.type.get_col_spec() + except (AttributeError, NotImplementedError): + return str(col.type) + + def format_col_str(col: Column) -> str: + """Generate the column name so that it takes into account any possible suffix. + + Args: + col (sqlalchemy.Column): SqlAlchemy column of the table that is being rendered + + Returns: + str: name of the column with the appropriate suffixes + """ + # add in (PK) OR (FK) suffixes to column names that are considered to be primary key + # or foreign key + suffix = ( + "(FK)" + if col.name in fk_col_names + else "(PK)" + if col.name in pk_col_names + else "" + ) + if show_datatypes: + return f"- {col.name + suffix} : {format_col_type(col)}" + return f"- {col.name + suffix}" + + def format_name(obj_name: str, format_dict: Union[dict, None]) -> str: + """Format the name of the object so that it is rendered differently. + + Args: + obj_name (str): name of the object being rendered + format_dict (Union[dict,None]): dictionary with the rendering options. \ + If None nothing is done + + Returns: + str: formatted name of the object + """ + # Check if format_dict was provided + if format_dict is not None: + # Should color be checked? + # Could use /^#([A-Fa-f0-9]{8}|[A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$/ + _color = format_dict.get("color") if "color" in format_dict else "initial" + _point_size = ( + float(format_dict["fontsize"]) + if "fontsize" in format_dict + else "initial" + ) + _bold = "" if format_dict.get("bold") else "" + _italic = "" if format_dict.get("italics") else "" + _text = f'' + _text += f'{_bold}{_italic}{obj_name}{"" if format_dict.get("italics") else ""}' + _text += f'{"" if format_dict.get("bold") else ""}' + + return _text + return obj_name + + schema_str = "" + if show_schema_name and hasattr(table, "schema") and table.schema is not None: + # Build string for schema name, empty if show_schema_name is False + schema_str = format_name(table.schema, format_schema_name) + table_str = format_name(table.name, format_table_name) + + # Assemble table header + html = '<
' + html += '' + + html += "".join( + f'' + for col in table.columns + ) + if isinstance(engine, Engine) and isinstance(engine.engine.dialect, PGDialect): + # postgres engine doesn't reflect indexes + with engine.connect() as connection: + indexes = { + key: value + for key, value in connection.execute( + text( + f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table.name}'" + ) + ) + } + if indexes and show_indexes: + html += '' + for value in indexes.values(): + i_label = "UNIQUE " if "UNIQUE" in value else "INDEX " + i_label += value[value.index("(") :] + html += f'' + html += "
' + html += f'{schema_str}{"." if show_schema_name else ""}{table_str}
{format_col_str(col)}
{i_label}
>" + return html + + +def create_schema_graph( + engine: Engine, + tables: List[Table] = None, + metadata: MetaData = None, + show_indexes: bool = True, + show_datatypes: bool = True, + font: str = "Bitstream-Vera Sans", + concentrate: bool = True, + relation_options: Union[dict, None] = None, + rankdir: str = "TB", + show_column_keys: bool = False, + restrict_tables: Union[List[str], None] = None, + show_schema_name: bool = False, + format_schema_name: Union[dict, None] = None, + format_table_name: Union[dict, None] = None, +) -> pydot.Dot: + # pylint: disable=too-many-locals,too-many-arguments + """Create a diagram for the database schema. + + Args: + engine (sqlalchemy.engine.Engine): SqlAlchemy database engine to connect to the database. + tables (List[sqlalchemy.Table], optional): SqlAlchemy database tables. Defaults to None. + metadata (sqlalchemy.MetaData, optional): SqlAlchemy `MetaData` with reference to related \ + tables. Defaults to None. + show_indexes (bool, optional): Whether to display the index column in the table. \ + Defaults to True. + show_datatypes (bool, optional): Whether to display the type of the columns in the table. \ + Defaults to True. + font (str, optional): font to be used in the diagram. Defaults to "Bitstream-Vera Sans". + concentrate (bool, optional): Specifies if multi-edges should be merged into a single edge \ + & partially parallel edges to share overlapping path. Passed to `pydot.Dot` object. \ + Defaults to True. + relation_options (Union[dict, None], optional): kwargs passed to pydot.Edge init. \ + Most attributes in pydot.EDGE_ATTRIBUTES are viable options. A few values are set \ + programmatically. Defaults to None. + rankdir (str, optional): Sets direction of graph layout. Passed to `pydot.Dot` object. \ + Options are 'TB' (top to bottom), 'BT' (bottom to top), 'LR' (left to right), \ + 'RL' (right to left). Defaults to 'TB'. + show_column_keys (bool, optional): If true then add a PK/FK suffix to columns names that \ + are primary and foreign keys. Defaults to False. + restrict_tables (Union[List[str], optional): Restrict the graph to only consider tables \ + whose name are defined `restrict_tables`. Defaults to None. + show_schema_name (bool, optional): If true, then prepend '.' to the table \ + name resulting in '.'. Defaults to False. + format_schema_name (Union[dict, None], optional): If provided, allowed keys include: \ + 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' \ + as bools. Defaults to None. + format_table_name (Union[dict, None], optional): If provided, allowed keys include: \ + 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' as \ + bools. Defaults to None. + + Raises: + ValueError: One needs to specify either the metadata or the tables + KeyError: raised when unexpected keys are given to `format_schema_name` or \ + `format_table_name` + Returns: + pydot.Dot: pydot object with the schema of the database + """ + + if not relation_options: + relation_options = {} + + relation_kwargs = {"fontsize": "7.0", "dir": "both"} + relation_kwargs.update(relation_options) + + if not metadata and not tables: + raise ValueError("You need to specify at least tables or metadata") + + if metadata and not tables: + metadata.reflect(bind=engine) + tables = metadata.tables.values() + + _accepted_keys = {"color", "fontsize", "italics", "bold"} + + # check if unexpected keys were used in format_schema_name param + if ( + format_schema_name is not None + and len(set(format_schema_name.keys()).difference(_accepted_keys)) > 0 + ): + raise KeyError( + "Unrecognized keys were used in dict provided for `format_schema_name` parameter" + ) + # check if unexpected keys were used in format_table_name param + if ( + format_table_name is not None + and len(set(format_table_name.keys()).difference(_accepted_keys)) > 0 + ): + raise KeyError( + "Unrecognized keys were used in dict provided for `format_table_name` parameter" + ) + + graph = pydot.Dot( + prog="dot", + mode="ipsep", + overlap="ipsep", + sep="0.01", + concentrate=str(concentrate), + rankdir=rankdir, + ) + if restrict_tables is None: + restrict_tables = {t.name.lower() for t in tables} + else: + restrict_tables = {t.lower() for t in restrict_tables} + tables = [t for t in tables if t.name.lower() in restrict_tables] + for table in tables: + + graph.add_node( + pydot.Node( + str(table.name), + shape="plaintext", + label=_render_table_html( + table=table, + engine=engine, + show_indexes=show_indexes, + show_datatypes=show_datatypes, + show_column_keys=show_column_keys, + show_schema_name=show_schema_name, + format_schema_name=format_schema_name, + format_table_name=format_table_name, + ), + fontname=font, + fontsize="7.0", + ), + ) + + for table in tables: + for fk in table.foreign_keys: + if fk.column.table not in tables: + continue + edge = [table.name, fk.column.table.name] + is_inheritance = fk.parent.primary_key and fk.column.primary_key + if is_inheritance: + edge = edge[::-1] + graph_edge = pydot.Edge( + headlabel="+ %s" % fk.column.name, + taillabel="+ %s" % fk.parent.name, + arrowhead=is_inheritance and "none" or "odot", + arrowtail=(fk.parent.primary_key or fk.parent.unique) + and "empty" + or "crow", + fontname=font, + # samehead=fk.column.name, sametail=fk.parent.name, + *edge, + **relation_kwargs, + ) + graph.add_edge(graph_edge) + + # not sure what this part is for, doesn't work with pydot 1.0.2 + # graph_edge.parent_graph = graph.parent_graph + # if table.name not in [e.get_source() for e in graph.get_edge_list()]: + # graph.edge_src_list.append(table.name) + # if fk.column.table.name not in graph.edge_dst_list: + # graph.edge_dst_list.append(fk.column.table.name) + # graph.sorted_graph_elements.append(graph_edge) + return graph diff --git a/sqlalchemy_schemadisplay/model_diagram.py b/sqlalchemy_schemadisplay/model_diagram.py new file mode 100644 index 0000000..5c720d2 --- /dev/null +++ b/sqlalchemy_schemadisplay/model_diagram.py @@ -0,0 +1,243 @@ +""" +Set of functions to generate the diagram related to the ORM models +""" +import types +from typing import List + +import pydot +from sqlalchemy.orm import Mapper, Relationship +from sqlalchemy.orm.properties import RelationshipProperty + + +def _mk_label( + mapper: Mapper, + show_operations: bool, + show_attributes: bool, + show_datatypes: bool, + show_inherited: bool, + bordersize: float, +) -> str: + # pylint: disable=too-many-arguments + """Generate the rendering of a given orm model. + + Args: + mapper (sqlalchemy.orm.Mapper): mapper for the SqlAlchemy orm class. + show_operations (bool): whether to show functions defined in the orm. + show_attributes (bool): whether to show the attributes of the class. + show_datatypes (bool): Whether to display the type of the columns in the model. + show_inherited (bool): whether to show inherited columns. + bordersize (float): thickness of the border lines in the diagram + + Returns: + str: html string to render the orm model + """ + html = ( + f'<
' + + def format_col(col): + colstr = f"+{col.name}" + if show_datatypes: + colstr += f" : {col.type.__class__.__name__}" + return colstr + + if show_attributes: + if not show_inherited: + cols = [c for c in mapper.columns if c.table == mapper.tables[0]] + else: + cols = mapper.columns + html += '' % '
'.join( + format_col(col) for col in cols + ) + else: + _ = [ + format_col(col) + for col in sorted(mapper.columns, key=lambda col: not col.primary_key) + ] + if show_operations: + html += '' % '
'.join( + "%s(%s)" + % ( + name, + ", ".join( + default is _mk_label and (f"{arg}") or (f"{arg}={repr(default)}") + for default, arg in zip( + ( + func.func_defaults + and len(func.func_code.co_varnames) + - 1 + - (len(func.func_defaults) or 0) + or func.func_code.co_argcount - 1 + ) + * [_mk_label] + + list(func.func_defaults or []), + func.func_code.co_varnames[1:], + ) + ), + ) + for name, func in mapper.class_.__dict__.items() + if isinstance(func, types.FunctionType) + and func.__module__ == mapper.class_.__module__ + ) + html += "
{mapper.class_.__name__}
%s
%s
>" + return html + + +def escape(name: str) -> str: + """Set the name of the object between quotations to avoid reading errors + + Args: + name (str): name of the object + + Returns: + str: name of the object between quotations to avoid reading errors + """ + return f'"{name}"' + + +def create_uml_graph( + mappers: List[Mapper], + show_operations: bool = True, + show_attributes: bool = True, + show_inherited: bool = True, + show_multiplicity_one: bool = False, + show_datatypes: bool = True, + linewidth: float = 1.0, + font: str = "Bitstream-Vera Sans", +) -> pydot.Dot: + # pylint: disable=too-many-locals,too-many-arguments + """Create rendering of the orm models associated with the database + + Args: + mappers (List[sqlalchemy.orm.Mapper]): SqlAlchemy list of mappers of the orm classes. + show_operations (bool, optional): whether to show functions defined in the orm. \ + Defaults to True. + show_attributes (bool, optional): whether to show the attributes of the class. \ + Defaults to True. + show_inherited (bool, optional): whether to show inherited columns. Defaults to True. + show_multiplicity_one (bool, optional): whether to show the multiplicity as a float or \ + integer. Defaults to False. + show_datatypes (bool, optional): Whether to display the type of the columns in the model. \ + Defaults to True. + linewidth (float, optional): thickness of the lines in the diagram. Defaults to 1.0. + font (str, optional): type of fond to be used for the diagram. \ + Defaults to "Bitstream-Vera Sans". + + Returns: + pydot.Dot: pydot object with the diagram for the orm models. + """ + graph = pydot.Dot( + prog="neato", + mode="major", + overlap="0", + sep="0.01", + dim="3", + pack="True", + ratio=".75", + ) + relations = set() + for mapper in mappers: + graph.add_node( + pydot.Node( + escape(mapper.class_.__name__), + shape="plaintext", + label=_mk_label( + mapper, + show_operations, + show_attributes, + show_datatypes, + show_inherited, + linewidth, + ), + fontname=font, + fontsize="8.0", + ) + ) + if mapper.inherits: + graph.add_edge( + pydot.Edge( + escape(mapper.inherits.class_.__name__), + escape(mapper.class_.__name__), + arrowhead="none", + arrowtail="empty", + style="setlinewidth(%s)" % linewidth, + arrowsize=str(linewidth), + ), + ) + for loader in mapper.iterate_properties: + if isinstance(loader, RelationshipProperty) and loader.mapper in mappers: + if hasattr(loader, "reverse_property"): + relations.add(frozenset([loader, loader.reverse_property])) + else: + relations.add(frozenset([loader])) + + def multiplicity_indicator(prop: Relationship) -> str: + """Indicate the multiplicity of a given relationship + + Args: + prop (sqlalchemy.orm.Relationship): relationship associated with this model + + Returns: + str: string indicating the multiplicity of the relationship + """ + if prop.uselist: + return " *" + if hasattr(prop, "local_side"): + cols = prop.local_side + else: + cols = prop.local_columns + if any(col.nullable for col in cols): + return " 0..1" + if show_multiplicity_one: + return " 1" + return "" + + def calc_label(src: Relationship) -> str: + """Generate the label for a given relationship + + Args: + src (Relationship): relationship associated with this model + + Returns: + str: relationship label + """ + return "+" + src.key + multiplicity_indicator(src) + + for relation in relations: + # if len(loaders) > 2: + # raise Exception("Warning: too many loaders for join %s" % join) + args = {} + + if len(relation) == 2: + src, dest = relation + from_name = escape(src.parent.class_.__name__) + to_name = escape(dest.parent.class_.__name__) + + args["headlabel"] = calc_label(src) + + args["taillabel"] = calc_label(dest) + args["arrowtail"] = "none" + args["arrowhead"] = "none" + args["constraint"] = False + else: + (prop,) = relation + from_name = escape(prop.parent.class_.__name__) + to_name = escape(prop.mapper.class_.__name__) + args["headlabel"] = f"+{prop.key}{multiplicity_indicator(prop)}" + args["arrowtail"] = "none" + args["arrowhead"] = "vee" + + graph.add_edge( + pydot.Edge( + from_name, + to_name, + fontname=font, + fontsize="7.0", + style=f"setlinewidth({linewidth})", + arrowsize=str(linewidth), + **args, + ), + ) + + return graph diff --git a/sqlalchemy_schemadisplay/utils.py b/sqlalchemy_schemadisplay/utils.py new file mode 100644 index 0000000..6ea0c99 --- /dev/null +++ b/sqlalchemy_schemadisplay/utils.py @@ -0,0 +1,28 @@ +""" +Set of functions to display the database diagrams generated. +""" +from io import StringIO + +from PIL import Image + +from sqlalchemy_schemadisplay import create_schema_graph, create_uml_graph + + +def show_uml_graph(*args, **kwargs): + """ + Show the SQLAlchemy ORM diagram generated. + """ + iostream = StringIO( + create_uml_graph(*args, **kwargs).create_png() + ) # pylint: disable=no-member + Image.open(iostream).show(command=kwargs.get("command", "gwenview")) + + +def show_schema_graph(*args, **kwargs): + """ + Show the database diagram generated + """ + iostream = StringIO( + create_schema_graph(*args, **kwargs).create_png() + ) # pylint: disable=no-member + Image.open(iostream).show(command=kwargs.get("command", "gwenview")) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_schema_graph.py b/tests/test_schema_graph.py index 50b19fe..3309d0e 100644 --- a/tests/test_schema_graph.py +++ b/tests/test_schema_graph.py @@ -1,156 +1,199 @@ -from sqlalchemy import types -from sqlalchemy import Column -from sqlalchemy import ForeignKey -from sqlalchemy import MetaData -from sqlalchemy import Table -from utils import parse_graph +"""Set of tests for database diagrams""" import pydot import pytest -import sqlalchemy_schemadisplay as sasd +from sqlalchemy import Column, ForeignKey, MetaData, Table, create_engine, types + +import sqlalchemy_schemadisplay +from .utils import parse_graph @pytest.fixture def metadata(request): - return MetaData('sqlite:///:memory:') + engine = create_engine("sqlite:///:memory:") + _metadata = MetaData() + _metadata.reflect(engine) + return _metadata + + +@pytest.fixture +def engine(): + return create_engine("sqlite:///:memory:") def plain_result(**kw): - if 'metadata' in kw: - kw['metadata'].create_all() - elif 'tables' in kw: - if len(kw['tables']): - kw['tables'][0].metadata.create_all() - return parse_graph(sasd.create_schema_graph(**kw)) + if "metadata" in kw: + kw["metadata"].create_all(kw['engine']) + elif "tables" in kw: + if len(kw["tables"]): + kw["tables"][0].metadata.create_all(kw['engine']) + return parse_graph(sqlalchemy_schemadisplay.create_schema_graph(**kw)) -def test_no_args(): +def test_no_args(engine): with pytest.raises(ValueError) as e: - sasd.create_schema_graph() - assert e.value.args[0] == 'You need to specify at least tables or metadata' + sqlalchemy_schemadisplay.create_schema_graph(engine=engine) + assert e.value.args[0] == "You need to specify at least tables or metadata" -def test_empty_db(metadata): - graph = sasd.create_schema_graph(metadata=metadata) +def test_empty_db(metadata, engine): + graph = sqlalchemy_schemadisplay.create_schema_graph(engine=engine, + metadata=metadata) assert isinstance(graph, pydot.Graph) - assert graph.create_plain() == b'graph 1 0 0\nstop\n' + assert graph.create_plain() == b"graph 1 0 0\nstop\n" -def test_empty_table(metadata): - Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) - result = plain_result(metadata=metadata) - assert list(result.keys()) == ['1'] - assert list(result['1']['nodes'].keys()) == ['foo'] - assert '- id : INTEGER' in result['1']['nodes']['foo'] +def test_empty_table(metadata, engine): + foo = Table("foo", metadata, Column("id", types.Integer, primary_key=True)) + result = plain_result(engine=engine, metadata=metadata) + assert list(result.keys()) == ["1"] + assert list(result["1"]["nodes"].keys()) == ["foo"] + assert "- id : INTEGER" in result["1"]["nodes"]["foo"] -def test_empty_table_with_key_suffix(metadata): - Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) - result = plain_result(metadata=metadata, show_column_keys=True) +def test_empty_table_with_key_suffix(metadata, engine): + foo = Table("foo", metadata, Column("id", types.Integer, primary_key=True)) + result = plain_result( + engine=engine, + metadata=metadata, + show_column_keys=True, + ) print(result) - assert list(result.keys()) == ['1'] - assert list(result['1']['nodes'].keys()) == ['foo'] - assert '- id(PK) : INTEGER' in result['1']['nodes']['foo'] + assert list(result.keys()) == ["1"] + assert list(result["1"]["nodes"].keys()) == ["foo"] + assert "- id(PK) : INTEGER" in result["1"]["nodes"]["foo"] -def test_foreign_key(metadata): +def test_foreign_key(metadata, engine): foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) - Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id))) - result = plain_result(metadata=metadata) - assert list(result.keys()) == ['1'] - assert sorted(result['1']['nodes'].keys()) == ['bar', 'foo'] - assert '- id : INTEGER' in result['1']['nodes']['foo'] - assert '- foo_id : INTEGER' in result['1']['nodes']['bar'] - assert 'edges' in result['1'] - assert ('bar', 'foo') in result['1']['edges'] - - -def test_foreign_key_with_key_suffix(metadata): + "foo", + metadata, + Column("id", types.Integer, primary_key=True), + ) + bar = Table( + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + ) + result = plain_result(engine=engine, metadata=metadata) + assert list(result.keys()) == ["1"] + assert sorted(result["1"]["nodes"].keys()) == ["bar", "foo"] + assert "- id : INTEGER" in result["1"]["nodes"]["foo"] + assert "- foo_id : INTEGER" in result["1"]["nodes"]["bar"] + assert "edges" in result["1"] + assert ("bar", "foo") in result["1"]["edges"] + + +def test_foreign_key_with_key_suffix(metadata, engine): foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) - Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id))) - result = plain_result(metadata=metadata, show_column_keys=True) - assert list(result.keys()) == ['1'] - assert sorted(result['1']['nodes'].keys()) == ['bar', 'foo'] - assert '- id(PK) : INTEGER' in result['1']['nodes']['foo'] - assert '- foo_id(FK) : INTEGER' in result['1']['nodes']['bar'] - assert 'edges' in result['1'] - assert ('bar', 'foo') in result['1']['edges'] - - -def test_table_filtering(metadata): + "foo", + metadata, + Column("id", types.Integer, primary_key=True), + ) + bar = Table( + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + ) + result = plain_result(engine=engine, + metadata=metadata, + show_column_keys=True) + assert list(result.keys()) == ["1"] + assert sorted(result["1"]["nodes"].keys()) == ["bar", "foo"] + assert "- id(PK) : INTEGER" in result["1"]["nodes"]["foo"] + assert "- foo_id(FK) : INTEGER" in result["1"]["nodes"]["bar"] + assert "edges" in result["1"] + assert ("bar", "foo") in result["1"]["edges"] + + +def test_table_filtering(engine, metadata): foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) + "foo", + metadata, + Column("id", types.Integer, primary_key=True), + ) bar = Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id))) - result = plain_result(tables=[bar]) - assert list(result.keys()) == ['1'] - assert list(result['1']['nodes'].keys()) == ['bar'] - assert '- foo_id : INTEGER' in result['1']['nodes']['bar'] - -def test_table_rendering_without_schema(metadata): + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + ) + result = plain_result(engine=engine, tables=[bar]) + assert list(result.keys()) == ["1"] + assert list(result["1"]["nodes"].keys()) == ["bar"] + assert "- foo_id : INTEGER" in result["1"]["nodes"]["bar"] + + +def test_table_rendering_without_schema(metadata, engine): foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True)) + "foo", + metadata, + Column("id", types.Integer, primary_key=True), + ) bar = Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id))) + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + ) try: - sasd.create_schema_graph(metadata=metadata).create_png() + sqlalchemy_schemadisplay.create_schema_graph( + engine=engine, metadata=metadata).create_png() except Exception as ex: - assert False, "An exception of type {} was produced when attempting to render a png of the graph".format(ex.__class__.__name__) + assert ( + False + ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" -def test_table_rendering_with_schema(metadata): - foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True), - schema='sch_foo' - ) + +def test_table_rendering_with_schema(metadata, engine): + foo = Table("foo", + metadata, + Column("id", types.Integer, primary_key=True), + schema="sch_foo") bar = Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id)), - schema='sch_bar' + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + schema="sch_bar", ) try: - sasd.create_schema_graph( + sqlalchemy_schemadisplay.create_schema_graph( + engine=engine, metadata=metadata, show_schema_name=True, ).create_png() except Exception as ex: - assert False, "An exception of type {} was produced when attempting to render a png of the graph".format(ex.__class__.__name__) + assert ( + False + ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" -def test_table_rendering_with_schema_and_formatting(metadata): - foo = Table( - 'foo', metadata, - Column('id', types.Integer, primary_key=True), - schema='sch_foo' - ) + +def test_table_rendering_with_schema_and_formatting(metadata, engine): + foo = Table("foo", + metadata, + Column("id", types.Integer, primary_key=True), + schema="sch_foo") bar = Table( - 'bar', metadata, - Column('foo_id', types.Integer, ForeignKey(foo.c.id)), - schema='sch_bar' + "bar", + metadata, + Column("foo_id", types.Integer, ForeignKey(foo.c.id)), + schema="sch_bar", ) try: - sasd.create_schema_graph( + sqlalchemy_schemadisplay.create_schema_graph( + engine=engine, metadata=metadata, show_schema_name=True, - format_schema_name={'fontsize':8.0, 'color': '#888888'}, - format_table_name={'bold':True, 'fontsize': 10.0}, + format_schema_name={ + "fontsize": 8.0, + "color": "#888888" + }, + format_table_name={ + "bold": True, + "fontsize": 10.0 + }, ).create_png() except Exception as ex: - assert False, "An exception of type {} was produced when attempting to render a png of the graph".format(ex.__class__.__name__) + assert ( + False + ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" diff --git a/tests/test_uml_graph.py b/tests/test_uml_graph.py index ff73257..ee4d7a7 100644 --- a/tests/test_uml_graph.py +++ b/tests/test_uml_graph.py @@ -1,18 +1,16 @@ -from sqlalchemy import types -from sqlalchemy import Column -from sqlalchemy import ForeignKey -from sqlalchemy import MetaData -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import class_mapper -from sqlalchemy.orm import relationship -from utils import parse_graph +"""Set of tests for the ORM diagrams""" import pytest -import sqlalchemy_schemadisplay as sasd +from sqlalchemy import Column, ForeignKey, MetaData, types +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import class_mapper, relationship + +import sqlalchemy_schemadisplay +from .utils import parse_graph @pytest.fixture def metadata(request): - return MetaData('sqlite:///:memory:') + return MetaData("sqlite:///:memory:") @pytest.fixture @@ -21,7 +19,7 @@ def Base(request, metadata): def plain_result(mapper, **kw): - return parse_graph(sasd.create_uml_graph(mapper, **kw)) + return parse_graph(sqlalchemy_schemadisplay.create_uml_graph(mapper, **kw)) def mappers(*args): @@ -30,50 +28,62 @@ def mappers(*args): def test_simple_class(Base, capsys): class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(types.Integer, primary_key=True) + result = plain_result(mappers(Foo)) - assert list(result.keys()) == ['1'] - assert list(result['1']['nodes'].keys()) == ['Foo'] - assert '+id : Integer' in result['1']['nodes']['Foo'] + assert list(result.keys()) == ["1"] + assert list(result["1"]["nodes"].keys()) == ["Foo"] + assert "+id : Integer" in result["1"]["nodes"]["Foo"] out, err = capsys.readouterr() - assert out == u'' - assert err == u'' + assert out == "" + assert err == "" def test_relation(Base): class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(types.Integer, primary_key=True) + class Bar(Base): - __tablename__ = 'bar' + __tablename__ = "bar" id = Column(types.Integer, primary_key=True) foo_id = Column(types.Integer, ForeignKey(Foo.id)) + Foo.bars = relationship(Bar) - graph = sasd.create_uml_graph(mappers(Foo, Bar)) - assert sorted(graph.obj_dict['nodes'].keys()) == ['"Bar"', '"Foo"'] - assert '+id : Integer' in graph.obj_dict['nodes']['"Foo"'][0]['attributes']['label'] - assert '+foo_id : Integer' in graph.obj_dict['nodes']['"Bar"'][0]['attributes']['label'] - assert 'edges' in graph.obj_dict - assert ('"Foo"', '"Bar"') in graph.obj_dict['edges'] - assert graph.obj_dict['edges'][('"Foo"', '"Bar"')][0]['attributes']['headlabel'] == '+bars *' + graph = sqlalchemy_schemadisplay.create_uml_graph(mappers(Foo, Bar)) + assert sorted(graph.obj_dict["nodes"].keys()) == ['"Bar"', '"Foo"'] + assert "+id : Integer" in graph.obj_dict["nodes"]['"Foo"'][0][ + "attributes"]["label"] + assert ("+foo_id : Integer" + in graph.obj_dict["nodes"]['"Bar"'][0]["attributes"]["label"]) + assert "edges" in graph.obj_dict + assert ('"Foo"', '"Bar"') in graph.obj_dict["edges"] + assert (graph.obj_dict["edges"][( + '"Foo"', '"Bar"')][0]["attributes"]["headlabel"] == "+bars *") def test_backref(Base): class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(types.Integer, primary_key=True) + class Bar(Base): - __tablename__ = 'bar' + __tablename__ = "bar" id = Column(types.Integer, primary_key=True) foo_id = Column(types.Integer, ForeignKey(Foo.id)) - Foo.bars = relationship(Bar, backref='foo') - graph = sasd.create_uml_graph(mappers(Foo, Bar)) - assert sorted(graph.obj_dict['nodes'].keys()) == ['"Bar"', '"Foo"'] - assert '+id : Integer' in graph.obj_dict['nodes']['"Foo"'][0]['attributes']['label'] - assert '+foo_id : Integer' in graph.obj_dict['nodes']['"Bar"'][0]['attributes']['label'] - assert 'edges' in graph.obj_dict - assert ('"Foo"', '"Bar"') in graph.obj_dict['edges'] - assert ('"Bar"', '"Foo"') in graph.obj_dict['edges'] - assert graph.obj_dict['edges'][('"Foo"', '"Bar"')][0]['attributes']['headlabel'] == '+bars *' - assert graph.obj_dict['edges'][('"Bar"', '"Foo"')][0]['attributes']['headlabel'] == '+foo 0..1' + + Foo.bars = relationship(Bar, backref="foo") + graph = sqlalchemy_schemadisplay.create_uml_graph(mappers(Foo, Bar)) + assert sorted(graph.obj_dict["nodes"].keys()) == ['"Bar"', '"Foo"'] + assert "+id : Integer" in graph.obj_dict["nodes"]['"Foo"'][0][ + "attributes"]["label"] + assert ("+foo_id : Integer" + in graph.obj_dict["nodes"]['"Bar"'][0]["attributes"]["label"]) + assert "edges" in graph.obj_dict + assert ('"Foo"', '"Bar"') in graph.obj_dict["edges"] + assert ('"Bar"', '"Foo"') in graph.obj_dict["edges"] + assert (graph.obj_dict["edges"][( + '"Foo"', '"Bar"')][0]["attributes"]["headlabel"] == "+bars *") + assert (graph.obj_dict["edges"][( + '"Bar"', '"Foo"')][0]["attributes"]["headlabel"] == "+foo 0..1") diff --git a/tests/utils.py b/tests/utils.py index e716207..f509319 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,30 +1,27 @@ -try: - from cStringIO import StringIO -except ImportError: - from io import StringIO +from io import StringIO def parse_graph(graph): result = {} graph_bytes = graph.create_plain() - sio = StringIO(graph_bytes.decode('utf-8')) + sio = StringIO(graph_bytes.decode("utf-8")) graph = None for line in sio: line = line.strip() if not line: continue - if line.startswith('graph'): + if line.startswith("graph"): parts = line.split(None, 4) - graph = result.setdefault(parts[1], {'nodes': {}}) + graph = result.setdefault(parts[1], {"nodes": {}}) if len(parts) > 4: - graph['options'] = parts[4] - elif line.startswith('node'): + graph["options"] = parts[4] + elif line.startswith("node"): parts = line.split(None, 6) - graph['nodes'][parts[1]] = parts[6] - elif line.startswith('edge'): + graph["nodes"][parts[1]] = parts[6] + elif line.startswith("edge"): parts = line.split(None, 3) - graph.setdefault('edges', {})[(parts[1], parts[2])] = parts[3] - elif line == 'stop': + graph.setdefault("edges", {})[(parts[1], parts[2])] = parts[3] + elif line == "stop": graph = None else: raise ValueError("Don't know how to handle line:\n%s" % line) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index bf29f6e..0000000 --- a/tox.ini +++ /dev/null @@ -1,31 +0,0 @@ -[tox] - -[testenv] -commands = py.test --cov-report html --cov-report term --cov sqlalchemy_schemadisplay {posargs} -deps = - pytest - pytest-cov - -[testenv:sqla06-py27] -deps = - {[testenv]deps} - sqlalchemy==0.6.* - -[testenv:sqla07-py27] -deps = - {[testenv]deps} - sqlalchemy==0.7.* - -[testenv:sqla08-py27] -deps = - {[testenv]deps} - sqlalchemy==0.8.* - -[testenv:sqla09-py27] -deps = - {[testenv]deps} - sqlalchemy==0.9.* - -[testenv:sqlalchemy-py27] - -[testenv:sqlalchemy-py3]