diff --git a/pkg/schema/base/base.go b/pkg/schema/base/base.go index 20be43e..1da1724 100644 --- a/pkg/schema/base/base.go +++ b/pkg/schema/base/base.go @@ -1,6 +1,7 @@ package base import ( + "database/sql" "fmt" "strings" @@ -165,3 +166,22 @@ func ColumnType(ctx *schema.MigratorContext, options schema.ColumnData) func() s return strBuilder.String() } } + +func (p *Schema) Query(query string, args []interface{}, rowHandler func(row *sql.Rows) error) { + rows, err := p.TX.QueryContext(p.Context.Context, query, args...) + if err != nil { + p.Context.RaiseError(fmt.Errorf("error while querying: %w", err)) + } + + defer rows.Close() + + for rows.Next() { + if err := rowHandler(rows); err != nil { + p.Context.RaiseError(fmt.Errorf("error from privded row handler: %w", err)) + } + } + + if err := rows.Err(); err != nil { + p.Context.RaiseError(fmt.Errorf("error after iterating rows: %w", err)) + } +} diff --git a/pkg/schema/pg/postgres_test.go b/pkg/schema/pg/postgres_test.go index e80ad57..f9ff18e 100644 --- a/pkg/schema/pg/postgres_test.go +++ b/pkg/schema/pg/postgres_test.go @@ -222,3 +222,116 @@ order by column_name;` return columns } + +func TestPostgres_Query(t *testing.T) { + t.Run("basic query execution", func(t *testing.T) { + p, _, schemaName := baseTest(t, "", "tst_pg_query") + + // Create a test table + p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) { + s.Integer("id") + s.String("name") + }) + + // Insert some test data + _, err := p.TX.ExecContext(context.Background(), fmt.Sprintf( + "INSERT INTO %s.test_query (id, name) VALUES ($1, $2), ($3, $4)", + schemaName, + ), 1, "Alice", 2, "Bob") + require.NoError(t, err) + + // Test Query function + var results []struct { + ID int + Name string + } + + p.Query( + fmt.Sprintf("SELECT id, name FROM %s.test_query ORDER BY id", schemaName), + []interface{}{}, + func(rows *sql.Rows) error { + var result struct { + ID int + Name string + } + if err := rows.Scan(&result.ID, &result.Name); err != nil { + return err + } + results = append(results, result) + return nil + }, + ) + + // Verify results + require.Len(t, results, 2) + require.Equal(t, 1, results[0].ID) + require.Equal(t, "Alice", results[0].Name) + require.Equal(t, 2, results[1].ID) + require.Equal(t, "Bob", results[1].Name) + }) + + t.Run("query with arguments", func(t *testing.T) { + p, _, schemaName := baseTest(t, "", "tst_pg_query_args") + + // Create a test table + p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) { + s.Integer("id") + s.String("name") + }) + + // Insert test data + _, err := p.TX.ExecContext(context.Background(), fmt.Sprintf( + "INSERT INTO %s.test_query (id, name) VALUES ($1, $2), ($3, $4)", + schemaName, + ), 1, "Alice", 2, "Bob") + require.NoError(t, err) + + // Test Query with arguments + var result string + p.Query( + fmt.Sprintf("SELECT name FROM %s.test_query WHERE id = $1", schemaName), + []interface{}{1}, + func(rows *sql.Rows) error { + return rows.Scan(&result) + }, + ) + + require.Equal(t, "Alice", result) + }) + + t.Run("query with error handling", func(t *testing.T) { + p, _, schemaName := baseTest(t, "", "tst_pg_query_error") + + // Test invalid query + require.Panics(t, func() { + p.Query( + fmt.Sprintf("SELECT * FROM %s.nonexistent_table", schemaName), + []interface{}{}, + func(rows *sql.Rows) error { + return nil + }, + ) + }) + + // Test error in row handler + p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) { + s.Integer("id") + }) + + _, err := p.TX.ExecContext(context.Background(), fmt.Sprintf( + "INSERT INTO %s.test_query (id) VALUES ($1)", + schemaName, + ), 1) + require.NoError(t, err) + + require.Panics(t, func() { + p.Query( + fmt.Sprintf("SELECT id FROM %s.test_query", schemaName), + []interface{}{}, + func(rows *sql.Rows) error { + return fmt.Errorf("test error in row handler") + }, + ) + }) + }) +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 5e2e62d..1eda97b 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -1,5 +1,7 @@ package schema +import "database/sql" + // Schema is the interface that need to be implemented to support migrations. type Schema interface { AddVersion(version string) @@ -7,4 +9,5 @@ type Schema interface { FindAppliedVersions() []string Exec(query string, args ...interface{}) + Query(query string, args []any, rowHandler func(row *sql.Rows) error) }