Skip to content

Commit

Permalink
Merge pull request #745 from MasoniteFramework/feature/737
Browse files Browse the repository at this point in the history
added schema for postgres
  • Loading branch information
josephmancuso authored Jul 4, 2022
2 parents 25480d7 + 68291ca commit ac3ce22
Show file tree
Hide file tree
Showing 19 changed files with 96 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def up(self):
table.timestamps()

if not self.schema._dry:
User.on(self.connection).create({
User.on(self.connection).set_schema(self.schema_name).create({
'name': 'Joe',
'email': '[email protected]',
'password': 'secret'
Expand Down
2 changes: 2 additions & 0 deletions src/masoniteorm/commands/MigrateCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MigrateCommand(Command):
{--c|connection=default : The connection you want to run migrations on}
{--f|force : Force migrations without prompt in production}
{--s|show : Shows the output of SQL for migrations that would be running}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""

Expand All @@ -32,6 +33,7 @@ def handle(self):
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.create_table_if_not_exists()
if not migration.get_unran_migrations():
Expand Down
3 changes: 3 additions & 0 deletions src/masoniteorm/commands/MigrateRefreshCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MigrateRefreshCommand(Command):
{--c|connection=default : The connection you want to run migrations on}
{--d|directory=databases/migrations : The location of the migration directory}
{--s|seed=? : Seed database after refresh. The seeder to be ran can be provided in argument}
{--schema=? : Sets the schema to be migrated}
{--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.}
"""

Expand All @@ -22,6 +23,7 @@ def handle(self):
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema")
)

migration.refresh(self.option("migration"))
Expand All @@ -33,6 +35,7 @@ def handle(self):
"seed:run",
f"None --directory {self.option('seed-directory')} --connection {self.option('connection', 'default')}",
)

elif self.option("seed"):
self.call(
"seed:run",
Expand Down
3 changes: 3 additions & 0 deletions src/masoniteorm/commands/MigrateResetCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class MigrateResetCommand(Command):
migrate:reset
{--m|migration=all : Migration's name to be rollback}
{--c|connection=default : The connection you want to run migrations on}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""

Expand All @@ -18,5 +19,7 @@ def handle(self):
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)

migration.reset(self.option("migration"))
2 changes: 2 additions & 0 deletions src/masoniteorm/commands/MigrateRollbackCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class MigrateRollbackCommand(Command):
{--m|migration=all : Migration's name to be rollback}
{--c|connection=default : The connection you want to run migrations on}
{--s|show : Shows the output of SQL for migrations that would be running}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""

Expand All @@ -19,4 +20,5 @@ def handle(self):
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
).rollback(migration=self.option("migration"), output=self.option("show"))
2 changes: 2 additions & 0 deletions src/masoniteorm/commands/MigrateStatusCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class MigrateStatusCommand(Command):
migrate:status
{--c|connection=default : The connection you want to run migrations on}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""

Expand All @@ -17,6 +18,7 @@ def handle(self):
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.create_table_if_not_exists()
table = self.table()
Expand Down
4 changes: 4 additions & 0 deletions src/masoniteorm/connections/BaseConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def dry(self):
self._dry = True
return self

def set_schema(self, schema):
self.schema = schema
return self

def log(
self, query, bindings, query_time=0, logger="masoniteorm.connections.queries"
):
Expand Down
8 changes: 6 additions & 2 deletions src/masoniteorm/connections/ConnectionResolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def set_connection_details(self, connection_details):
def get_connection_details(self):
return self._connection_details

def set_connection_option(self, connection: str, options: dict):
self._connection_details.get(connection).update(options)
return self

def get_global_connections(self):
return self._connections

Expand Down Expand Up @@ -104,11 +108,11 @@ def get_connection_information(self, name):
"full_details": details.get(name, {}),
}

def get_schema_builder(self, connection="default"):
def get_schema_builder(self, connection="default", schema=None):
from ..schema import Schema

return Schema(
connection=connection, connection_details=self.get_connection_details()
connection=connection, connection_details=self.get_connection_details(), schema=schema
)

def get_query_builder(self, connection="default"):
Expand Down
4 changes: 4 additions & 0 deletions src/masoniteorm/connections/PostgresConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self._cursor = None
self.transaction_level = 0
self.open = 0
self.schema = None
if name:
self.name = name

Expand All @@ -56,12 +57,15 @@ def make_connection(self):
if self.has_global_connection():
return self.get_global_connection()

schema = self.schema or self.full_details.get("schema")

self._connection = psycopg2.connect(
database=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
options=f"-c search_path={schema}" if schema else "",
)

self._connection.autocommit = True
Expand Down
21 changes: 17 additions & 4 deletions src/masoniteorm/migrations/Migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@ def __init__(
command_class=None,
migration_directory="databases/migrations",
config_path=None,
schema=None,
):
self.connection = connection
self.migration_directory = migration_directory
self.last_migrations_ran = []
self.command_class = command_class

self.schema_name = schema

DB = load_config(config_path).DB

DATABASES = DB.get_connection_details()

self.schema = Schema(
connection=connection, connection_details=DATABASES, dry=dry
connection=connection, connection_details=DATABASES, dry=dry, schema=self.schema_name
)

self.migration_model = MigrationModel.on(self.connection)
if self.schema_name:
self.migration_model.set_schema(self.schema_name)

def create_table_if_not_exists(self):
if not self.schema.has_table("migrations"):
Expand Down Expand Up @@ -132,7 +137,7 @@ def migrate(self, migration="all", output=False):
f"<comment>Migrating:</comment> <question>{migration}</question>"
)

migration_class = migration_class(connection=self.connection)
migration_class = migration_class(connection=self.connection, schema=self.schema_name)

if output:
migration_class.schema.dry()
Expand Down Expand Up @@ -182,7 +187,7 @@ def rollback(self, migration="all", output=False):
self.command_class.line(f"<error>Not Found: {migration}</error>")
continue

migration_class = migration_class(connection=self.connection)
migration_class = migration_class(connection=self.connection, schema=self.schema_name)

if output:
migration_class.schema.dry()
Expand Down Expand Up @@ -230,14 +235,22 @@ def reset(self, migration="all"):
default_migrations = self.get_all_migrations(reverse=True)
migrations = default_migrations if migration == "all" else [migration]

if not len(migrations):
if self.command_class:
self.command_class.line(
"<info>Nothing to reset</info>"
)
else:
print("Nothing to reset")

for migration in migrations:
if self.command_class:
self.command_class.line(
f"<comment>Rolling back:</comment> <question>{migration}</question>"
)

try:
self.locate(migration)(connection=self.connection).down()
self.locate(migration)(connection=self.connection, schema=self.schema_name).down()
except TypeError:
self.command_class.line(f"<error>Not Found: {migration}</error>")
continue
Expand Down
1 change: 1 addition & 0 deletions src/masoniteorm/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class Model(TimeStampsMixin, ObservesEvents, metaclass=ModelMeta):
"select_raw",
"select",
"set_global_scope",
"set_schema",
"shared_lock",
"simple_paginate",
"skip",
Expand Down
20 changes: 14 additions & 6 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
connection_driver="default",
model=None,
scopes=None,
schema=None,
dry=False,
):
"""QueryBuilder initializer
Expand All @@ -70,6 +71,7 @@ def __init__(
self._connection_driver = connection_driver
self._scopes = scopes or {}
self.lock = False
self._schema = schema
self._eager_relation = EagerRelations()
if model:
self._global_scopes = model._global_scopes
Expand Down Expand Up @@ -119,6 +121,10 @@ def _set_creates_related(self, fields: dict):
self._creates_related = fields
return self

def set_schema(self, schema):
self._schema = schema
return self

def shared_lock(self):
return self.make_lock("share")

Expand Down Expand Up @@ -1163,9 +1169,7 @@ def or_where_has(self, relationship, callback):
)
continue

last_builder = related.query_has(
last_builder, method="where_exists"
)
last_builder = related.query_has(last_builder, method="where_exists")
else:
related = getattr(self._model, relationship)
related.query_where_exists(self, callback, method="or_where_exists")
Expand Down Expand Up @@ -1937,9 +1941,13 @@ def new_connection(self):
if self._connection:
return self._connection

self._connection = self.connection_class(
**self.get_connection_information(), name=self.connection
).make_connection()
self._connection = (
self.connection_class(
**self.get_connection_information(), name=self.connection
)
.set_schema(self._schema)
.make_connection()
)
return self._connection

def get_connection(self):
Expand Down
4 changes: 3 additions & 1 deletion src/masoniteorm/schema/Blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(
table="",
connection=None,
platform=None,
schema=None,
action=None,
default_string_length=None,
dry=False,
Expand All @@ -16,6 +17,7 @@ def __init__(
self._last_column = None
self._default_string_length = default_string_length
self.platform = platform
self.schema = schema
self._dry = dry
self._action = action
self.connection = connection
Expand Down Expand Up @@ -689,7 +691,7 @@ def to_sql(self):
if not self._dry:
# get current table schema
table = self.platform().get_current_schema(
self.connection, self.table.name
self.connection, self.table.name, schema=self.schema
)
self.table.from_table = table

Expand Down
23 changes: 19 additions & 4 deletions src/masoniteorm/schema/Schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
platform=None,
grammar=None,
connection_details=None,
connection_driver=None,
schema=None,
):
self._dry = dry
self.connection = connection
Expand All @@ -68,6 +68,7 @@ def __init__(
self.connection_details = connection_details or {}
self._blueprint = None
self._sql = None
self.schema = schema

if not self.connection_class:
self.on(self.connection)
Expand Down Expand Up @@ -133,6 +134,7 @@ def create(self, table):
table=Table(table),
action="create",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
Expand All @@ -148,6 +150,7 @@ def create_table_if_not_exists(self, table):
table=Table(table),
action="create_table_if_not_exists",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
Expand All @@ -173,6 +176,7 @@ def table(self, table):
table=TableDiff(table),
action="alter",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
Expand Down Expand Up @@ -203,7 +207,7 @@ def new_connection(self):

self._connection = self.connection_class(
**self.get_connection_information()
).make_connection()
).set_schema(self.schema).make_connection()

return self._connection

Expand All @@ -225,7 +229,9 @@ def has_column(self, table, column, query_only=False):
return bool(self.new_connection().query(sql, ()))

def get_columns(self, table, dict=True):
table = self.platform().get_current_schema(self.new_connection(), table)
table = self.platform().get_current_schema(
self.new_connection(), table, schema=self.get_schema()
)
result = {}
if dict:
for column in table.get_added_columns().items():
Expand Down Expand Up @@ -278,6 +284,13 @@ def truncate(self, table, foreign_keys=False):

return bool(self.new_connection().query(sql, ()))

def get_schema(self):
"""Gets the schema set on the migration class
"""
return self.schema or self.get_connection_information().get("full_details").get(
"schema"
)

def has_table(self, table, query_only=False):
"""Checks if the a database has a specific table
Arguments:
Expand All @@ -286,7 +299,9 @@ def has_table(self, table, query_only=False):
masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object.
"""
sql = self.platform().compile_table_exists(
table, database=self.get_connection_information().get("database")
table,
database=self.get_connection_information().get("database"),
schema=self.get_schema(),
)

if self._dry:
Expand Down
Loading

0 comments on commit ac3ce22

Please sign in to comment.