diff --git a/utils/misc/dbutils.go b/utils/misc/dbutils.go index d4bc215dde..f0f3b86607 100644 --- a/utils/misc/dbutils.go +++ b/utils/misc/dbutils.go @@ -5,6 +5,7 @@ import ( "fmt" "net/url" "os" + "time" "github.com/lib/pq" @@ -19,6 +20,8 @@ func GetConnectionString(c *config.Config, componentName string) string { port := c.GetInt("DB.port", 5432) password := c.GetString("DB.password", "ubuntu") // Reading secrets from sslmode := c.GetString("DB.sslMode", "disable") + idleTxTimeout := c.GetDuration("DB.IdleTxTimeout", 5, time.Minute) + // Application Name can be any string of less than NAMEDATALEN characters (64 characters in a standard PostgreSQL build). // There is no need to truncate the string on our own though since PostgreSQL auto-truncates this identifier and issues a relevant notice if necessary. appName := DefaultString("rudder-server").OnError(os.Hostname()) @@ -26,8 +29,11 @@ func GetConnectionString(c *config.Config, componentName string) string { appName = fmt.Sprintf("%s-%s", componentName, appName) } return fmt.Sprintf("host=%s port=%d user=%s "+ - "password=%s dbname=%s sslmode=%s application_name=%s", - host, port, user, password, dbname, sslmode, appName) + "password=%s dbname=%s sslmode=%s application_name=%s "+ + " options='-c idle_in_transaction_session_timeout=%d'", + host, port, user, password, dbname, sslmode, appName, + idleTxTimeout.Milliseconds(), + ) } // SetAppNameInDBConnURL sets application name in db connection url diff --git a/utils/misc/dbutils_test.go b/utils/misc/dbutils_test.go index 44ba1bf765..9a3b74b830 100644 --- a/utils/misc/dbutils_test.go +++ b/utils/misc/dbutils_test.go @@ -1,8 +1,16 @@ package misc_test import ( + "database/sql" + "fmt" "testing" + "time" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" "github.com/rudderlabs/rudder-server/utils/misc" ) @@ -58,3 +66,65 @@ func TestSetApplicationNameInDBConnectionURL(t *testing.T) { }) } } + +func TestIdleTxTimeout(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + conf := config.New() + conf.Set("DB.host", postgresContainer.Host) + conf.Set("DB.user", postgresContainer.User) + conf.Set("DB.name", postgresContainer.Database) + conf.Set("DB.port", postgresContainer.Port) + conf.Set("DB.password", postgresContainer.Password) + + txTimeout := 2 * time.Millisecond + + conf.Set("DB.IdleTxTimeout", txTimeout) + + dsn := misc.GetConnectionString(conf, "test") + + db, err := sql.Open("postgres", dsn) + require.NoError(t, err) + + var sessionTimeout string + err = db.QueryRow("SHOW idle_in_transaction_session_timeout;").Scan(&sessionTimeout) + require.NoError(t, err) + require.Equal(t, txTimeout.String(), sessionTimeout) + + t.Run("timeout tx", func(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + + var pid int + err = tx.QueryRow(`select pg_backend_pid();`).Scan(&pid) + require.NoError(t, err) + + _, err = tx.Exec("select 1") + require.NoError(t, err) + t.Log("sleep double the timeout to close connection") + time.Sleep(2 * txTimeout) + + err = tx.Commit() + require.EqualError(t, err, "driver: bad connection") + + var count int + err = db.QueryRow(`SELECT count(*) FROM pg_stat_activity WHERE pid = $1`, pid).Scan(&count) + require.NoError(t, err) + + require.Zero(t, count) + }) + + t.Run("successful tx", func(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + _, err = tx.Exec("select 1") + require.NoError(t, err) + _, err = tx.Exec(fmt.Sprintf("select pg_sleep(%f)", txTimeout.Seconds())) + require.NoError(t, err) + + require.NoError(t, tx.Commit()) + }) +}