Skip to content

Commit

Permalink
created redshift syncer (#121)
Browse files Browse the repository at this point in the history
* created redshift syncer
  • Loading branch information
saurabhsingh1608 authored Apr 10, 2024
1 parent b45bb06 commit bdb4094
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 122 deletions.
Binary file added .DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions cs_tools/sync/redshift/MANIFEST.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"syncer_class": "Redshift",
"requirements": [
"sqlalchemy-redshift>=1.4.1",
"redshift-connector>=2.0.905",
"s3fs>=2022.2.0"
"redshift_connector>=2.1.0",
"psycopg2-binary>=2.9.9"
]
}
180 changes: 60 additions & 120 deletions cs_tools/sync/redshift/syncer.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,71 @@
from __future__ import annotations

from typing import Any
import csv
import enum
import io
from typing import Any, Optional
import logging
import pathlib

from pydantic.dataclasses import dataclass
from sqlalchemy_redshift import dialect
import s3fs
import pydantic
import redshift_connector
import sqlalchemy as sa
from sqlalchemy.engine.url import URL
import sqlmodel

log = logging.getLogger(__name__)

from cs_tools.sync.base import DatabaseSyncer
from cs_tools.sync.types import TableRows

class AuthType(enum.Enum):
local = "local"
okta = "okta"
log = logging.getLogger(__name__)

class Redshift(DatabaseSyncer):

@dataclass
class Redshift:
"""
Interact with an AWS Redshift database.
Interact with Redshift DataBase
"""

username: str
password: str
database: str
aws_access_key: str # for S3 data load
aws_secret_key: str # for S3 data load
aws_endpoint: str # FMT: <clusterid>.xxxxxx.<aws-region>.redshift.amazonaws.com
port: int = 5439
auth_type: AuthType = AuthType.local
# okta_account_name: str = None
# okta_app_id: str = None
truncate_on_load: bool = True

# DATABASE ATTRIBUTES
__is_database__ = True

def __post_init_post_parse__(self):
if self.auth_type == AuthType.local:
connect_args = {}
url = sa.engine.URL.create(
drivername="redshift+redshift_connector",
host=self.aws_endpoint,
port=self.port,
database=self.database,
username=self.username,
password=self.password,
)

elif self.auth_type == AuthType.okta:
# aws_cluster_id, _, aws_region, *_ = self.aws_endpoint.split('.')
# connect_args = {
# 'credentials_provider': 'OktaCredentialsProvider',
# 'idp_host': '<prefix>.okta.com',
# 'app_id': '<appid>',
# 'app_name': 'amazon_aws_redshift',
# 'cluster_identifier': aws_cluster_id,
# 'region': aws_region,
# 'ssl_insecure': False,
# **connect_args
# }
# url = sa.engine.URL.create(
# drivername='redshift+redshift_connector',
# database=self.database,
# username=self.username,
# password=self.password
# )
raise NotImplementedError(
"our implementation is best-effort, but lacks testing.. see the source "
"code for ideas on how to implement MFA to Okta."
)

self.engine = sa.create_engine(url, connect_args=connect_args)
self.cnxn = self.engine.connect()

# decorators must be declared here, SQLAlchemy doesn't care about instances
sa.event.listen(sa.schema.MetaData, "after_create", self.capture_metadata)

def capture_metadata(self, metadata, cnxn, **kw):
self.metadata = metadata

def __repr__(self):
return f"<Database ({self.name}) sync: conn_string='{self.engine.url}'>"

# MANDATORY PROTOCOL MEMBERS

@property
def name(self) -> str:
return "redshift"

def load(self, table: str) -> list[dict[str, Any]]:
t = self.metadata.tables[table]

with self.cnxn.begin():
r = self.cnxn.execute(t.select())

return [dict(_) for _ in r]

def dump(self, table: str, *, data: list[dict[str, Any]]) -> None:
__manifest_path__ = pathlib.Path(__file__).parent
__syncer_name__ = "Redshift"

host : str
database : str
user : str
password : str
port : int
username : str
password : str

def __init__(self, **kwargs):
super().__init__(**kwargs)
print(kwargs)
self.engine_url=URL.create(
drivername='redshift+psycopg2', # indicate redshift_connector driver and dialect will be used
host=self.host,
port=self.port,
database=self.database, # Amazon Redshift database
username=self.user, # Okta username
password=self.password # Okta password
)
self._engine = sa.create_engine(self.engine_url)
# self.metadata = sqlmodel.MetaData(schema=self.schema_)


def load(self, tablename: str) -> TableRows:
"""SELECT rows from Redshift"""
table = self.metadata.tables[f"{tablename}"]
rows = self.session.execute(table.select())
return [row.model_dump() for row in rows]

def dump(self, tablename: str, *, data: TableRows) -> None:


table = self.metadata.tables[f"{tablename}"]
if not data:
log.warning(f"no data to write to syncer {self}")
log.warning(f"No data to write to syncer {table}")
return

t = self.metadata.tables[table]

if self.truncate_on_load:
with self.cnxn.begin():
self.cnxn.execute(table.delete().where(True))

# 1. Load file to S3
fs = s3fs.S3FileSystem(key=self.aws_access_key, secret=self.aws_secret_key)
fp = f"s3://{self.s3_bucket_name}/ts_{table}.csv"

with io.StringIO() as buf, fs.open(fp, "w") as f:
header = list(data[0].keys())
writer = csv.DictWriter(buf, fieldnames=header, dialect="excel", delimiter="|")
writer.writeheader()
writer.writerows(data)

f.write(buf.getvalue())

# 2. Perform a COPY operation
q = dialect.CopyCommand(t, data_location=fp, ignore_header=0)

with self.cnxn.begin():
self.cnxn.execute(q) # .execution_options(autocommit=True)

if self.load_strategy == "APPEND":
self.session.execute(table.insert(), data)
self.session.commit()

if self.load_strategy == "TRUNCATE":
self.session.execute(table.delete())
self.session.execute(table.insert(), data)

if self.load_strategy == "UPSERT":
self.session.execute(table.merge())
self.session.commit() #

0 comments on commit bdb4094

Please sign in to comment.