Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

created redshift syncer #121

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions cs_tools/sync/redshift/MANIFEST.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"syncer_class": "Redshift",
"requirements": [
"sqlalchemy-redshift>=1.4.1",
"redshift-connector>=2.0.905",
"s3fs>=2022.2.0"
"redshift_connector>=2.1.0",
"psycopg2-binary>=2.9.9"
]
}
180 changes: 60 additions & 120 deletions cs_tools/sync/redshift/syncer.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,71 @@
from __future__ import annotations

from typing import Any
import csv
import enum
import io
from typing import Any, Optional
import logging
import pathlib

from pydantic.dataclasses import dataclass
from sqlalchemy_redshift import dialect
import s3fs
import pydantic
import redshift_connector
import sqlalchemy as sa
from sqlalchemy.engine.url import URL
import sqlmodel

log = logging.getLogger(__name__)

from cs_tools.sync.base import DatabaseSyncer
from cs_tools.sync.types import TableRows

class AuthType(enum.Enum):
local = "local"
okta = "okta"
log = logging.getLogger(__name__)

class Redshift(DatabaseSyncer):

@dataclass
class Redshift:
"""
Interact with an AWS Redshift database.
Interact with Redshift DataBase
"""

username: str
password: str
database: str
aws_access_key: str # for S3 data load
aws_secret_key: str # for S3 data load
aws_endpoint: str # FMT: <clusterid>.xxxxxx.<aws-region>.redshift.amazonaws.com
port: int = 5439
auth_type: AuthType = AuthType.local
# okta_account_name: str = None
# okta_app_id: str = None
truncate_on_load: bool = True

# DATABASE ATTRIBUTES
__is_database__ = True

def __post_init_post_parse__(self):
if self.auth_type == AuthType.local:
connect_args = {}
url = sa.engine.URL.create(
drivername="redshift+redshift_connector",
host=self.aws_endpoint,
port=self.port,
database=self.database,
username=self.username,
password=self.password,
)

elif self.auth_type == AuthType.okta:
# aws_cluster_id, _, aws_region, *_ = self.aws_endpoint.split('.')
# connect_args = {
# 'credentials_provider': 'OktaCredentialsProvider',
# 'idp_host': '<prefix>.okta.com',
# 'app_id': '<appid>',
# 'app_name': 'amazon_aws_redshift',
# 'cluster_identifier': aws_cluster_id,
# 'region': aws_region,
# 'ssl_insecure': False,
# **connect_args
# }
# url = sa.engine.URL.create(
# drivername='redshift+redshift_connector',
# database=self.database,
# username=self.username,
# password=self.password
# )
raise NotImplementedError(
"our implementation is best-effort, but lacks testing.. see the source "
"code for ideas on how to implement MFA to Okta."
)

self.engine = sa.create_engine(url, connect_args=connect_args)
self.cnxn = self.engine.connect()

# decorators must be declared here, SQLAlchemy doesn't care about instances
sa.event.listen(sa.schema.MetaData, "after_create", self.capture_metadata)

def capture_metadata(self, metadata, cnxn, **kw):
self.metadata = metadata

def __repr__(self):
return f"<Database ({self.name}) sync: conn_string='{self.engine.url}'>"

# MANDATORY PROTOCOL MEMBERS

@property
def name(self) -> str:
return "redshift"

def load(self, table: str) -> list[dict[str, Any]]:
t = self.metadata.tables[table]

with self.cnxn.begin():
r = self.cnxn.execute(t.select())

return [dict(_) for _ in r]

def dump(self, table: str, *, data: list[dict[str, Any]]) -> None:
__manifest_path__ = pathlib.Path(__file__).parent
__syncer_name__ = "Redshift"

host : str
database : str
user : str
password : str
port : int
username : str
password : str

def __init__(self, **kwargs):
super().__init__(**kwargs)
print(kwargs)
self.engine_url=URL.create(
drivername='redshift+psycopg2', # indicate redshift_connector driver and dialect will be used
host=self.host,
port=self.port,
database=self.database, # Amazon Redshift database
username=self.user, # Okta username
password=self.password # Okta password
)
self._engine = sa.create_engine(self.engine_url)
# self.metadata = sqlmodel.MetaData(schema=self.schema_)


def load(self, tablename: str) -> TableRows:
"""SELECT rows from Redshift"""
table = self.metadata.tables[f"{tablename}"]
rows = self.session.execute(table.select())
return [row.model_dump() for row in rows]

def dump(self, tablename: str, *, data: TableRows) -> None:


table = self.metadata.tables[f"{tablename}"]
if not data:
log.warning(f"no data to write to syncer {self}")
log.warning(f"No data to write to syncer {table}")
return

t = self.metadata.tables[table]

if self.truncate_on_load:
with self.cnxn.begin():
self.cnxn.execute(table.delete().where(True))

# 1. Load file to S3
fs = s3fs.S3FileSystem(key=self.aws_access_key, secret=self.aws_secret_key)
fp = f"s3://{self.s3_bucket_name}/ts_{table}.csv"

with io.StringIO() as buf, fs.open(fp, "w") as f:
header = list(data[0].keys())
writer = csv.DictWriter(buf, fieldnames=header, dialect="excel", delimiter="|")
writer.writeheader()
writer.writerows(data)

f.write(buf.getvalue())

# 2. Perform a COPY operation
q = dialect.CopyCommand(t, data_location=fp, ignore_header=0)

with self.cnxn.begin():
self.cnxn.execute(q) # .execution_options(autocommit=True)

if self.load_strategy == "APPEND":
self.session.execute(table.insert(), data)
self.session.commit()

if self.load_strategy == "TRUNCATE":
self.session.execute(table.delete())
self.session.execute(table.insert(), data)

if self.load_strategy == "UPSERT":
self.session.execute(table.merge())
self.session.commit() #