diff --git a/conn.go b/conn.go index cb44c6929..3225468fa 100644 --- a/conn.go +++ b/conn.go @@ -621,7 +621,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err } func (cn *conn) simpleQuery(q string) (res *rows, err error) { - defer cn.errRecover(&err) + defer cn.errRecoverWithQuery(&err, q) b := cn.writeBuf('Q') b.string(q) @@ -763,7 +763,7 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { if cn.bad { return nil, driver.ErrBadConn } - defer cn.errRecover(&err) + defer cn.errRecoverWithQuery(&err, q) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { return cn.prepareCopyIn(q) @@ -794,7 +794,7 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err if cn.bad { return nil, driver.ErrBadConn } - defer cn.errRecover(&err) + defer cn.errRecoverWithQuery(&err, query) // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec @@ -828,7 +828,7 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err if cn.bad { return nil, driver.ErrBadConn } - defer cn.errRecover(&err) + defer cn.errRecoverWithQuery(&err, query) // Check to see if we can use the "simpleExec" interface, which is // *much* faster than going through prepare/exec diff --git a/error.go b/error.go index b4bb44cee..8f5dcc62b 100644 --- a/error.go +++ b/error.go @@ -6,6 +6,8 @@ import ( "io" "net" "runtime" + "strconv" + "strings" ) // Error severities @@ -40,6 +42,9 @@ type Error struct { File string Line string Routine string + + // For syntax error reporting. Not a field returned by Postgres. + originalQuery string } // ErrorCode is a five-character error code. @@ -444,9 +449,61 @@ func (err *Error) Get(k byte) (v string) { } func (err Error) Error() string { + switch err.Code { + case "42601": // syntax_error + return err.syntaxError() + default: + return err.normalError() + } +} + +func (err Error) normalError() string { return "pq: " + err.Message } +// syntaxError formats a syntax error the way psql does. +func (err Error) syntaxError() string { + if err.Position == "" || err.originalQuery == "" { + return err.normalError() // not enough information, fallback + } + + pos, e := strconv.Atoi(err.Position) + if e != nil { + return err.normalError() // Position is not a number, fallback + } + pos -= 1 // make zero-based + + if pos < 0 || pos >= len(err.originalQuery) { + return err.normalError() // Position is out of range, fallback + } + + lineStartPos := strings.LastIndex(err.originalQuery[:pos], "\n") + if lineStartPos == -1 { // error in first line? + lineStartPos = 0 + } else { + lineStartPos += 1 // remove \n + } + + lineEndPos := strings.Index(err.originalQuery[pos:], "\n") + if lineEndPos == -1 { // error in last line? + lineEndPos = len(err.originalQuery) + } else { + lineEndPos += pos // absolute position + } + + lineNo := strings.Count(err.originalQuery[:lineStartPos], "\n") + + queryLinePrefix := fmt.Sprintf("LINE %d: ", lineNo+1) + queryLine := err.originalQuery[lineStartPos:lineEndPos] + markerLinePrefix := strings.Repeat(" ", len(queryLinePrefix)) + markerLine := strings.Repeat(" ", pos-lineStartPos) + "^" + + return fmt.Sprintf("pq: %s\n%s%s\n%s%s", + err.Message, + queryLinePrefix, queryLine, + markerLinePrefix, markerLine) +} + // PGError is an interface used by previous versions of pq. It is provided // only to support legacy code. New code should use the Error type. type PGError interface { @@ -473,8 +530,20 @@ func errRecoverNoErrBadConn(err *error) { } func (c *conn) errRecover(err *error) { - e := recover() - switch v := e.(type) { + c.errHandleRecovered(recover(), err) +} + +func (c *conn) errRecoverWithQuery(err *error, query string) { + c.errHandleRecovered(recover(), err) + if *err != nil { + if pqErr, ok := (*err).(*Error); ok { + pqErr.originalQuery = query + } + } +} + +func (c *conn) errHandleRecovered(recovered interface{}, err *error) { + switch v := recovered.(type) { case nil: // Do nothing case runtime.Error: @@ -497,7 +566,7 @@ func (c *conn) errRecover(err *error) { default: c.bad = true - panic(fmt.Sprintf("unknown error: %#v", e)) + panic(fmt.Sprintf("unknown error: %#v", recovered)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't diff --git a/error_test.go b/error_test.go new file mode 100644 index 000000000..e816ea93d --- /dev/null +++ b/error_test.go @@ -0,0 +1,68 @@ +package pq + +import ( + "strings" + "testing" +) + +func TestSyntaxErrorFormatting(t *testing.T) { + for _, tt := range []struct { + err Error + expected string + }{ + // Single line + {Error{Message: "test", Position: "8", originalQuery: "SELECT *;"}, + "pq: test\nLINE 1: SELECT *;\n ^"}, + + // Syntax error in first line + {Error{Message: "test", Position: "1", originalQuery: "SELECT\n *;"}, + "pq: test\nLINE 1: SELECT\n ^"}, + + // Syntax error in last line + {Error{Message: "test", Position: "9", originalQuery: "SELECT\n *;"}, + "pq: test\nLINE 2: *;\n ^"}, + + // Bad input: position non-positive + {Error{Message: "test", Position: "0", originalQuery: "SELECT\n *;"}, + Error{Message: "test", Position: "0", originalQuery: "SELECT\n *;"}.normalError()}, + + // Bad input: position after end of string + {Error{Message: "test", Position: "11", originalQuery: "SELECT\n *;"}, + Error{Message: "test", Position: "11", originalQuery: "SELECT\n *;"}.normalError()}, + {Error{Message: "test", Position: "not a number", originalQuery: "SELECT\n *;"}, + Error{Message: "test", Position: "not a number", originalQuery: "SELECT\n *;"}.normalError()}, + } { + actual := tt.err.syntaxError() + if tt.expected != actual { + t.Errorf("bad message, expected %#v, got %#v", tt.expected, actual) + } + } +} + +func TestSyntaxErrorHandlingWithQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Query("SELECT *;") + if err == nil { + t.Fatal(err) + } + + if !strings.HasSuffix(err.Error(), " ^") { + t.Errorf("syntax error not formatted as such. got %#v", err.Error()) + } +} + +func TestSyntaxErrorHandlingWithPrepare(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Prepare("SELECT *;") + if err == nil { + t.Fatal(err) + } + + if !strings.HasSuffix(err.Error(), " ^") { + t.Errorf("syntax error not formatted as such. got %#v", err.Error()) + } +}