Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions neo4j/db/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ type Connection interface {
// Implementation of this should be passive, no pinging or similair since it might be
// called rather frequently.
IsAlive() bool
// HasFailed returns true if the connection has received a recoverable error (``FAILURE``).
HasFailed() bool
// Returns the point in time when this connection was established.
Birthdate() time.Time
// Resets connection to same state as directly after a connect.
Expand Down
4 changes: 4 additions & 0 deletions neo4j/internal/bolt/bolt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ func (b *bolt3) IsAlive() bool {
return b.state != bolt3_dead
}

func (b *bolt3) HasFailed() bool {
return b.state == bolt3_failed
}

func (b *bolt3) Birthdate() time.Time {
return b.birthDate
}
Expand Down
4 changes: 4 additions & 0 deletions neo4j/internal/bolt/bolt4.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,10 @@ func (b *bolt4) IsAlive() bool {
return b.state != bolt4_dead
}

func (b *bolt4) HasFailed() bool {
return b.state == bolt4_failed
}

func (b *bolt4) Birthdate() time.Time {
return b.birthDate
}
Expand Down
4 changes: 2 additions & 2 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,6 @@ func (p *Pool) Return(c db.Connection) {
return
}

c.SetBoltLogger(nil)

// Get the name of the server that the connection belongs to.
serverName := c.ServerName()
isAlive := c.IsAlive()
Expand Down Expand Up @@ -380,6 +378,8 @@ func (p *Pool) Return(c db.Connection) {
isAlive = c.IsAlive()
}

c.SetBoltLogger(nil)

// Shouldn't return a too old or dead connection back to the pool
if !isAlive || age >= p.maxAge {
p.unreg(serverName, c, now)
Expand Down
17 changes: 17 additions & 0 deletions neo4j/internal/testutil/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,20 @@ func AssertSameType(t *testing.T, x, y interface{}) {
t.Errorf("Expected types of %s and %s to be same but was %s and %s", x, y, t1, t2)
}
}

func AssertDeepEquals(t *testing.T, values ...interface{}) {
t.Helper()
count := len(values)
if count == 0 {
return
}
prev := values[0]
for i := 1; i < count; i++ {
current := values[i]
if !reflect.DeepEqual(prev, current) {
t.Errorf("Expected value %v (parameter %d) to equal value %v (parameter %d)", prev, i-1, current, i)
return
}
prev = current
}
}
17 changes: 13 additions & 4 deletions neo4j/internal/testutil/connfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (c *ConnFake) IsAlive() bool {
return c.Alive
}

func (c *ConnFake) HasFailed() bool {
return false
}

func (c *ConnFake) Reset() {
}

Expand Down Expand Up @@ -148,11 +152,16 @@ func (c *ConnFake) Keys(streamHandle db.StreamHandle) ([]string, error) {
}

func (c *ConnFake) Next(streamHandle db.StreamHandle) (*db.Record, *db.Summary, error) {
next := c.Nexts[0]
if len(c.Nexts) > 1 {
c.Nexts = c.Nexts[1:]
if len(c.Nexts) >= 1 {
next := c.Nexts[0]
// moves to next record only if the current record is not an error or summary
// this emulates the stream buffering of a real connection
if next.Err == nil && next.Summary == nil {
c.Nexts = c.Nexts[1:]
}
return next.Record, next.Summary, next.Err
}
return next.Record, next.Summary, next.Err
return nil, nil, nil
}

func (c *ConnFake) ForceReset() error {
Expand Down
55 changes: 43 additions & 12 deletions neo4j/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Result interface {
// NextRecord returns true if there is a record to be processed, record parameter is set
// to point to current record.
NextRecord(record **Record) bool
// PeekRecord returns true if there is a record after the current one to be processed without advancing the record
// stream, record parameter is set to point to that record if present.
PeekRecord(record **Record) bool
// Err returns the latest error that caused this Next to return false.
Err() error
// Record returns the current record.
Expand All @@ -46,13 +49,16 @@ type Result interface {
}

type result struct {
conn db.Connection
streamHandle db.StreamHandle
cypher string
params map[string]interface{}
record *Record
summary *db.Summary
err error
conn db.Connection
streamHandle db.StreamHandle
cypher string
params map[string]interface{}
record *Record
summary *db.Summary
err error
peekedRecord *Record
peekedSummary *db.Summary
peeked bool
}

func newResult(conn db.Connection, str db.StreamHandle, cypher string, params map[string]interface{}) *result {
Expand All @@ -69,18 +75,26 @@ func (r *result) Keys() ([]string, error) {
}

func (r *result) Next() bool {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
return r.record != nil
}

func (r *result) NextRecord(out **Record) bool {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if out != nil {
*out = r.record
}
return r.record != nil
}

func (r *result) PeekRecord(out **Record) bool {
r.peek()
if out != nil {
*out = r.peekedRecord
}
return r.peekedRecord != nil
}

func (r *result) Record() *Record {
return r.record
}
Expand All @@ -92,7 +106,7 @@ func (r *result) Err() error {
func (r *result) Collect() ([]*Record, error) {
recs := make([]*Record, 0, 1024)
for r.summary == nil && r.err == nil {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.record != nil {
recs = append(recs, r.record)
}
Expand All @@ -109,7 +123,7 @@ func (r *result) buffer() {

func (r *result) Single() (*Record, error) {
// Try retrieving the single record
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.err != nil {
return nil, wrapError(r.err)
}
Expand All @@ -122,7 +136,7 @@ func (r *result) Single() (*Record, error) {
single := r.record

// Probe connection for more records
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
r.advance()
if r.record != nil {
// There were more records, consume the stream since the user didn't
// expect more records and should therefore not use them.
Expand Down Expand Up @@ -165,3 +179,20 @@ func (r *result) Consume() (ResultSummary, error) {
}
return r.toResultSummary(), nil
}

func (r *result) advance() {
if r.peeked {
r.record, r.peekedRecord = r.peekedRecord, nil
r.summary, r.peekedSummary = r.peekedSummary, nil
r.peeked = false
} else {
r.record, r.summary, r.err = r.conn.Next(r.streamHandle)
}
}

func (r *result) peek() {
if !r.peeked {
r.peekedRecord, r.peekedSummary, r.err = r.conn.Next(r.streamHandle)
r.peeked = true
}
}
Loading