diff --git a/api.go b/api.go new file mode 100644 index 0000000..294a01f --- /dev/null +++ b/api.go @@ -0,0 +1,14 @@ +package athena + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/athena" +) + +type athenaAPI interface { + GetQueryExecution(context.Context, *athena.GetQueryExecutionInput, ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) + GetQueryResults(context.Context, *athena.GetQueryResultsInput, ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) + StartQueryExecution(context.Context, *athena.StartQueryExecutionInput, ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) + StopQueryExecution(context.Context, *athena.StopQueryExecutionInput, ...func(*athena.Options)) (*athena.StopQueryExecutionOutput, error) +} diff --git a/conn.go b/conn.go index 6c8f9d3..b131e4c 100644 --- a/conn.go +++ b/conn.go @@ -6,13 +6,13 @@ import ( "errors" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/athena" - "github.com/aws/aws-sdk-go/service/athena/athenaiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/athena" + "github.com/aws/aws-sdk-go-v2/service/athena/types" ) type conn struct { - athena athenaiface.AthenaAPI + athena athenaAPI db string OutputLocation string @@ -38,7 +38,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) { - queryID, err := c.startQuery(query) + queryID, err := c.startQuery(ctx, query) if err != nil { return nil, err } @@ -47,7 +47,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) return nil, err } - return newRows(rowsConfig{ + return newRows(ctx, rowsConfig{ Athena: c.athena, QueryID: queryID, // todo add check for ddl queries to not skip header(#10) @@ -56,13 +56,13 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) } // startQuery starts an Athena query and returns its ID. -func (c *conn) startQuery(query string) (string, error) { - resp, err := c.athena.StartQueryExecution(&athena.StartQueryExecutionInput{ +func (c *conn) startQuery(ctx context.Context, query string) (string, error) { + resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ QueryString: aws.String(query), - QueryExecutionContext: &athena.QueryExecutionContext{ + QueryExecutionContext: &types.QueryExecutionContext{ Database: aws.String(c.db), }, - ResultConfiguration: &athena.ResultConfiguration{ + ResultConfiguration: &types.ResultConfiguration{ OutputLocation: aws.String(c.OutputLocation), }, }) @@ -76,28 +76,28 @@ func (c *conn) startQuery(query string) (string, error) { // waitOnQuery blocks until a query finishes, returning an error if it failed. func (c *conn) waitOnQuery(ctx context.Context, queryID string) error { for { - statusResp, err := c.athena.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{ + statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(queryID), }) if err != nil { return err } - switch *statusResp.QueryExecution.Status.State { - case athena.QueryExecutionStateCancelled: + switch statusResp.QueryExecution.Status.State { + case types.QueryExecutionStateCancelled: return context.Canceled - case athena.QueryExecutionStateFailed: + case types.QueryExecutionStateFailed: reason := *statusResp.QueryExecution.Status.StateChangeReason return errors.New(reason) - case athena.QueryExecutionStateSucceeded: + case types.QueryExecutionStateSucceeded: return nil - case athena.QueryExecutionStateQueued: - case athena.QueryExecutionStateRunning: + case types.QueryExecutionStateQueued: + case types.QueryExecutionStateRunning: } select { case <-ctx.Done(): - c.athena.StopQueryExecution(&athena.StopQueryExecutionInput{ + c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{ QueryExecutionId: aws.String(queryID), }) @@ -109,7 +109,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error { } func (c *conn) Prepare(query string) (driver.Stmt, error) { - panic("Athena doesn't support prepared statements") + panic("The go-athena driver doesn't support prepared statements yet") } func (c *conn) Begin() (driver.Tx, error) { diff --git a/db_test.go b/db_test.go index 2684f8a..a784853 100644 --- a/db_test.go +++ b/db_test.go @@ -11,9 +11,9 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,8 +35,9 @@ func init() { } func TestQuery(t *testing.T) { - harness := setup(t) - // defer harness.teardown() + ctx := context.Background() + harness := setup(ctx, t) + // defer harness.teardown(ctx) expected := []dummyRow{ { @@ -77,9 +78,9 @@ func TestQuery(t *testing.T) { }, } expectedTypeNames := []string{"varchar", "smallint", "integer", "bigint", "boolean", "float", "double", "varchar", "timestamp", "date", "decimal"} - harness.uploadData(expected) + harness.uploadData(ctx, expected) - rows := harness.mustQuery("select * from %s", harness.table) + rows := harness.mustQuery(ctx, "select * from %s", harness.table) index := -1 for rows.Next() { index++ @@ -115,8 +116,10 @@ func TestQuery(t *testing.T) { } func TestOpen(t *testing.T) { - db, err := Open(Config{ - Session: session.Must(session.NewSession()), + awsConfig, err := config.LoadDefaultConfig(context.Background()) + require.NoError(t, err, "LoadDefaultConfig") + db, err := Open(DriverConfig{ + Config: &awsConfig, Database: AthenaDatabase, OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket), }) @@ -143,28 +146,29 @@ type dummyRow struct { type athenaHarness struct { t *testing.T db *sql.DB - s3 *s3.S3 + s3 *s3.Client table string } -func setup(t *testing.T) *athenaHarness { - harness := athenaHarness{t: t, s3: s3.New(session.New())} +func setup(ctx context.Context, t *testing.T) *athenaHarness { + awsConfig, err := config.LoadDefaultConfig(ctx) + require.NoError(t, err) + harness := athenaHarness{t: t, s3: s3.NewFromConfig(awsConfig)} - var err error harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3://%s/output", AthenaDatabase, S3Bucket)) require.NoError(t, err) - harness.setupTable() + harness.setupTable(ctx) return &harness } -func (a *athenaHarness) setupTable() { +func (a *athenaHarness) setupTable(ctx context.Context) { // tables cannot start with numbers or contain dashes id := uuid.NewV4() a.table = "t_" + strings.Replace(id.String(), "-", "_", -1) - a.mustExec(`CREATE EXTERNAL TABLE %[1]s ( + a.mustExec(ctx, `CREATE EXTERNAL TABLE %[1]s ( nullValue string, smallintType smallint, intType int, @@ -184,24 +188,24 @@ WITH SERDEPROPERTIES ( fmt.Printf("created table: %s", a.table) } -func (a *athenaHarness) teardown() { - a.mustExec("drop table %s", a.table) +func (a *athenaHarness) teardown(ctx context.Context) { + a.mustExec(ctx, "drop table %s", a.table) } -func (a *athenaHarness) mustExec(sql string, args ...interface{}) { +func (a *athenaHarness) mustExec(ctx context.Context, sql string, args ...interface{}) { query := fmt.Sprintf(sql, args...) - _, err := a.db.ExecContext(context.TODO(), query) + _, err := a.db.ExecContext(ctx, query) require.NoError(a.t, err, query) } -func (a *athenaHarness) mustQuery(sql string, args ...interface{}) *sql.Rows { +func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...interface{}) *sql.Rows { query := fmt.Sprintf(sql, args...) - rows, err := a.db.QueryContext(context.TODO(), query) + rows, err := a.db.QueryContext(ctx, query) require.NoError(a.t, err, query) return rows } -func (a *athenaHarness) uploadData(rows []dummyRow) { +func (a *athenaHarness) uploadData(ctx context.Context, rows []dummyRow) { var buf bytes.Buffer enc := json.NewEncoder(&buf) for _, row := range rows { @@ -209,7 +213,7 @@ func (a *athenaHarness) uploadData(rows []dummyRow) { require.NoError(a.t, err) } - _, err := a.s3.PutObject(&s3.PutObjectInput{ + _, err := a.s3.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(S3Bucket), Key: aws.String(fmt.Sprintf("%s/fixture.json", a.table)), Body: bytes.NewReader(buf.Bytes()), diff --git a/driver.go b/driver.go index 2942bb7..03dceca 100644 --- a/driver.go +++ b/driver.go @@ -1,6 +1,7 @@ package athena import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -9,9 +10,9 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/athena" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/athena" ) var ( @@ -21,7 +22,7 @@ var ( // Driver is a sql.Driver. It's intended for db/sql.Open(). type Driver struct { - cfg *Config + cfg *DriverConfig } // NewDriver allows you to register your own driver with `sql.Register`. @@ -29,7 +30,7 @@ type Driver struct { // https://github.com/segmentio/go-athena/pull/3 // // Generally, sql.Open() or athena.Open() should suffice. -func NewDriver(cfg *Config) *Driver { +func NewDriver(cfg *DriverConfig) *Driver { return &Driver{cfg} } @@ -65,7 +66,8 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { cfg := d.cfg if cfg == nil { var err error - cfg, err = configFromConnectionString(connStr) + // TODO: Implement DriverContext to get proper access to context + cfg, err = configFromConnectionString(context.TODO(), connStr) if err != nil { return nil, err } @@ -76,7 +78,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { } return &conn{ - athena: athena.New(cfg.Session), + athena: athena.NewFromConfig(*cfg.Config), db: cfg.Database, OutputLocation: cfg.OutputLocation, pollFrequency: cfg.PollFrequency, @@ -86,7 +88,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { // Open is a more robust version of `db.Open`, as it accepts a raw aws.Session. // This is useful if you have a complex AWS session since the driver doesn't // currently attempt to serialize all options into a string. -func Open(cfg Config) (*sql.DB, error) { +func Open(cfg DriverConfig) (*sql.DB, error) { if cfg.Database == "" { return nil, errors.New("db is required") } @@ -95,8 +97,8 @@ func Open(cfg Config) (*sql.DB, error) { return nil, errors.New("s3_staging_url is required") } - if cfg.Session == nil { - return nil, errors.New("session is required") + if cfg.Config == nil { + return nil, errors.New("AWS config is required") } // This hack was copied from jackc/pgx. Sorry :( @@ -111,30 +113,30 @@ func Open(cfg Config) (*sql.DB, error) { } // Config is the input to Open(). -type Config struct { - Session *session.Session +type DriverConfig struct { + Config *aws.Config Database string OutputLocation string PollFrequency time.Duration } -func configFromConnectionString(connStr string) (*Config, error) { +func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) { args, err := url.ParseQuery(connStr) if err != nil { return nil, err } - var cfg Config + var cfg DriverConfig - var acfg []*aws.Config - if region := args.Get("region"); region != "" { - acfg = append(acfg, &aws.Config{Region: aws.String(region)}) - } - cfg.Session, err = session.NewSession(acfg...) + awsConfig, err := config.LoadDefaultConfig(ctx) if err != nil { return nil, err } + if region := args.Get("region"); region != "" { + awsConfig.Region = region + } + cfg.Config = &awsConfig cfg.Database = args.Get("db") cfg.OutputLocation = args.Get("output_location") diff --git a/go.mod b/go.mod index 70f7659..873c6e7 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,31 @@ module github.com/segmentio/go-athena go 1.21 require ( - github.com/aws/aws-sdk-go v1.55.5 + github.com/aws/aws-sdk-go-v2 v1.30.4 + github.com/aws/aws-sdk-go-v2/config v1.27.30 + github.com/aws/aws-sdk-go-v2/service/athena v1.44.5 + github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1 github.com/satori/go.uuid v1.2.0 github.com/stretchr/testify v1.9.0 ) require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.29 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f09d78c..22ee5ec 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,50 @@ -github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= -github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw= +github.com/aws/aws-sdk-go-v2/config v1.27.30 h1:AQF3/+rOgeJBQP3iI4vojlPib5X6eeOYoa/af7OxAYg= +github.com/aws/aws-sdk-go-v2/config v1.27.30/go.mod h1:yxqvuubha9Vw8stEgNiStO+yZpP68Wm9hLmcm+R/Qk4= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29 h1:CwGsupsXIlAFYuDVHv1nnK0wnxO0wZ/g1L8DSK/xiIw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29/go.mod h1:BPJ/yXV92ZVq6G8uYvbU0gSl8q94UB63nMT5ctNO38g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJCkMC0lMy6FaCD51jm6ayE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 h1:mimdLQkIX1zr8GIPY1ZtALdBQGxcASiBd2MOp8m/dMc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16/go.mod h1:YHk6owoSwrIsok+cAH9PENCOGoH5PU2EllX4vLtSrsY= +github.com/aws/aws-sdk-go-v2/service/athena v1.44.5 h1:l6fpIrGjYc8zfeBo3QHWxQf3d8TwIxITJXCLOKEhMWw= +github.com/aws/aws-sdk-go-v2/service/athena v1.44.5/go.mod h1:JKpavcrQ83Uy6ntM2pIt0vfVpHR9kvI3dkUeAKQstpc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 h1:GckUnpm4EJOAio1c8o25a+b3lVfwVzC9gnSBqiiNmZM= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18/go.mod h1:Br6+bxfG33Dk3ynmkhsW2Z/t9D4+lRqdLDNCKi85w0U= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 h1:jg16PhLPUiHIj8zYIW6bqzeQSuHVEiWnGA0Brz5Xv2I= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16/go.mod h1:Uyk1zE1VVdsHSU7096h/rwnXDzOzYQVl+FNPhPw7ShY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1 h1:mx2ucgtv+MWzJesJY9Ig/8AFHgoE5FwLXwUVgW/FGdI= +github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1/go.mod h1:BSPI0EfnYUuNHPS0uqIo5VrRwzie+Fp+YhQOUs16sKI= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5/go.mod h1:ZeDX1SnKsVlejeuz41GiajjZpRSWR7/42q/EyA/QEiM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 h1:SKvPgvdvmiTWoi0GAJ7AsJfOz3ngVkD/ERbs5pUnHNI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5/go.mod h1:20sz31hv/WsPa3HhU3hfrIet2kxM4Pe0r20eBZ20Tac= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wINh+4UK+k/0Yo/q8= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rows.go b/rows.go index f010cb9..83045bc 100644 --- a/rows.go +++ b/rows.go @@ -1,16 +1,16 @@ package athena import ( + "context" "database/sql/driver" "io" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/athena" - "github.com/aws/aws-sdk-go/service/athena/athenaiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/athena" ) type rows struct { - athena athenaiface.AthenaAPI + athena athenaAPI queryID string done bool @@ -19,19 +19,19 @@ type rows struct { } type rowsConfig struct { - Athena athenaiface.AthenaAPI + Athena athenaAPI QueryID string SkipHeader bool } -func newRows(cfg rowsConfig) (*rows, error) { +func newRows(ctx context.Context, cfg rowsConfig) (*rows, error) { r := rows{ athena: cfg.Athena, queryID: cfg.QueryID, skipHeaderRow: cfg.SkipHeader, } - shouldContinue, err := r.fetchNextPage(nil) + shouldContinue, err := r.fetchNextPage(ctx, nil) if err != nil { return nil, err } @@ -69,7 +69,9 @@ func (r *rows) Next(dest []driver.Value) error { return io.EOF } - cont, err := r.fetchNextPage(r.out.NextToken) + // A context cannot be passed into the Next function because it is defined + // in the database.sql.driver.Rows interface. + cont, err := r.fetchNextPage(context.Background(), r.out.NextToken) if err != nil { return err } @@ -90,9 +92,9 @@ func (r *rows) Next(dest []driver.Value) error { return nil } -func (r *rows) fetchNextPage(token *string) (bool, error) { +func (r *rows) fetchNextPage(ctx context.Context, token *string) (bool, error) { var err error - r.out, err = r.athena.GetQueryResults(&athena.GetQueryResultsInput{ + r.out, err = r.athena.GetQueryResults(ctx, &athena.GetQueryResultsInput{ QueryExecutionId: aws.String(r.queryID), NextToken: token, }) diff --git a/rows_test.go b/rows_test.go index 8226fba..a6bcbfe 100644 --- a/rows_test.go +++ b/rows_test.go @@ -1,14 +1,15 @@ package athena import ( + "context" "database/sql/driver" "errors" "io" "math/rand" "testing" - "github.com/aws/aws-sdk-go/service/athena" - "github.com/aws/aws-sdk-go/service/athena/athenaiface" + "github.com/aws/aws-sdk-go-v2/service/athena" + "github.com/aws/aws-sdk-go-v2/service/athena/types" "github.com/stretchr/testify/assert" ) @@ -22,22 +23,22 @@ var queryToResultsGenMap = map[string]genQueryResultsOutputByToken{ "iteration_fail": dummyFailedIterationResponse, } -func genColumnInfo(column string) *athena.ColumnInfo { +func genColumnInfo(column string) types.ColumnInfo { caseSensitive := true catalogName := "hive" - nullable := "UNKNOWN" - precision := int64(2147483647) - scale := int64(0) + nullable := types.ColumnNullableUnknown + precision := int32(2147483647) + scale := int32(0) schemaName := "" tableName := "" columnType := "varchar" - return &athena.ColumnInfo{ - CaseSensitive: &caseSensitive, + return types.ColumnInfo{ + CaseSensitive: caseSensitive, CatalogName: &catalogName, - Nullable: &nullable, - Precision: &precision, - Scale: &scale, + Nullable: nullable, + Precision: precision, + Scale: scale, SchemaName: &schemaName, TableName: &tableName, Type: &columnType, @@ -55,21 +56,21 @@ func randomString() string { return string(s) } -func genRow(isHeader bool, columns []*athena.ColumnInfo) *athena.Row { - var data []*athena.Datum +func genRow(isHeader bool, columns []types.ColumnInfo) types.Row { + var data []types.Datum for i := 0; i < len(columns); i++ { if isHeader { - data = append(data, &athena.Datum{ + data = append(data, types.Datum{ VarCharValue: columns[i].Name, }) } else { s := randomString() - data = append(data, &athena.Datum{ + data = append(data, types.Datum{ VarCharValue: &s, }) } } - return &athena.Row{ + return types.Row{ Data: data, } } @@ -78,17 +79,17 @@ func dummySelectQueryResponse(token string) (*athena.GetQueryResultsOutput, erro switch token { case "": var nextToken = "page_1" - columns := []*athena.ColumnInfo{ + columns := []types.ColumnInfo{ genColumnInfo("first_name"), genColumnInfo("last_name"), } return &athena.GetQueryResultsOutput{ NextToken: &nextToken, - ResultSet: &athena.ResultSet{ - ResultSetMetadata: &athena.ResultSetMetadata{ + ResultSet: &types.ResultSet{ + ResultSetMetadata: &types.ResultSetMetadata{ ColumnInfo: columns, }, - Rows: []*athena.Row{ + Rows: []types.Row{ genRow(true, columns), genRow(false, columns), genRow(false, columns), @@ -98,16 +99,16 @@ func dummySelectQueryResponse(token string) (*athena.GetQueryResultsOutput, erro }, }, nil case "page_1": - columns := []*athena.ColumnInfo{ + columns := []types.ColumnInfo{ genColumnInfo("first_name"), genColumnInfo("last_name"), } return &athena.GetQueryResultsOutput{ - ResultSet: &athena.ResultSet{ - ResultSetMetadata: &athena.ResultSetMetadata{ + ResultSet: &types.ResultSet{ + ResultSetMetadata: &types.ResultSetMetadata{ ColumnInfo: columns, }, - Rows: []*athena.Row{ + Rows: []types.Row{ genRow(false, columns), genRow(false, columns), genRow(false, columns), @@ -122,15 +123,15 @@ func dummySelectQueryResponse(token string) (*athena.GetQueryResultsOutput, erro } func dummyShowResponse(_ string) (*athena.GetQueryResultsOutput, error) { - columns := []*athena.ColumnInfo{ + columns := []types.ColumnInfo{ genColumnInfo("partition"), } return &athena.GetQueryResultsOutput{ - ResultSet: &athena.ResultSet{ - ResultSetMetadata: &athena.ResultSetMetadata{ + ResultSet: &types.ResultSet{ + ResultSetMetadata: &types.ResultSetMetadata{ ColumnInfo: columns, }, - Rows: []*athena.Row{ + Rows: []types.Row{ genRow(false, columns), genRow(false, columns), }, @@ -142,17 +143,17 @@ func dummyFailedIterationResponse(token string) (*athena.GetQueryResultsOutput, switch token { case "": var nextToken = "page_1" - columns := []*athena.ColumnInfo{ + columns := []types.ColumnInfo{ genColumnInfo("first_name"), genColumnInfo("last_name"), } return &athena.GetQueryResultsOutput{ NextToken: &nextToken, - ResultSet: &athena.ResultSet{ - ResultSetMetadata: &athena.ResultSetMetadata{ + ResultSet: &types.ResultSet{ + ResultSetMetadata: &types.ResultSetMetadata{ ColumnInfo: columns, }, - Rows: []*athena.Row{ + Rows: []types.Row{ genRow(true, columns), genRow(false, columns), genRow(false, columns), @@ -167,10 +168,10 @@ func dummyFailedIterationResponse(token string) (*athena.GetQueryResultsOutput, } type mockAthenaClient struct { - athenaiface.AthenaAPI + athenaAPI } -func (m *mockAthenaClient) GetQueryResults(query *athena.GetQueryResultsInput) (*athena.GetQueryResultsOutput, error) { +func (m *mockAthenaClient) GetQueryResults(ctx context.Context, query *athena.GetQueryResultsInput, opts ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) { var nextToken = "" if query.NextToken != nil { nextToken = *query.NextToken @@ -211,8 +212,9 @@ func TestRows_Next(t *testing.T) { expectedError: dummyError, }, } + ctx := context.Background() for _, test := range tests { - r, _ := newRows(rowsConfig{ + r, _ := newRows(ctx, rowsConfig{ Athena: new(mockAthenaClient), QueryID: test.queryID, SkipHeader: test.skipHeader, diff --git a/value.go b/value.go index 077533a..3e46418 100644 --- a/value.go +++ b/value.go @@ -6,7 +6,7 @@ import ( "strconv" "time" - "github.com/aws/aws-sdk-go/service/athena" + "github.com/aws/aws-sdk-go-v2/service/athena/types" ) const ( @@ -16,7 +16,7 @@ const ( DateLayout = "2006-01-02" ) -func convertRow(columns []*athena.ColumnInfo, in []*athena.Datum, ret []driver.Value) error { +func convertRow(columns []types.ColumnInfo, in []types.Datum, ret []driver.Value) error { for i, val := range in { coerced, err := convertValue(*columns[i].Type, val.VarCharValue) if err != nil {