diff --git a/internal/proto/writer.go b/internal/proto/writer.go index 78595cc4f..fed169b67 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "reflect" "strconv" "time" @@ -140,6 +141,36 @@ func (w *Writer) WriteArg(v interface{}) error { return w.bytes(b) case net.IP: return w.bytes(v) + default: + return w.writeArgExtra(v) + } +} + +func (w *Writer) writeArgExtra(v interface{}) error { + var ( + rfValue = reflect.ValueOf(v) + rfKind = rfValue.Kind() + ) + + switch rfKind { + case reflect.Bool: + if rfValue.Bool() { + return w.int(1) + } + return w.int(0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return w.int(rfValue.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return w.uint(rfValue.Uint()) + case reflect.Float32, reflect.Float64: + return w.float(rfValue.Float()) + case reflect.String: + return w.string(rfValue.String()) + case reflect.Slice: + if rfValue.Type().Elem().Kind() == reflect.Uint8 { + return w.bytes(rfValue.Bytes()) + } + fallthrough default: return fmt.Errorf( "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) diff --git a/redis_test.go b/redis_test.go index ef2125452..89a1ad53f 100644 --- a/redis_test.go +++ b/redis_test.go @@ -362,6 +362,20 @@ var _ = Describe("Client", func() { Expect(ip2).To(Equal(ip)) }) + + It("should set and scan custom type", func() { + type customString string + + val := customString("hello") + err := client.Set(ctx, "custom", val, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + var val2 string + err = client.Get(ctx, "custom").Scan(&val2) + Expect(err).NotTo(HaveOccurred()) + + Expect(customString(val2)).To(Equal(val)) + }) }) var _ = Describe("Client timeout", func() {