Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Upgrade to aws-sdk-go-v2 #37

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package athena

import (
"context"

"github.com/aws/aws-sdk-go-v2/service/athena"
)

type athenaAPI interface {
GetQueryExecution(context.Context, *athena.GetQueryExecutionInput, ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error)
GetQueryResults(context.Context, *athena.GetQueryResultsInput, ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error)
StartQueryExecution(context.Context, *athena.StartQueryExecutionInput, ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error)
StopQueryExecution(context.Context, *athena.StopQueryExecutionInput, ...func(*athena.Options)) (*athena.StopQueryExecutionOutput, error)
}
38 changes: 19 additions & 19 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"errors"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/athena"
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/athena"
"github.com/aws/aws-sdk-go-v2/service/athena/types"
)

type conn struct {
athena athenaiface.AthenaAPI
athena athenaAPI
db string
OutputLocation string

Expand All @@ -38,7 +38,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}

func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) {
queryID, err := c.startQuery(query)
queryID, err := c.startQuery(ctx, query)
if err != nil {
return nil, err
}
Expand All @@ -47,7 +47,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
return nil, err
}

return newRows(rowsConfig{
return newRows(ctx, rowsConfig{
Athena: c.athena,
QueryID: queryID,
// todo add check for ddl queries to not skip header(#10)
Expand All @@ -56,13 +56,13 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
}

// startQuery starts an Athena query and returns its ID.
func (c *conn) startQuery(query string) (string, error) {
resp, err := c.athena.StartQueryExecution(&athena.StartQueryExecutionInput{
func (c *conn) startQuery(ctx context.Context, query string) (string, error) {
resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{
QueryString: aws.String(query),
QueryExecutionContext: &athena.QueryExecutionContext{
QueryExecutionContext: &types.QueryExecutionContext{
Database: aws.String(c.db),
},
ResultConfiguration: &athena.ResultConfiguration{
ResultConfiguration: &types.ResultConfiguration{
OutputLocation: aws.String(c.OutputLocation),
},
})
Expand All @@ -76,28 +76,28 @@ func (c *conn) startQuery(query string) (string, error) {
// waitOnQuery blocks until a query finishes, returning an error if it failed.
func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
for {
statusResp, err := c.athena.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})
if err != nil {
return err
}

switch *statusResp.QueryExecution.Status.State {
case athena.QueryExecutionStateCancelled:
switch statusResp.QueryExecution.Status.State {
case types.QueryExecutionStateCancelled:
return context.Canceled
case athena.QueryExecutionStateFailed:
case types.QueryExecutionStateFailed:
reason := *statusResp.QueryExecution.Status.StateChangeReason
return errors.New(reason)
case athena.QueryExecutionStateSucceeded:
case types.QueryExecutionStateSucceeded:
return nil
case athena.QueryExecutionStateQueued:
case athena.QueryExecutionStateRunning:
case types.QueryExecutionStateQueued:
case types.QueryExecutionStateRunning:
}

select {
case <-ctx.Done():
c.athena.StopQueryExecution(&athena.StopQueryExecutionInput{
c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})

Expand All @@ -109,7 +109,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
panic("Athena doesn't support prepared statements")
panic("The go-athena driver doesn't support prepared statements yet")
}

func (c *conn) Begin() (driver.Tx, error) {
Expand Down
52 changes: 28 additions & 24 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -35,8 +35,9 @@ func init() {
}

func TestQuery(t *testing.T) {
harness := setup(t)
// defer harness.teardown()
ctx := context.Background()
harness := setup(ctx, t)
// defer harness.teardown(ctx)

expected := []dummyRow{
{
Expand Down Expand Up @@ -77,9 +78,9 @@ func TestQuery(t *testing.T) {
},
}
expectedTypeNames := []string{"varchar", "smallint", "integer", "bigint", "boolean", "float", "double", "varchar", "timestamp", "date", "decimal"}
harness.uploadData(expected)
harness.uploadData(ctx, expected)

rows := harness.mustQuery("select * from %s", harness.table)
rows := harness.mustQuery(ctx, "select * from %s", harness.table)
index := -1
for rows.Next() {
index++
Expand Down Expand Up @@ -115,8 +116,10 @@ func TestQuery(t *testing.T) {
}

func TestOpen(t *testing.T) {
db, err := Open(Config{
Session: session.Must(session.NewSession()),
awsConfig, err := config.LoadDefaultConfig(context.Background())
require.NoError(t, err, "LoadDefaultConfig")
db, err := Open(DriverConfig{
Config: &awsConfig,
Database: AthenaDatabase,
OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket),
})
Expand All @@ -143,28 +146,29 @@ type dummyRow struct {
type athenaHarness struct {
t *testing.T
db *sql.DB
s3 *s3.S3
s3 *s3.Client

table string
}

func setup(t *testing.T) *athenaHarness {
harness := athenaHarness{t: t, s3: s3.New(session.New())}
func setup(ctx context.Context, t *testing.T) *athenaHarness {
awsConfig, err := config.LoadDefaultConfig(ctx)
require.NoError(t, err)
harness := athenaHarness{t: t, s3: s3.NewFromConfig(awsConfig)}

var err error
harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3://%s/output", AthenaDatabase, S3Bucket))
require.NoError(t, err)

harness.setupTable()
harness.setupTable(ctx)

return &harness
}

func (a *athenaHarness) setupTable() {
func (a *athenaHarness) setupTable(ctx context.Context) {
// tables cannot start with numbers or contain dashes
id := uuid.NewV4()
a.table = "t_" + strings.Replace(id.String(), "-", "_", -1)
a.mustExec(`CREATE EXTERNAL TABLE %[1]s (
a.mustExec(ctx, `CREATE EXTERNAL TABLE %[1]s (
nullValue string,
smallintType smallint,
intType int,
Expand All @@ -184,32 +188,32 @@ WITH SERDEPROPERTIES (
fmt.Printf("created table: %s", a.table)
}

func (a *athenaHarness) teardown() {
a.mustExec("drop table %s", a.table)
func (a *athenaHarness) teardown(ctx context.Context) {
a.mustExec(ctx, "drop table %s", a.table)
}

func (a *athenaHarness) mustExec(sql string, args ...interface{}) {
func (a *athenaHarness) mustExec(ctx context.Context, sql string, args ...interface{}) {
query := fmt.Sprintf(sql, args...)
_, err := a.db.ExecContext(context.TODO(), query)
_, err := a.db.ExecContext(ctx, query)
require.NoError(a.t, err, query)
}

func (a *athenaHarness) mustQuery(sql string, args ...interface{}) *sql.Rows {
func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...interface{}) *sql.Rows {
query := fmt.Sprintf(sql, args...)
rows, err := a.db.QueryContext(context.TODO(), query)
rows, err := a.db.QueryContext(ctx, query)
require.NoError(a.t, err, query)
return rows
}

func (a *athenaHarness) uploadData(rows []dummyRow) {
func (a *athenaHarness) uploadData(ctx context.Context, rows []dummyRow) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
for _, row := range rows {
err := enc.Encode(row)
require.NoError(a.t, err)
}

_, err := a.s3.PutObject(&s3.PutObjectInput{
_, err := a.s3.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(S3Bucket),
Key: aws.String(fmt.Sprintf("%s/fixture.json", a.table)),
Body: bytes.NewReader(buf.Bytes()),
Expand Down
40 changes: 21 additions & 19 deletions driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package athena

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand All @@ -9,9 +10,9 @@ import (
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/athena"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/athena"
)

var (
Expand All @@ -21,15 +22,15 @@ var (

// Driver is a sql.Driver. It's intended for db/sql.Open().
type Driver struct {
cfg *Config
cfg *DriverConfig
}

// NewDriver allows you to register your own driver with `sql.Register`.
// It's useful for more complex use cases. Read more in PR #3.
// https://github.com/segmentio/go-athena/pull/3
//
// Generally, sql.Open() or athena.Open() should suffice.
func NewDriver(cfg *Config) *Driver {
func NewDriver(cfg *DriverConfig) *Driver {
return &Driver{cfg}
}

Expand Down Expand Up @@ -65,7 +66,8 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
cfg := d.cfg
if cfg == nil {
var err error
cfg, err = configFromConnectionString(connStr)
// TODO: Implement DriverContext to get proper access to context
cfg, err = configFromConnectionString(context.TODO(), connStr)
if err != nil {
return nil, err
}
Expand All @@ -76,7 +78,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
}

return &conn{
athena: athena.New(cfg.Session),
athena: athena.NewFromConfig(*cfg.Config),
db: cfg.Database,
OutputLocation: cfg.OutputLocation,
pollFrequency: cfg.PollFrequency,
Expand All @@ -86,7 +88,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
// Open is a more robust version of `db.Open`, as it accepts a raw aws.Session.
// This is useful if you have a complex AWS session since the driver doesn't
// currently attempt to serialize all options into a string.
func Open(cfg Config) (*sql.DB, error) {
func Open(cfg DriverConfig) (*sql.DB, error) {
if cfg.Database == "" {
return nil, errors.New("db is required")
}
Expand All @@ -95,8 +97,8 @@ func Open(cfg Config) (*sql.DB, error) {
return nil, errors.New("s3_staging_url is required")
}

if cfg.Session == nil {
return nil, errors.New("session is required")
if cfg.Config == nil {
return nil, errors.New("AWS config is required")
}

// This hack was copied from jackc/pgx. Sorry :(
Expand All @@ -111,30 +113,30 @@ func Open(cfg Config) (*sql.DB, error) {
}

// Config is the input to Open().
type Config struct {
Session *session.Session
type DriverConfig struct {
Config *aws.Config
Database string
OutputLocation string

PollFrequency time.Duration
}

func configFromConnectionString(connStr string) (*Config, error) {
func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) {
args, err := url.ParseQuery(connStr)
if err != nil {
return nil, err
}

var cfg Config
var cfg DriverConfig

var acfg []*aws.Config
if region := args.Get("region"); region != "" {
acfg = append(acfg, &aws.Config{Region: aws.String(region)})
}
cfg.Session, err = session.NewSession(acfg...)
awsConfig, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
if region := args.Get("region"); region != "" {
awsConfig.Region = region
}
cfg.Config = &awsConfig

cfg.Database = args.Get("db")
cfg.OutputLocation = args.Get("output_location")
Expand Down
21 changes: 19 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@ module github.com/segmentio/go-athena
go 1.21

require (
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/aws-sdk-go-v2 v1.30.4
github.com/aws/aws-sdk-go-v2/config v1.27.30
github.com/aws/aws-sdk-go-v2/service/athena v1.44.5
github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1
github.com/satori/go.uuid v1.2.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.29 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading