Skip to content

Commit c2175fe

Browse files
authored
Merge pull request #2213 from moukoublen/fix_2204
Fix #2204
2 parents 659823f + 61a0227 commit c2175fe

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

pgtype/json.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
161161
//
162162
// https://github.com/jackc/pgx/issues/2146
163163
func isSQLScanner(v any) bool {
164+
if _, is := v.(sql.Scanner); is {
165+
return true
166+
}
167+
164168
val := reflect.ValueOf(v)
165169
for val.Kind() == reflect.Ptr {
166170
if _, ok := val.Interface().(sql.Scanner); ok {
@@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
212216
return fmt.Errorf("cannot scan NULL into %T", dst)
213217
}
214218

215-
elem := reflect.ValueOf(dst).Elem()
219+
v := reflect.ValueOf(dst)
220+
if v.Kind() != reflect.Pointer || v.IsNil() {
221+
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
222+
}
223+
224+
elem := v.Elem()
216225
elem.Set(reflect.Zero(elem.Type()))
217226

218227
return s.unmarshal(src, dst)

pgtype/json_test.go

+44-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func TestJSONCodec(t *testing.T) {
4848
Age int `json:"age"`
4949
}
5050

51+
var str string
5152
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
5253
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
5354
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
@@ -65,6 +66,9 @@ func TestJSONCodec(t *testing.T) {
6566
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
6667
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
6768
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
69+
70+
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
71+
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
6872
})
6973

7074
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@@ -136,6 +140,27 @@ func (i Issue2146) Value() (driver.Value, error) {
136140
return string(b), err
137141
}
138142

143+
type NonPointerJSONScanner struct {
144+
V *string
145+
}
146+
147+
func (i NonPointerJSONScanner) Scan(src any) error {
148+
switch c := src.(type) {
149+
case string:
150+
*i.V = c
151+
case []byte:
152+
*i.V = string(c)
153+
default:
154+
return errors.New("unknown source type")
155+
}
156+
157+
return nil
158+
}
159+
160+
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
161+
return i.V, nil
162+
}
163+
139164
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
140165
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
141166
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
@@ -267,7 +292,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
267292
Unmarshal: func(data []byte, v any) error {
268293
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
269294
},
270-
}})
295+
},
296+
})
271297
}
272298

273299
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@@ -278,3 +304,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
278304
}},
279305
})
280306
}
307+
308+
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
309+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
310+
n := 44
311+
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
312+
require.Error(t, err)
313+
314+
var i *int
315+
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
316+
require.Error(t, err)
317+
318+
m := 0
319+
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
320+
require.NoError(t, err)
321+
require.Equal(t, 42, m)
322+
})
323+
}

pgtype/pgtype.go

+4
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
415415

416416
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
417417
func getSQLScanner(target any) sql.Scanner {
418+
if sc, is := target.(sql.Scanner); is {
419+
return sc
420+
}
421+
418422
val := reflect.ValueOf(target)
419423
for val.Kind() == reflect.Ptr {
420424
if _, ok := val.Interface().(sql.Scanner); ok {

0 commit comments

Comments
 (0)