Skip to content

Commit

Permalink
feat: add Query api
Browse files Browse the repository at this point in the history
  • Loading branch information
alexisvisco committed Nov 4, 2024
1 parent 911faf4 commit 9fd2bf0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
20 changes: 20 additions & 0 deletions pkg/schema/base/base.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package base

import (
"database/sql"
"fmt"
"strings"

Expand Down Expand Up @@ -165,3 +166,22 @@ func ColumnType(ctx *schema.MigratorContext, options schema.ColumnData) func() s
return strBuilder.String()
}
}

func (p *Schema) Query(query string, args []interface{}, rowHandler func(row *sql.Rows) error) {
rows, err := p.TX.QueryContext(p.Context.Context, query, args...)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while querying: %w", err))
}

defer rows.Close()

for rows.Next() {
if err := rowHandler(rows); err != nil {
p.Context.RaiseError(fmt.Errorf("error from privded row handler: %w", err))
}
}

if err := rows.Err(); err != nil {
p.Context.RaiseError(fmt.Errorf("error after iterating rows: %w", err))
}
}
113 changes: 113 additions & 0 deletions pkg/schema/pg/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,116 @@ order by column_name;`

return columns
}

func TestPostgres_Query(t *testing.T) {
t.Run("basic query execution", func(t *testing.T) {
p, _, schemaName := baseTest(t, "", "tst_pg_query")

// Create a test table
p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) {
s.Integer("id")
s.String("name")
})

// Insert some test data
_, err := p.TX.ExecContext(context.Background(), fmt.Sprintf(
"INSERT INTO %s.test_query (id, name) VALUES ($1, $2), ($3, $4)",
schemaName,
), 1, "Alice", 2, "Bob")
require.NoError(t, err)

// Test Query function
var results []struct {
ID int
Name string
}

p.Query(
fmt.Sprintf("SELECT id, name FROM %s.test_query ORDER BY id", schemaName),
[]interface{}{},
func(rows *sql.Rows) error {
var result struct {
ID int
Name string
}
if err := rows.Scan(&result.ID, &result.Name); err != nil {
return err
}
results = append(results, result)
return nil
},
)

// Verify results
require.Len(t, results, 2)
require.Equal(t, 1, results[0].ID)
require.Equal(t, "Alice", results[0].Name)
require.Equal(t, 2, results[1].ID)
require.Equal(t, "Bob", results[1].Name)
})

t.Run("query with arguments", func(t *testing.T) {
p, _, schemaName := baseTest(t, "", "tst_pg_query_args")

// Create a test table
p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) {
s.Integer("id")
s.String("name")
})

// Insert test data
_, err := p.TX.ExecContext(context.Background(), fmt.Sprintf(
"INSERT INTO %s.test_query (id, name) VALUES ($1, $2), ($3, $4)",
schemaName,
), 1, "Alice", 2, "Bob")
require.NoError(t, err)

// Test Query with arguments
var result string
p.Query(
fmt.Sprintf("SELECT name FROM %s.test_query WHERE id = $1", schemaName),
[]interface{}{1},
func(rows *sql.Rows) error {
return rows.Scan(&result)
},
)

require.Equal(t, "Alice", result)
})

t.Run("query with error handling", func(t *testing.T) {
p, _, schemaName := baseTest(t, "", "tst_pg_query_error")

// Test invalid query
require.Panics(t, func() {
p.Query(
fmt.Sprintf("SELECT * FROM %s.nonexistent_table", schemaName),
[]interface{}{},
func(rows *sql.Rows) error {
return nil
},
)
})

// Test error in row handler
p.CreateTable(schema.Table("test_query", schemaName), func(s *PostgresTableDef) {
s.Integer("id")
})

_, err := p.TX.ExecContext(context.Background(), fmt.Sprintf(
"INSERT INTO %s.test_query (id) VALUES ($1)",
schemaName,
), 1)
require.NoError(t, err)

require.Panics(t, func() {
p.Query(
fmt.Sprintf("SELECT id FROM %s.test_query", schemaName),
[]interface{}{},
func(rows *sql.Rows) error {
return fmt.Errorf("test error in row handler")
},
)
})
})
}
3 changes: 3 additions & 0 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package schema

import "database/sql"

// Schema is the interface that need to be implemented to support migrations.
type Schema interface {
AddVersion(version string)
RemoveVersion(version string)
FindAppliedVersions() []string

Exec(query string, args ...interface{})
Query(query string, args []any, rowHandler func(row *sql.Rows) error)
}

0 comments on commit 9fd2bf0

Please sign in to comment.