From 69005e6c3f61bbba01fdca4836e0baafe87e26e9 Mon Sep 17 00:00:00 2001 From: oldme Date: Wed, 4 Sep 2024 14:33:32 +0800 Subject: [PATCH 1/2] up --- internal/proto/writer.go | 31 +++++++++++++++++++++++++++++++ redis_test.go | 14 ++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/internal/proto/writer.go b/internal/proto/writer.go index 78595cc4f..fed169b67 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "reflect" "strconv" "time" @@ -140,6 +141,36 @@ func (w *Writer) WriteArg(v interface{}) error { return w.bytes(b) case net.IP: return w.bytes(v) + default: + return w.writeArgExtra(v) + } +} + +func (w *Writer) writeArgExtra(v interface{}) error { + var ( + rfValue = reflect.ValueOf(v) + rfKind = rfValue.Kind() + ) + + switch rfKind { + case reflect.Bool: + if rfValue.Bool() { + return w.int(1) + } + return w.int(0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return w.int(rfValue.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return w.uint(rfValue.Uint()) + case reflect.Float32, reflect.Float64: + return w.float(rfValue.Float()) + case reflect.String: + return w.string(rfValue.String()) + case reflect.Slice: + if rfValue.Type().Elem().Kind() == reflect.Uint8 { + return w.bytes(rfValue.Bytes()) + } + fallthrough default: return fmt.Errorf( "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) diff --git a/redis_test.go b/redis_test.go index ef2125452..5b86de2a5 100644 --- a/redis_test.go +++ b/redis_test.go @@ -362,6 +362,20 @@ var _ = Describe("Client", func() { Expect(ip2).To(Equal(ip)) }) + + It("should set and scan custom type", func() { + type customString string + + val := customString("hello") + err := client.Set(ctx, "custom", val, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + var val2 customString + err = client.Get(ctx, "custom").Scan(&val2) + Expect(err).NotTo(HaveOccurred()) + + Expect(val2).To(Equal(val)) + }) }) var _ = Describe("Client timeout", func() { From b121fb83994b48a67bd9ad44226a12cb2d4180dc Mon Sep 17 00:00:00 2001 From: oldme Date: Wed, 4 Sep 2024 15:13:35 +0800 Subject: [PATCH 2/2] up --- redis_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_test.go b/redis_test.go index 5b86de2a5..89a1ad53f 100644 --- a/redis_test.go +++ b/redis_test.go @@ -370,11 +370,11 @@ var _ = Describe("Client", func() { err := client.Set(ctx, "custom", val, 0).Err() Expect(err).NotTo(HaveOccurred()) - var val2 customString + var val2 string err = client.Get(ctx, "custom").Scan(&val2) Expect(err).NotTo(HaveOccurred()) - Expect(val2).To(Equal(val)) + Expect(customString(val2)).To(Equal(val)) }) })