Skip to content

Commit

Permalink
chore: Upgrade to aws-sdk-go-v2 (#37)
Browse files Browse the repository at this point in the history
The conversion is essentially complete. The only gap is in the
implementation of `Rows.Next()`, where a context is not available from
the caller for use when fetching more results from Athena, because the
function is defined in its interface to not include a context.

Related to contexts, the driver does not yet implement `DriverContext`,
but should be, separately.
  • Loading branch information
bhavanki authored Sep 3, 2024
1 parent 3d220fb commit 76f4666
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 121 deletions.
14 changes: 14 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package athena

import (
"context"

"github.com/aws/aws-sdk-go-v2/service/athena"
)

type athenaAPI interface {
GetQueryExecution(context.Context, *athena.GetQueryExecutionInput, ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error)
GetQueryResults(context.Context, *athena.GetQueryResultsInput, ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error)
StartQueryExecution(context.Context, *athena.StartQueryExecutionInput, ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error)
StopQueryExecution(context.Context, *athena.StopQueryExecutionInput, ...func(*athena.Options)) (*athena.StopQueryExecutionOutput, error)
}
38 changes: 19 additions & 19 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"errors"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/athena"
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/athena"
"github.com/aws/aws-sdk-go-v2/service/athena/types"
)

type conn struct {
athena athenaiface.AthenaAPI
athena athenaAPI
db string
OutputLocation string

Expand All @@ -38,7 +38,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}

func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) {
queryID, err := c.startQuery(query)
queryID, err := c.startQuery(ctx, query)
if err != nil {
return nil, err
}
Expand All @@ -47,7 +47,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
return nil, err
}

return newRows(rowsConfig{
return newRows(ctx, rowsConfig{
Athena: c.athena,
QueryID: queryID,
// todo add check for ddl queries to not skip header(#10)
Expand All @@ -56,13 +56,13 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
}

// startQuery starts an Athena query and returns its ID.
func (c *conn) startQuery(query string) (string, error) {
resp, err := c.athena.StartQueryExecution(&athena.StartQueryExecutionInput{
func (c *conn) startQuery(ctx context.Context, query string) (string, error) {
resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{
QueryString: aws.String(query),
QueryExecutionContext: &athena.QueryExecutionContext{
QueryExecutionContext: &types.QueryExecutionContext{
Database: aws.String(c.db),
},
ResultConfiguration: &athena.ResultConfiguration{
ResultConfiguration: &types.ResultConfiguration{
OutputLocation: aws.String(c.OutputLocation),
},
})
Expand All @@ -76,28 +76,28 @@ func (c *conn) startQuery(query string) (string, error) {
// waitOnQuery blocks until a query finishes, returning an error if it failed.
func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
for {
statusResp, err := c.athena.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})
if err != nil {
return err
}

switch *statusResp.QueryExecution.Status.State {
case athena.QueryExecutionStateCancelled:
switch statusResp.QueryExecution.Status.State {
case types.QueryExecutionStateCancelled:
return context.Canceled
case athena.QueryExecutionStateFailed:
case types.QueryExecutionStateFailed:
reason := *statusResp.QueryExecution.Status.StateChangeReason
return errors.New(reason)
case athena.QueryExecutionStateSucceeded:
case types.QueryExecutionStateSucceeded:
return nil
case athena.QueryExecutionStateQueued:
case athena.QueryExecutionStateRunning:
case types.QueryExecutionStateQueued:
case types.QueryExecutionStateRunning:
}

select {
case <-ctx.Done():
c.athena.StopQueryExecution(&athena.StopQueryExecutionInput{
c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})

Expand All @@ -109,7 +109,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
panic("Athena doesn't support prepared statements")
panic("The go-athena driver doesn't support prepared statements yet")
}

func (c *conn) Begin() (driver.Tx, error) {
Expand Down
52 changes: 28 additions & 24 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -35,8 +35,9 @@ func init() {
}

func TestQuery(t *testing.T) {
harness := setup(t)
// defer harness.teardown()
ctx := context.Background()
harness := setup(ctx, t)
// defer harness.teardown(ctx)

expected := []dummyRow{
{
Expand Down Expand Up @@ -77,9 +78,9 @@ func TestQuery(t *testing.T) {
},
}
expectedTypeNames := []string{"varchar", "smallint", "integer", "bigint", "boolean", "float", "double", "varchar", "timestamp", "date", "decimal"}
harness.uploadData(expected)
harness.uploadData(ctx, expected)

rows := harness.mustQuery("select * from %s", harness.table)
rows := harness.mustQuery(ctx, "select * from %s", harness.table)
index := -1
for rows.Next() {
index++
Expand Down Expand Up @@ -115,8 +116,10 @@ func TestQuery(t *testing.T) {
}

func TestOpen(t *testing.T) {
db, err := Open(Config{
Session: session.Must(session.NewSession()),
awsConfig, err := config.LoadDefaultConfig(context.Background())
require.NoError(t, err, "LoadDefaultConfig")
db, err := Open(DriverConfig{
Config: &awsConfig,
Database: AthenaDatabase,
OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket),
})
Expand All @@ -143,28 +146,29 @@ type dummyRow struct {
type athenaHarness struct {
t *testing.T
db *sql.DB
s3 *s3.S3
s3 *s3.Client

table string
}

func setup(t *testing.T) *athenaHarness {
harness := athenaHarness{t: t, s3: s3.New(session.New())}
func setup(ctx context.Context, t *testing.T) *athenaHarness {
awsConfig, err := config.LoadDefaultConfig(ctx)
require.NoError(t, err)
harness := athenaHarness{t: t, s3: s3.NewFromConfig(awsConfig)}

var err error
harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3://%s/output", AthenaDatabase, S3Bucket))
require.NoError(t, err)

harness.setupTable()
harness.setupTable(ctx)

return &harness
}

func (a *athenaHarness) setupTable() {
func (a *athenaHarness) setupTable(ctx context.Context) {
// tables cannot start with numbers or contain dashes
id := uuid.NewV4()
a.table = "t_" + strings.Replace(id.String(), "-", "_", -1)
a.mustExec(`CREATE EXTERNAL TABLE %[1]s (
a.mustExec(ctx, `CREATE EXTERNAL TABLE %[1]s (
nullValue string,
smallintType smallint,
intType int,
Expand All @@ -184,32 +188,32 @@ WITH SERDEPROPERTIES (
fmt.Printf("created table: %s", a.table)
}

func (a *athenaHarness) teardown() {
a.mustExec("drop table %s", a.table)
func (a *athenaHarness) teardown(ctx context.Context) {
a.mustExec(ctx, "drop table %s", a.table)
}

func (a *athenaHarness) mustExec(sql string, args ...interface{}) {
func (a *athenaHarness) mustExec(ctx context.Context, sql string, args ...interface{}) {
query := fmt.Sprintf(sql, args...)
_, err := a.db.ExecContext(context.TODO(), query)
_, err := a.db.ExecContext(ctx, query)
require.NoError(a.t, err, query)
}

func (a *athenaHarness) mustQuery(sql string, args ...interface{}) *sql.Rows {
func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...interface{}) *sql.Rows {
query := fmt.Sprintf(sql, args...)
rows, err := a.db.QueryContext(context.TODO(), query)
rows, err := a.db.QueryContext(ctx, query)
require.NoError(a.t, err, query)
return rows
}

func (a *athenaHarness) uploadData(rows []dummyRow) {
func (a *athenaHarness) uploadData(ctx context.Context, rows []dummyRow) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
for _, row := range rows {
err := enc.Encode(row)
require.NoError(a.t, err)
}

_, err := a.s3.PutObject(&s3.PutObjectInput{
_, err := a.s3.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(S3Bucket),
Key: aws.String(fmt.Sprintf("%s/fixture.json", a.table)),
Body: bytes.NewReader(buf.Bytes()),
Expand Down
40 changes: 21 additions & 19 deletions driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package athena

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand All @@ -9,9 +10,9 @@ import (
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/athena"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/athena"
)

var (
Expand All @@ -21,15 +22,15 @@ var (

// Driver is a sql.Driver. It's intended for db/sql.Open().
type Driver struct {
cfg *Config
cfg *DriverConfig
}

// NewDriver allows you to register your own driver with `sql.Register`.
// It's useful for more complex use cases. Read more in PR #3.
// https://github.com/segmentio/go-athena/pull/3
//
// Generally, sql.Open() or athena.Open() should suffice.
func NewDriver(cfg *Config) *Driver {
func NewDriver(cfg *DriverConfig) *Driver {
return &Driver{cfg}
}

Expand Down Expand Up @@ -65,7 +66,8 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
cfg := d.cfg
if cfg == nil {
var err error
cfg, err = configFromConnectionString(connStr)
// TODO: Implement DriverContext to get proper access to context
cfg, err = configFromConnectionString(context.TODO(), connStr)
if err != nil {
return nil, err
}
Expand All @@ -76,7 +78,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
}

return &conn{
athena: athena.New(cfg.Session),
athena: athena.NewFromConfig(*cfg.Config),
db: cfg.Database,
OutputLocation: cfg.OutputLocation,
pollFrequency: cfg.PollFrequency,
Expand All @@ -86,7 +88,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
// Open is a more robust version of `db.Open`, as it accepts a raw aws.Session.
// This is useful if you have a complex AWS session since the driver doesn't
// currently attempt to serialize all options into a string.
func Open(cfg Config) (*sql.DB, error) {
func Open(cfg DriverConfig) (*sql.DB, error) {
if cfg.Database == "" {
return nil, errors.New("db is required")
}
Expand All @@ -95,8 +97,8 @@ func Open(cfg Config) (*sql.DB, error) {
return nil, errors.New("s3_staging_url is required")
}

if cfg.Session == nil {
return nil, errors.New("session is required")
if cfg.Config == nil {
return nil, errors.New("AWS config is required")
}

// This hack was copied from jackc/pgx. Sorry :(
Expand All @@ -111,30 +113,30 @@ func Open(cfg Config) (*sql.DB, error) {
}

// Config is the input to Open().
type Config struct {
Session *session.Session
type DriverConfig struct {
Config *aws.Config
Database string
OutputLocation string

PollFrequency time.Duration
}

func configFromConnectionString(connStr string) (*Config, error) {
func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) {
args, err := url.ParseQuery(connStr)
if err != nil {
return nil, err
}

var cfg Config
var cfg DriverConfig

var acfg []*aws.Config
if region := args.Get("region"); region != "" {
acfg = append(acfg, &aws.Config{Region: aws.String(region)})
}
cfg.Session, err = session.NewSession(acfg...)
awsConfig, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
if region := args.Get("region"); region != "" {
awsConfig.Region = region
}
cfg.Config = &awsConfig

cfg.Database = args.Get("db")
cfg.OutputLocation = args.Get("output_location")
Expand Down
21 changes: 19 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@ module github.com/segmentio/go-athena
go 1.21

require (
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/aws-sdk-go-v2 v1.30.4
github.com/aws/aws-sdk-go-v2/config v1.27.30
github.com/aws/aws-sdk-go-v2/service/athena v1.44.5
github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1
github.com/satori/go.uuid v1.2.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.29 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit 76f4666

Please sign in to comment.