diff --git a/conn.go b/conn.go index 5eb17df..556d661 100644 --- a/conn.go +++ b/conn.go @@ -15,6 +15,7 @@ type conn struct { athena athenaiface.AthenaAPI db string OutputLocation string + workgroup string pollFrequency time.Duration } @@ -65,6 +66,7 @@ func (c *conn) startQuery(query string) (string, error) { ResultConfiguration: &athena.ResultConfiguration{ OutputLocation: aws.String(c.OutputLocation), }, + WorkGroup: aws.String(c.workgroup), }) if err != nil { return "", err diff --git a/driver.go b/driver.go index 2942bb7..90d6242 100644 --- a/driver.go +++ b/driver.go @@ -58,6 +58,9 @@ func init() { // - `region` (optional) // Override AWS region. Useful if it is not set with environment variable. // +// - `workgroup` (optional) +// Athena's workgroup. This defaults to "primary". +// // Credentials must be accessible via the SDK's Default Credential Provider Chain. // For more advanced AWS credentials/session/config management, please supply // a custom AWS session directly via `athena.Open()`. @@ -80,6 +83,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { db: cfg.Database, OutputLocation: cfg.OutputLocation, pollFrequency: cfg.PollFrequency, + workgroup: cfg.WorkGroup, }, nil } @@ -99,6 +103,10 @@ func Open(cfg Config) (*sql.DB, error) { return nil, errors.New("session is required") } + if cfg.WorkGroup == "" { + cfg.WorkGroup = "primary" + } + // This hack was copied from jackc/pgx. Sorry :( // https://github.com/jackc/pgx/blob/70a284f4f33a9cc28fd1223f6b83fb00deecfe33/stdlib/sql.go#L130-L136 openFromSessionMutex.Lock() @@ -115,6 +123,7 @@ type Config struct { Session *session.Session Database string OutputLocation string + WorkGroup string PollFrequency time.Duration } @@ -138,6 +147,10 @@ func configFromConnectionString(connStr string) (*Config, error) { cfg.Database = args.Get("db") cfg.OutputLocation = args.Get("output_location") + cfg.WorkGroup = args.Get("workgroup") + if cfg.WorkGroup == "" { + cfg.WorkGroup = "primary" + } frequencyStr := args.Get("poll_frequency") if frequencyStr != "" {