-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
62 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,131 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
import csv | ||
import enum | ||
import io | ||
from typing import Any, Optional | ||
import logging | ||
import pathlib | ||
|
||
from pydantic.dataclasses import dataclass | ||
from sqlalchemy_redshift import dialect | ||
import s3fs | ||
import pydantic | ||
import redshift_connector | ||
import sqlalchemy as sa | ||
from sqlalchemy.engine.url import URL | ||
import sqlmodel | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
from cs_tools.sync.base import DatabaseSyncer | ||
from cs_tools.sync.types import TableRows | ||
|
||
class AuthType(enum.Enum): | ||
local = "local" | ||
okta = "okta" | ||
log = logging.getLogger(__name__) | ||
|
||
class Redshift(DatabaseSyncer): | ||
|
||
@dataclass | ||
class Redshift: | ||
""" | ||
Interact with an AWS Redshift database. | ||
Interact with Redshift DataBase | ||
""" | ||
|
||
username: str | ||
password: str | ||
database: str | ||
aws_access_key: str # for S3 data load | ||
aws_secret_key: str # for S3 data load | ||
aws_endpoint: str # FMT: <clusterid>.xxxxxx.<aws-region>.redshift.amazonaws.com | ||
port: int = 5439 | ||
auth_type: AuthType = AuthType.local | ||
# okta_account_name: str = None | ||
# okta_app_id: str = None | ||
truncate_on_load: bool = True | ||
|
||
# DATABASE ATTRIBUTES | ||
__is_database__ = True | ||
|
||
def __post_init_post_parse__(self): | ||
if self.auth_type == AuthType.local: | ||
connect_args = {} | ||
url = sa.engine.URL.create( | ||
drivername="redshift+redshift_connector", | ||
host=self.aws_endpoint, | ||
port=self.port, | ||
database=self.database, | ||
username=self.username, | ||
password=self.password, | ||
) | ||
|
||
elif self.auth_type == AuthType.okta: | ||
# aws_cluster_id, _, aws_region, *_ = self.aws_endpoint.split('.') | ||
# connect_args = { | ||
# 'credentials_provider': 'OktaCredentialsProvider', | ||
# 'idp_host': '<prefix>.okta.com', | ||
# 'app_id': '<appid>', | ||
# 'app_name': 'amazon_aws_redshift', | ||
# 'cluster_identifier': aws_cluster_id, | ||
# 'region': aws_region, | ||
# 'ssl_insecure': False, | ||
# **connect_args | ||
# } | ||
# url = sa.engine.URL.create( | ||
# drivername='redshift+redshift_connector', | ||
# database=self.database, | ||
# username=self.username, | ||
# password=self.password | ||
# ) | ||
raise NotImplementedError( | ||
"our implementation is best-effort, but lacks testing.. see the source " | ||
"code for ideas on how to implement MFA to Okta." | ||
) | ||
|
||
self.engine = sa.create_engine(url, connect_args=connect_args) | ||
self.cnxn = self.engine.connect() | ||
|
||
# decorators must be declared here, SQLAlchemy doesn't care about instances | ||
sa.event.listen(sa.schema.MetaData, "after_create", self.capture_metadata) | ||
|
||
def capture_metadata(self, metadata, cnxn, **kw): | ||
self.metadata = metadata | ||
|
||
def __repr__(self): | ||
return f"<Database ({self.name}) sync: conn_string='{self.engine.url}'>" | ||
|
||
# MANDATORY PROTOCOL MEMBERS | ||
|
||
@property | ||
def name(self) -> str: | ||
return "redshift" | ||
|
||
def load(self, table: str) -> list[dict[str, Any]]: | ||
t = self.metadata.tables[table] | ||
|
||
with self.cnxn.begin(): | ||
r = self.cnxn.execute(t.select()) | ||
|
||
return [dict(_) for _ in r] | ||
|
||
def dump(self, table: str, *, data: list[dict[str, Any]]) -> None: | ||
__manifest_path__ = pathlib.Path(__file__).parent | ||
__syncer_name__ = "Redshift" | ||
|
||
host : str | ||
database : str | ||
user : str | ||
password : str | ||
port : int | ||
username : str | ||
password : str | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
print(kwargs) | ||
self.engine_url=URL.create( | ||
drivername='redshift+psycopg2', # indicate redshift_connector driver and dialect will be used | ||
host=self.host, | ||
port=self.port, | ||
database=self.database, # Amazon Redshift database | ||
username=self.user, # Okta username | ||
password=self.password # Okta password | ||
) | ||
self._engine = sa.create_engine(self.engine_url) | ||
# self.metadata = sqlmodel.MetaData(schema=self.schema_) | ||
|
||
|
||
def load(self, tablename: str) -> TableRows: | ||
"""SELECT rows from Redshift""" | ||
table = self.metadata.tables[f"{tablename}"] | ||
rows = self.session.execute(table.select()) | ||
return [row.model_dump() for row in rows] | ||
|
||
def dump(self, tablename: str, *, data: TableRows) -> None: | ||
|
||
|
||
table = self.metadata.tables[f"{tablename}"] | ||
if not data: | ||
log.warning(f"no data to write to syncer {self}") | ||
log.warning(f"No data to write to syncer {table}") | ||
return | ||
|
||
t = self.metadata.tables[table] | ||
|
||
if self.truncate_on_load: | ||
with self.cnxn.begin(): | ||
self.cnxn.execute(table.delete().where(True)) | ||
|
||
# 1. Load file to S3 | ||
fs = s3fs.S3FileSystem(key=self.aws_access_key, secret=self.aws_secret_key) | ||
fp = f"s3://{self.s3_bucket_name}/ts_{table}.csv" | ||
|
||
with io.StringIO() as buf, fs.open(fp, "w") as f: | ||
header = list(data[0].keys()) | ||
writer = csv.DictWriter(buf, fieldnames=header, dialect="excel", delimiter="|") | ||
writer.writeheader() | ||
writer.writerows(data) | ||
|
||
f.write(buf.getvalue()) | ||
|
||
# 2. Perform a COPY operation | ||
q = dialect.CopyCommand(t, data_location=fp, ignore_header=0) | ||
|
||
with self.cnxn.begin(): | ||
self.cnxn.execute(q) # .execution_options(autocommit=True) | ||
|
||
if self.load_strategy == "APPEND": | ||
self.session.execute(table.insert(), data) | ||
self.session.commit() | ||
|
||
if self.load_strategy == "TRUNCATE": | ||
self.session.execute(table.delete()) | ||
self.session.execute(table.insert(), data) | ||
|
||
if self.load_strategy == "UPSERT": | ||
self.session.execute(table.merge()) | ||
self.session.commit() # |