Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
}

func (cn *conn) simpleQuery(q string) (res *rows, err error) {
defer cn.errRecover(&err)
defer cn.errRecoverWithQuery(&err, q)

b := cn.writeBuf('Q')
b.string(q)
Expand Down Expand Up @@ -763,7 +763,7 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
defer cn.errRecoverWithQuery(&err, q)

if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
return cn.prepareCopyIn(q)
Expand Down Expand Up @@ -794,7 +794,7 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
defer cn.errRecoverWithQuery(&err, query)

// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
Expand Down Expand Up @@ -828,7 +828,7 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
defer cn.errRecoverWithQuery(&err, query)

// Check to see if we can use the "simpleExec" interface, which is
// *much* faster than going through prepare/exec
Expand Down
75 changes: 72 additions & 3 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"net"
"runtime"
"strconv"
"strings"
)

// Error severities
Expand Down Expand Up @@ -40,6 +42,9 @@ type Error struct {
File string
Line string
Routine string

// For syntax error reporting. Not a field returned by Postgres.
originalQuery string
}

// ErrorCode is a five-character error code.
Expand Down Expand Up @@ -444,9 +449,61 @@ func (err *Error) Get(k byte) (v string) {
}

func (err Error) Error() string {
switch err.Code {
case "42601": // syntax_error
return err.syntaxError()
default:
return err.normalError()
}
}

func (err Error) normalError() string {
return "pq: " + err.Message
}

// syntaxError formats a syntax error the way psql does.
func (err Error) syntaxError() string {
if err.Position == "" || err.originalQuery == "" {
return err.normalError() // not enough information, fallback
}

pos, e := strconv.Atoi(err.Position)
if e != nil {
return err.normalError() // Position is not a number, fallback
}
pos -= 1 // make zero-based

if pos < 0 || pos >= len(err.originalQuery) {
return err.normalError() // Position is out of range, fallback
}

lineStartPos := strings.LastIndex(err.originalQuery[:pos], "\n")
if lineStartPos == -1 { // error in first line?
lineStartPos = 0
} else {
lineStartPos += 1 // remove \n
}

lineEndPos := strings.Index(err.originalQuery[pos:], "\n")
if lineEndPos == -1 { // error in last line?
lineEndPos = len(err.originalQuery)
} else {
lineEndPos += pos // absolute position
}

lineNo := strings.Count(err.originalQuery[:lineStartPos], "\n")

queryLinePrefix := fmt.Sprintf("LINE %d: ", lineNo+1)
queryLine := err.originalQuery[lineStartPos:lineEndPos]
markerLinePrefix := strings.Repeat(" ", len(queryLinePrefix))
markerLine := strings.Repeat(" ", pos-lineStartPos) + "^"

return fmt.Sprintf("pq: %s\n%s%s\n%s%s",
err.Message,
queryLinePrefix, queryLine,
markerLinePrefix, markerLine)
}

// PGError is an interface used by previous versions of pq. It is provided
// only to support legacy code. New code should use the Error type.
type PGError interface {
Expand All @@ -473,8 +530,20 @@ func errRecoverNoErrBadConn(err *error) {
}

func (c *conn) errRecover(err *error) {
e := recover()
switch v := e.(type) {
c.errHandleRecovered(recover(), err)
}

func (c *conn) errRecoverWithQuery(err *error, query string) {
c.errHandleRecovered(recover(), err)
if *err != nil {
if pqErr, ok := (*err).(*Error); ok {
pqErr.originalQuery = query
}
}
}

func (c *conn) errHandleRecovered(recovered interface{}, err *error) {
switch v := recovered.(type) {
case nil:
// Do nothing
case runtime.Error:
Expand All @@ -497,7 +566,7 @@ func (c *conn) errRecover(err *error) {

default:
c.bad = true
panic(fmt.Sprintf("unknown error: %#v", e))
panic(fmt.Sprintf("unknown error: %#v", recovered))
}

// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
Expand Down
68 changes: 68 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package pq

import (
"strings"
"testing"
)

func TestSyntaxErrorFormatting(t *testing.T) {
for _, tt := range []struct {
err Error
expected string
}{
// Single line
{Error{Message: "test", Position: "8", originalQuery: "SELECT *;"},
"pq: test\nLINE 1: SELECT *;\n ^"},

// Syntax error in first line
{Error{Message: "test", Position: "1", originalQuery: "SELECT\n *;"},
"pq: test\nLINE 1: SELECT\n ^"},

// Syntax error in last line
{Error{Message: "test", Position: "9", originalQuery: "SELECT\n *;"},
"pq: test\nLINE 2: *;\n ^"},

// Bad input: position non-positive
{Error{Message: "test", Position: "0", originalQuery: "SELECT\n *;"},
Error{Message: "test", Position: "0", originalQuery: "SELECT\n *;"}.normalError()},

// Bad input: position after end of string
{Error{Message: "test", Position: "11", originalQuery: "SELECT\n *;"},
Error{Message: "test", Position: "11", originalQuery: "SELECT\n *;"}.normalError()},
{Error{Message: "test", Position: "not a number", originalQuery: "SELECT\n *;"},
Error{Message: "test", Position: "not a number", originalQuery: "SELECT\n *;"}.normalError()},
} {
actual := tt.err.syntaxError()
if tt.expected != actual {
t.Errorf("bad message, expected %#v, got %#v", tt.expected, actual)
}
}
}

func TestSyntaxErrorHandlingWithQuery(t *testing.T) {
db := openTestConn(t)
defer db.Close()

_, err := db.Query("SELECT *;")
if err == nil {
t.Fatal(err)
}

if !strings.HasSuffix(err.Error(), " ^") {
t.Errorf("syntax error not formatted as such. got %#v", err.Error())
}
}

func TestSyntaxErrorHandlingWithPrepare(t *testing.T) {
db := openTestConn(t)
defer db.Close()

_, err := db.Prepare("SELECT *;")
if err == nil {
t.Fatal(err)
}

if !strings.HasSuffix(err.Error(), " ^") {
t.Errorf("syntax error not formatted as such. got %#v", err.Error())
}
}