diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..116b56d8 Binary files /dev/null and b/.DS_Store differ diff --git a/cs_tools/sync/redshift/MANIFEST.json b/cs_tools/sync/redshift/MANIFEST.json index b5ed5ca5..14c6bff1 100644 --- a/cs_tools/sync/redshift/MANIFEST.json +++ b/cs_tools/sync/redshift/MANIFEST.json @@ -3,7 +3,7 @@ "syncer_class": "Redshift", "requirements": [ "sqlalchemy-redshift>=1.4.1", - "redshift-connector>=2.0.905", - "s3fs>=2022.2.0" + "redshift_connector>=2.1.0", + "psycopg2-binary>=2.9.9" ] } diff --git a/cs_tools/sync/redshift/syncer.py b/cs_tools/sync/redshift/syncer.py index 53963d38..51d90039 100644 --- a/cs_tools/sync/redshift/syncer.py +++ b/cs_tools/sync/redshift/syncer.py @@ -1,131 +1,71 @@ -from __future__ import annotations - -from typing import Any -import csv -import enum -import io +from typing import Any, Optional import logging +import pathlib -from pydantic.dataclasses import dataclass -from sqlalchemy_redshift import dialect -import s3fs +import pydantic +import redshift_connector import sqlalchemy as sa +from sqlalchemy.engine.url import URL +import sqlmodel -log = logging.getLogger(__name__) - +from cs_tools.sync.base import DatabaseSyncer +from cs_tools.sync.types import TableRows -class AuthType(enum.Enum): - local = "local" - okta = "okta" +log = logging.getLogger(__name__) +class Redshift(DatabaseSyncer): -@dataclass -class Redshift: """ - Interact with an AWS Redshift database. + Interact with Redshift DataBase """ - - username: str - password: str - database: str - aws_access_key: str # for S3 data load - aws_secret_key: str # for S3 data load - aws_endpoint: str # FMT: .xxxxxx..redshift.amazonaws.com - port: int = 5439 - auth_type: AuthType = AuthType.local - # okta_account_name: str = None - # okta_app_id: str = None - truncate_on_load: bool = True - - # DATABASE ATTRIBUTES - __is_database__ = True - - def __post_init_post_parse__(self): - if self.auth_type == AuthType.local: - connect_args = {} - url = sa.engine.URL.create( - drivername="redshift+redshift_connector", - host=self.aws_endpoint, - port=self.port, - database=self.database, - username=self.username, - password=self.password, - ) - - elif self.auth_type == AuthType.okta: - # aws_cluster_id, _, aws_region, *_ = self.aws_endpoint.split('.') - # connect_args = { - # 'credentials_provider': 'OktaCredentialsProvider', - # 'idp_host': '.okta.com', - # 'app_id': '', - # 'app_name': 'amazon_aws_redshift', - # 'cluster_identifier': aws_cluster_id, - # 'region': aws_region, - # 'ssl_insecure': False, - # **connect_args - # } - # url = sa.engine.URL.create( - # drivername='redshift+redshift_connector', - # database=self.database, - # username=self.username, - # password=self.password - # ) - raise NotImplementedError( - "our implementation is best-effort, but lacks testing.. see the source " - "code for ideas on how to implement MFA to Okta." - ) - - self.engine = sa.create_engine(url, connect_args=connect_args) - self.cnxn = self.engine.connect() - - # decorators must be declared here, SQLAlchemy doesn't care about instances - sa.event.listen(sa.schema.MetaData, "after_create", self.capture_metadata) - - def capture_metadata(self, metadata, cnxn, **kw): - self.metadata = metadata - - def __repr__(self): - return f"" - - # MANDATORY PROTOCOL MEMBERS - - @property - def name(self) -> str: - return "redshift" - - def load(self, table: str) -> list[dict[str, Any]]: - t = self.metadata.tables[table] - - with self.cnxn.begin(): - r = self.cnxn.execute(t.select()) - - return [dict(_) for _ in r] - - def dump(self, table: str, *, data: list[dict[str, Any]]) -> None: + __manifest_path__ = pathlib.Path(__file__).parent + __syncer_name__ = "Redshift" + + host : str + database : str + user : str + password : str + port : int + username : str + password : str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + print(kwargs) + self.engine_url=URL.create( + drivername='redshift+psycopg2', # indicate redshift_connector driver and dialect will be used + host=self.host, + port=self.port, + database=self.database, # Amazon Redshift database + username=self.user, # Okta username + password=self.password # Okta password + ) + self._engine = sa.create_engine(self.engine_url) + # self.metadata = sqlmodel.MetaData(schema=self.schema_) + + + def load(self, tablename: str) -> TableRows: + """SELECT rows from Redshift""" + table = self.metadata.tables[f"{tablename}"] + rows = self.session.execute(table.select()) + return [row.model_dump() for row in rows] + + def dump(self, tablename: str, *, data: TableRows) -> None: + + + table = self.metadata.tables[f"{tablename}"] if not data: - log.warning(f"no data to write to syncer {self}") + log.warning(f"No data to write to syncer {table}") return - - t = self.metadata.tables[table] - - if self.truncate_on_load: - with self.cnxn.begin(): - self.cnxn.execute(table.delete().where(True)) - - # 1. Load file to S3 - fs = s3fs.S3FileSystem(key=self.aws_access_key, secret=self.aws_secret_key) - fp = f"s3://{self.s3_bucket_name}/ts_{table}.csv" - - with io.StringIO() as buf, fs.open(fp, "w") as f: - header = list(data[0].keys()) - writer = csv.DictWriter(buf, fieldnames=header, dialect="excel", delimiter="|") - writer.writeheader() - writer.writerows(data) - - f.write(buf.getvalue()) - - # 2. Perform a COPY operation - q = dialect.CopyCommand(t, data_location=fp, ignore_header=0) - - with self.cnxn.begin(): - self.cnxn.execute(q) # .execution_options(autocommit=True) + + if self.load_strategy == "APPEND": + self.session.execute(table.insert(), data) + self.session.commit() + + if self.load_strategy == "TRUNCATE": + self.session.execute(table.delete()) + self.session.execute(table.insert(), data) + + if self.load_strategy == "UPSERT": + self.session.execute(table.merge()) + self.session.commit() #