Skip to content

Commit

Permalink
WIP: Support PostgreSQL Databases (#437)
Browse files Browse the repository at this point in the history
feat: add postgres driver support

---------

Co-authored-by: Raphael Santo Domingo <[email protected]>
Co-authored-by: JJ Philipp <[email protected]>
Co-authored-by: Adam Shannon <[email protected]>
  • Loading branch information
4 people authored Sep 30, 2024
1 parent c9858a5 commit 12d4ae9
Show file tree
Hide file tree
Showing 18 changed files with 625 additions and 251 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/cgo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ jobs:
run: make check
env:
GOTEST_FLAGS: "-short"

- name: Logs
if: failure() && runner.os == 'Linux'
run: docker compose logs
4 changes: 4 additions & 0 deletions .github/workflows/nocgo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ jobs:
env:
CGO_ENABLED: "0"
GOTEST_FLAGS: "-short"

- name: Logs
if: failure() && runner.os == 'Linux'
run: docker compose logs
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ coverage.txt

*.pyc

.idea/*
.idea/*

testcerts/*
3 changes: 3 additions & 0 deletions .gitleaksignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
testcerts/client.key:private-key:1
testcerts/root.key:private-key:1
testcerts/server.key:private-key:1
18 changes: 14 additions & 4 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,28 @@ func New(ctx context.Context, logger log.Logger, config DatabaseConfig) (*sql.DB
if config.MySQL != nil {
preppedDb, err := mysqlConnection(logger, config.MySQL, config.DatabaseName)
if err != nil {
return nil, err
return nil, fmt.Errorf("configuring mysql connection: %v", err)
}

db, err := preppedDb.Connect(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("connecting to mysql: %w", err)
}

return ApplyConnectionsConfig(db, &config.MySQL.Connections, logger), nil

} else if config.Spanner != nil {
return spannerConnection(logger, *config.Spanner, config.DatabaseName)
db, err := spannerConnection(logger, *config.Spanner, config.DatabaseName)
if err != nil {
return nil, fmt.Errorf("connecting to spanner: %w", err)
}
return db, nil
} else if config.Postgres != nil {
db, err := postgresConnection(ctx, logger, *config.Postgres, config.DatabaseName)
if err != nil {
return nil, fmt.Errorf("connecting to postgres: %w", err)
}
return ApplyConnectionsConfig(db, &config.Postgres.Connections, logger), nil
}

return nil, fmt.Errorf("database config not defined")
Expand Down Expand Up @@ -61,7 +71,7 @@ func NewAndMigrate(ctx context.Context, logger log.Logger, config DatabaseConfig
// UniqueViolation returns true when the provided error matches a database error
// for duplicate entries (violating a unique table constraint).
func UniqueViolation(err error) bool {
return MySQLUniqueViolation(err) || SpannerUniqueViolation(err)
return MySQLUniqueViolation(err) || SpannerUniqueViolation(err) || PostgresUniqueViolation(err)
}

func DataTooLong(err error) bool {
Expand Down
28 changes: 28 additions & 0 deletions database/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
migmysql "github.com/golang-migrate/migrate/v4/database/mysql"
migpostgres "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/iofs"

Expand Down Expand Up @@ -155,6 +156,29 @@ func getDriver(logger log.Logger, config DatabaseConfig, opts *migrateOptions) (
return nil, nil, err
}
}
} else if config.Postgres != nil {
if opts.source == nil {
src, err := NewPkgerSource("postgres", false)
if err != nil {
return nil, nil, err
}
opts.source = &SourceDriver{
name: "pkger-postgres",
Driver: src,
}
}

if opts.driver == nil {
db, err := New(context.Background(), logger, config)
if err != nil {
return nil, nil, err
}

opts.driver, err = PostgresDriver(db)
if err != nil {
return nil, nil, err
}
}
}

if opts.source == nil || opts.driver == nil {
Expand All @@ -172,6 +196,10 @@ func SpannerDriver(config DatabaseConfig) (database.Driver, error) {
return SpannerMigrationDriver(*config.Spanner, config.DatabaseName)
}

func PostgresDriver(db *sql.DB) (database.Driver, error) {
return migpostgres.WithInstance(db, &migpostgres.Config{})
}

type MigrateOption func(o *migrateOptions) error

type SourceDriver struct {
Expand Down
22 changes: 22 additions & 0 deletions database/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
type DatabaseConfig struct {
MySQL *MySQLConfig
Spanner *SpannerConfig
Postgres *PostgresConfig
DatabaseName string
}

Expand All @@ -23,6 +24,27 @@ type SpannerConfig struct {
DisableCleanStatements bool
}

type PostgresConfig struct {
Address string
User string
Password string
Connections ConnectionsConfig
TLS *PostgresTLSConfig
Alloy *PostgresAlloyConfig
}

type PostgresTLSConfig struct {
CACertFile string
ClientKeyFile string
ClientCertFile string
}

type PostgresAlloyConfig struct {
InstanceURI string
UseIAM bool
UsePSC bool
}

type MySQLConfig struct {
Address string
User string
Expand Down
145 changes: 145 additions & 0 deletions database/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package database

import (
"context"
"database/sql"
"errors"
"fmt"
"net"

"cloud.google.com/go/alloydbconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/moov-io/base/log"
)

const (
// PostgreSQL Error Codes
// https://www.postgresql.org/docs/current/errcodes-appendix.html
postgresErrUniqueViolation = "23505"
)

func postgresConnection(ctx context.Context, logger log.Logger, config PostgresConfig, databaseName string) (*sql.DB, error) {
var connStr string
if config.Alloy != nil {
c, err := getAlloyDBConnectorConnStr(ctx, config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating alloydb connection: %w", err).Err()
}
connStr = c
} else {
c, err := getPostgresConnStr(config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating postgres connection: %w", err).Err()
}
connStr = c
}

db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, logger.LogErrorf("opening database: %w", err).Err()
}

err = db.Ping()
if err != nil {
_ = db.Close()
return nil, logger.LogErrorf("connecting to database: %w", err).Err()
}

return db, nil
}

func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) {
url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName)

params := ""

if config.TLS != nil {
params += "sslmode=verify-full"

if config.TLS.CACertFile == "" {
return "", fmt.Errorf("missing TLS CA file")
}
params += "&sslrootcert=" + config.TLS.CACertFile

if config.TLS.ClientCertFile != "" {
params += "&sslcert=" + config.TLS.ClientCertFile
}

if config.TLS.ClientKeyFile != "" {
params += "&sslkey=" + config.TLS.ClientKeyFile
}
}

connStr := fmt.Sprintf("%s?%s", url, params)
return connStr, nil
}

func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, databaseName string) (string, error) {
if config.Alloy == nil {
return "", fmt.Errorf("missing alloy config")
}

var dialer *alloydbconn.Dialer
var dsn string

if config.Alloy.UseIAM {
d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN())
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
// sslmode is disabled because the alloy db connection dialer will handle it
// no password is used with IAM
"user=%s dbname=%s sslmode=disable",
config.User, databaseName,
)
} else {
d, err := alloydbconn.NewDialer(ctx)
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
// sslmode is disabled because the alloy db connection dialer will handle it
"user=%s password=%s dbname=%s sslmode=disable",
config.User, config.Password, databaseName,
)
}

// TODO
//cleanup := func() error { return d.Close() }

connConfig, err := pgx.ParseConfig(dsn)
if err != nil {
return "", fmt.Errorf("failed to parse pgx config: %v", err)
}

var connOptions []alloydbconn.DialOption
if config.Alloy.UsePSC {
connOptions = append(connOptions, alloydbconn.WithPSC())
}

connConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
return dialer.Dial(ctx, config.Alloy.InstanceURI, connOptions...)
}

connStr := stdlib.RegisterConnConfig(connConfig)
return connStr, nil
}

func PostgresUniqueViolation(err error) bool {
if err == nil {
return false
}
var pgError *pgconn.PgError
if errors.As(err, &pgError) {
if pgError.Code == postgresErrUniqueViolation {
return true
}
}
return false
}
Loading

0 comments on commit 12d4ae9

Please sign in to comment.