diff --git a/README.md b/README.md index b372944..5614cf3 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ func main() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error @@ -98,12 +98,12 @@ func main() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - err := client.TensorSet("foo1", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + err := client.TensorSet("foo1", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) if err != nil { log.Fatal(err) } // AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1 - err = client.TensorSet("foo2", redisai.TypeFloat, []int{1, 1}, []float32{1.1}) + err = client.TensorSet("foo2", redisai.TypeFloat, []int64{1, 1}, []float32{1.1}) if err != nil { log.Fatal(err) } diff --git a/examples/redisai_pipelined_client/main.go b/examples/redisai_pipelined_client/main.go index 6171366..659c4cc 100644 --- a/examples/redisai_pipelined_client/main.go +++ b/examples/redisai_pipelined_client/main.go @@ -16,12 +16,12 @@ func main() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - err := client.TensorSet("foo1", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + err := client.TensorSet("foo1", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) if err != nil { log.Fatal(err) } // AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1 - err = client.TensorSet("foo2", redisai.TypeFloat, []int{1, 1}, []float32{1.1}) + err = client.TensorSet("foo2", redisai.TypeFloat, []int64{1, 1}, []float32{1.1}) if err != nil { log.Fatal(err) } diff --git a/examples/redisai_simple_client/main.go b/examples/redisai_simple_client/main.go index 98fb964..24d3c9b 100644 --- a/examples/redisai_simple_client/main.go +++ b/examples/redisai_simple_client/main.go @@ -13,7 +13,7 @@ func main() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error diff --git a/examples/redisai_tls_client/redisai_tls_client.go b/examples/redisai_tls_client/redisai_tls_client.go index 579d7f6..992ed8a 100644 --- a/examples/redisai_tls_client/redisai_tls_client.go +++ b/examples/redisai_tls_client/redisai_tls_client.go @@ -82,7 +82,7 @@ func main() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error diff --git a/redisai/client_test.go b/redisai/client_test.go index 6c1ea85..6f01e31 100644 --- a/redisai/client_test.go +++ b/redisai/client_test.go @@ -238,12 +238,12 @@ func TestClient_DisablePipeline(t *testing.T) { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - err := client.TensorSet("foo1", TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + err := client.TensorSet("foo1", TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) if err != nil { t.Errorf("TensorSet() error = %v", err) } // AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1 - err = client.TensorSet("foo2", TypeFloat, []int{1, 1}, []float32{1.1}) + err = client.TensorSet("foo2", TypeFloat, []int64{1, 1}, []float32{1.1}) if err != nil { t.Errorf("TensorSet() error = %v", err) } diff --git a/redisai/commands.go b/redisai/commands.go index 419c115..25d36ab 100644 --- a/redisai/commands.go +++ b/redisai/commands.go @@ -8,7 +8,7 @@ import ( ) // TensorSet sets a tensor -func (c *Client) TensorSet(keyName, dt string, dims []int, data interface{}) (err error) { +func (c *Client) TensorSet(keyName, dt string, dims []int64, data interface{}) (err error) { args, err := tensorSetFlatArgs(keyName, dt, dims, data) _, err = c.DoOrSend("AI.TENSORSET", args, err) return @@ -45,7 +45,7 @@ func (c *Client) TensorGetToTensor(name, format string, tensor TensorInterface) } // TensorGetValues gets a tensor's values -func (c *Client) TensorGetValues(name string) (dt string, shape []int, data interface{}, err error) { +func (c *Client) TensorGetValues(name string) (dt string, shape []int64, data interface{}, err error) { args := redis.Args{}.Add(name, TensorContentTypeMeta, TensorContentTypeValues) var reply interface{} reply, err = c.DoOrSend("AI.TENSORGET", args, nil) @@ -57,7 +57,7 @@ func (c *Client) TensorGetValues(name string) (dt string, shape []int, data inte } // TensorGetValues gets a tensor's values -func (c *Client) TensorGetMeta(name string) (dt string, shape []int, err error) { +func (c *Client) TensorGetMeta(name string) (dt string, shape []int64, err error) { args := redis.Args{}.Add(name, TensorContentTypeMeta) var reply interface{} reply, err = c.DoOrSend("AI.TENSORGET", args, nil) @@ -69,7 +69,7 @@ func (c *Client) TensorGetMeta(name string) (dt string, shape []int, err error) } // TensorGetValues gets a tensor's values -func (c *Client) TensorGetBlob(name string) (dt string, shape []int, data []byte, err error) { +func (c *Client) TensorGetBlob(name string) (dt string, shape []int64, data []byte, err error) { args := redis.Args{}.Add(name, TensorContentTypeMeta, TensorContentTypeBlob) var reply interface{} reply, err = c.DoOrSend("AI.TENSORGET", args, nil) diff --git a/redisai/commands_test.go b/redisai/commands_test.go index ba6a4c4..ed83fee 100644 --- a/redisai/commands_test.go +++ b/redisai/commands_test.go @@ -17,7 +17,7 @@ func TestCommand_TensorSet(t *testing.T) { valuesInt8 := []int8{1} valuesInt16 := []int16{1} - valuesInt32 := []int{1} + valuesInt32 := []int64{1} valuesInt64 := []int64{1} valuesUint8 := []uint8{1} @@ -46,12 +46,12 @@ func TestCommand_TensorSet(t *testing.T) { keyError1 := "test:TestCommand_TensorSet:1:FaultyDims" - shp := []int{1} + shp := []int64{1} type args struct { name string dt string - dims []int + dims []int64 data interface{} } @@ -78,7 +78,7 @@ func TestCommand_TensorSet(t *testing.T) { {keyByte, args{keyByte, TypeUint8, shp, valuesByte}, false}, - {keyError1, args{keyError1, TypeFloat, []int{1, 10}, []float32{1}}, true}, + {keyError1, args{keyError1, TypeFloat, []int64{1, 10}, []float32{1}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,7 +91,7 @@ func TestCommand_TensorSet(t *testing.T) { } func TestCommand_FullFromTensor(t *testing.T) { - tensor := implementations.NewAiTensorWithShape([]int{1}) + tensor := implementations.NewAiTensorWithShape([]int64{1}) tensor.SetData([]float32{1.0}) client := createTestClient() err := client.TensorSetFromTensor("tensor1", tensor) @@ -101,7 +101,7 @@ func TestCommand_FullFromTensor(t *testing.T) { if diff := cmp.Diff(TypeFloat32, gotResp[0]); diff != "" { t.Errorf("TestCommand_FullFromTensor() mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff([]int{1}, gotResp[1]); diff != "" { + if diff := cmp.Diff([]int64{1}, gotResp[1]); diff != "" { t.Errorf("TestCommand_FullFromTensor() mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff([]float32{1.0}, gotResp[2]); diff != "" { @@ -149,8 +149,8 @@ func TestCommand_TensorGet(t *testing.T) { keyUint8 := "test:TensorGet:TypeUint8:1" keyUint16 := "test:TensorGet:TypeUint16:1" - shp := []int{1} - shpByteSlice := []int{1, 5} + shp := []int64{1} + shpByteSlice := []int64{1, 5} simpleClient := createTestClient() simpleClient.TensorSet(keyByteSlice, TypeUint8, shpByteSlice, valuesByteSlice) @@ -173,7 +173,7 @@ func TestCommand_TensorGet(t *testing.T) { name string args args wantDt string - wantShape []int + wantShape []int64 wantData interface{} compareDt bool compareShape bool @@ -222,7 +222,7 @@ func TestCommand_TensorGetBlob(t *testing.T) { keyByte := "test:TensorGetBlog:[]byte:1" keyUnexistant := "test:TensorGetMeta:Unexistant" - shp := []int{1, 4} + shp := []int64{1, 4} simpleClient := Connect("", createPool()) simpleClient.TensorSet(keyByte, TypeInt8, shp, valuesByte) @@ -233,7 +233,7 @@ func TestCommand_TensorGetBlob(t *testing.T) { name string args args wantDt string - wantShape []int + wantShape []int64 wantData []byte wantErr bool }{ @@ -279,7 +279,7 @@ func TestCommand_TensorGetMeta(t *testing.T) { keyUnexistant := "test:TensorGetMeta:Unexistant" - shp := []int{1, 2} + shp := []int64{1, 2} simpleClient := Connect("", createPool()) simpleClient.TensorSet(keyFloat32, TypeFloat32, shp, nil) simpleClient.TensorSet(keyFloat64, TypeFloat64, shp, nil) @@ -299,7 +299,7 @@ func TestCommand_TensorGetMeta(t *testing.T) { name string args args wantDt string - wantShape []int + wantShape []int64 wantErr bool }{ {keyFloat32, args{keyFloat32}, TypeFloat32, shp, false}, @@ -354,8 +354,8 @@ func TestCommand_TensorGetValues(t *testing.T) { keyUint16 := "test:TensorGetValues:TypeUint16:1" keyUnexistant := "test:TensorGetValues:Unexistant" - shp := []int{1} - shp2 := []int{1, 5} + shp := []int64{1} + shp2 := []int64{1, 5} simpleClient := Connect("", createPool()) simpleClient.TensorSet(keyFloat32, TypeFloat32, shp2, valuesFloat32) simpleClient.TensorSet(keyFloat64, TypeFloat64, shp, valuesFloat64) @@ -375,7 +375,7 @@ func TestCommand_TensorGetValues(t *testing.T) { name string args args wantDt string - wantShape []int + wantShape []int64 wantData interface{} wantErr bool }{ @@ -574,12 +574,12 @@ func TestCommand_ModelRun(t *testing.T) { return } - errortset := simpleClient.TensorSet(keyTransaction1, TypeFloat, []int{1, 30}, nil) + errortset := simpleClient.TensorSet(keyTransaction1, TypeFloat, []int64{1, 30}, nil) if errortset != nil { t.Error(errortset) } - errortsetReference := simpleClient.TensorSet(keyReference1, TypeFloat, []int{256}, nil) + errortsetReference := simpleClient.TensorSet(keyReference1, TypeFloat, []int64{256}, nil) if errortsetReference != nil { t.Error(errortsetReference) } @@ -848,9 +848,9 @@ func TestCommand_Info(t *testing.T) { assert.Equal(t, BackendTF, info["backend"]) assert.Equal(t, "0", info["calls"]) - err = c.TensorSet("a", TypeFloat32, []int{1}, []float32{1.1}) + err = c.TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1}) assert.Nil(t, err) - err = c.TensorSet("b", TypeFloat32, []int{1}, []float32{4.4}) + err = c.TensorSet("b", TypeFloat32, []int64{1}, []float32{4.4}) assert.Nil(t, err) err = c.ModelRun(keyModel1, []string{"a", "b"}, []string{"mul"}) assert.Nil(t, err) @@ -878,7 +878,7 @@ func TestCommand_DagRun(t *testing.T) { return } err = c.ModelSet(keyModel1, BackendTF, DeviceCPU, data, []string{"a", "b"}, []string{"mul"}) - err = c.TensorSet("persisted_tensor_1", TypeFloat32, []int{1, 2}, []float32{5, 10}) + err = c.TensorSet("persisted_tensor_1", TypeFloat32, []int64{1, 2}, []float32{5, 10}) assert.Nil(t, err) type args struct { @@ -891,12 +891,12 @@ func TestCommand_DagRun(t *testing.T) { args args wantErr bool }{ - {"t_wrong_number", args{[]string{"notnumber"}, nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true}, - {"t_load", args{[]string{"persisted_tensor_1"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, false}, - {"t_load_err", args{[]string{"not_exits_tensor"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true}, - {"t1", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1})}, false}, - {"t_blob", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1}).TensorSet("b", TypeFloat32, []int{1}, []float32{4.4}).ModelRun("test:DagRun:mymodel:1", []string{"a", "b"}, []string{"mul"}).TensorGet("mul", TensorContentTypeBlob)}, false}, - {"t_values", args{nil, nil, NewDag().TensorSet("mytensor", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorGet("mytensor", TensorContentTypeValues)}, false}, + {"t_wrong_number", args{[]string{"notnumber"}, nil, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true}, + {"t_load", args{[]string{"persisted_tensor_1"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, false}, + {"t_load_err", args{[]string{"not_exits_tensor"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true}, + {"t1", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1})}, false}, + {"t_blob", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1}).TensorSet("b", TypeFloat32, []int64{1}, []float32{4.4}).ModelRun("test:DagRun:mymodel:1", []string{"a", "b"}, []string{"mul"}).TensorGet("mul", TensorContentTypeBlob)}, false}, + {"t_values", args{nil, nil, NewDag().TensorSet("mytensor", TypeFloat32, []int64{1, 2}, []int64{5, 10}).TensorGet("mytensor", TensorContentTypeValues)}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -932,7 +932,7 @@ func TestCommand_DagRun(t *testing.T) { func TestCommand_DagRunRO(t *testing.T) { c := createTestClient() - err := c.TensorSet("persisted_tensor", TypeFloat32, []int{1, 2}, []float32{5, 10}) + err := c.TensorSet("persisted_tensor", TypeFloat32, []int64{1, 2}, []float32{5, 10}) assert.Nil(t, err) type args struct { loadKeys []string @@ -944,8 +944,8 @@ func TestCommand_DagRunRO(t *testing.T) { wantErr bool }{ {"t_1", args{[]string{"persisted_tensor"}, NewDag().TensorGet("persisted_tensor", TensorContentTypeValues)}, false}, - {"t_2", args{nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorSet("tensor2", TypeFloat32, []int{1, 2}, []int{5, 10})}, false}, - {"t_err1", args{[]string{"notnumber"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true}, + {"t_2", args{nil, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10}).TensorSet("tensor2", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, false}, + {"t_err1", args{[]string{"notnumber"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -978,3 +978,40 @@ func TestCommand_DagRunRO(t *testing.T) { }) } } + +func TestClient_ModelRun(t *testing.T) { + type fields struct { + Pool *redis.Pool + PipelineActive bool + PipelineAutoFlushSize uint32 + PipelinePos uint32 + ActiveConn redis.Conn + } + type args struct { + name string + inputTensorNames []string + outputTensorNames []string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + Pool: tt.fields.Pool, + PipelineActive: tt.fields.PipelineActive, + PipelineAutoFlushSize: tt.fields.PipelineAutoFlushSize, + PipelinePos: tt.fields.PipelinePos, + ActiveConn: tt.fields.ActiveConn, + } + if err := c.ModelRun(tt.args.name, tt.args.inputTensorNames, tt.args.outputTensorNames); (err != nil) != tt.wantErr { + t.Errorf("ModelRun() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} \ No newline at end of file diff --git a/redisai/dag.go b/redisai/dag.go index fe04abf..ac17c46 100644 --- a/redisai/dag.go +++ b/redisai/dag.go @@ -5,7 +5,7 @@ import "github.com/gomodule/redigo/redis" // DagCommandInterface is an interface that represents the skeleton of DAG supported commands // needed to map it to a RedisAI DAGRUN and DAGURN_RO commands type DagCommandInterface interface { - TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface + TensorSet(keyName, dt string, dims []int64, data interface{}) DagCommandInterface TensorGet(name, format string) DagCommandInterface ModelRun(name string, inputTensorNames, outputTensorNames []string) DagCommandInterface FlatArgs() (redis.Args, error) @@ -22,7 +22,7 @@ func NewDag() *Dag { } } -func (d *Dag) TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface { +func (d *Dag) TensorSet(keyName, dt string, dims []int64, data interface{}) DagCommandInterface { args := redis.Args{"AI.TENSORSET"} setFlatArgs, err := tensorSetFlatArgs(keyName, dt, dims, data) if err == nil { diff --git a/redisai/example_client_test.go b/redisai/example_client_test.go index 4da93b3..4533ab8 100644 --- a/redisai/example_client_test.go +++ b/redisai/example_client_test.go @@ -19,7 +19,7 @@ func ExampleConnect() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error @@ -49,7 +49,7 @@ func ExampleConnect_pool() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error @@ -120,7 +120,7 @@ func ExampleConnect_ssl() { // Set a tensor // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 - _ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) // Get a tensor content as a slice of values // dt DataType, shape []int, data interface{}, err error diff --git a/redisai/example_commands_test.go b/redisai/example_commands_test.go new file mode 100644 index 0000000..49a5c86 --- /dev/null +++ b/redisai/example_commands_test.go @@ -0,0 +1,220 @@ +package redisai_test + +import ( + "fmt" + "github.com/RedisAI/redisai-go/redisai" + "github.com/RedisAI/redisai-go/redisai/implementations" + "io/ioutil" +) + +func ExampleClient_TensorSet() { + // Create a simple client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Set a tensor + // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 + err := client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + + // print the error (should be ) + fmt.Println(err) + // Output: +} + +func ExampleClient_TensorSetFromTensor() { + // Create a simple client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Build a tensor + tensor := implementations.NewAiTensor() + tensor.SetShape([]int64{2, 2}) + tensor.SetData([]float32{1.1, 2.2, 3.3, 4.4}) + + // Set a tensor + // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 + err := client.TensorSetFromTensor("foo", tensor) + + // print the error (should be ) + fmt.Println(err) + // Output: +} + +func ExampleClient_TensorGet() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Set a tensor + // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + + // Get a tensor content as a slice of values + // AI.TENSORGET foo VALUES + fooTensorValues, err := client.TensorGet("foo", redisai.TensorContentTypeValues) + + fmt.Println(fooTensorValues, err) + // Output: [FLOAT [2 2] [1.1 2.2 3.3 4.4]] +} + +func ExampleClient_TensorGetToTensor() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Set a tensor + // AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4 + _ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4}) + + // Get a tensor content as a slice of values + // AI.TENSORGET foo VALUES + // Allocate an empty tensor + fooTensor := implementations.NewAiTensor() + err := client.TensorGetToTensor("foo", redisai.TensorContentTypeValues, fooTensor) + + // Print the tensor data + fmt.Println(fooTensor.Data(), err) + // Output: [1.1 2.2 3.3 4.4] +} + +func ExampleClient_ModelSet() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + data, _ := ioutil.ReadFile("./../tests/test_data/creditcardfraud.pb") + err := client.ModelSet("financialNet", redisai.BackendTF, redisai.DeviceCPU, data, []string{"transaction", "reference"}, []string{"output"}) + + // Print the error, which should be in case of sucessfull modelset + fmt.Println(err) + // Output: +} + +func ExampleClient_ModelGet() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + data, _ := ioutil.ReadFile("./../tests/test_data/creditcardfraud.pb") + err := client.ModelSet("financialNet", redisai.BackendTF, redisai.DeviceCPU, data, []string{"transaction", "reference"}, []string{"output"}) + + // Print the error, which should be in case of sucessfull modelset + fmt.Println(err) + + ///////////////////////////////////////////////////////////// + // The important part of ModelGet example starts here + reply, err := client.ModelGet("financialNet") + backend := reply[0] + device := reply[1] + // print the error (should be ) + fmt.Println(err) + fmt.Println(backend,device) + + // Output: + // + // + // TF CPU +} + +func ExampleClient_ModelSetFromModel() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Create a model + model := implementations.NewModel("TF", "CPU") + model.SetInputs([]string{"transaction", "reference"}) + model.SetOutputs([]string{"output"}) + model.SetBlobFromFile("./../tests/test_data/creditcardfraud.pb") + + err := client.ModelSetFromModel("financialNet", model) + + // Print the error, which should be in case of successful modelset + fmt.Println(err) + // Output: +} + +func ExampleClient_ModelGetToModel() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // Create a model + model := implementations.NewModel("TF", "CPU") + model.SetInputs([]string{"transaction", "reference"}) + model.SetOutputs([]string{"output"}) + + // Read the model from file + model.SetBlobFromFile("./../tests/test_data/creditcardfraud.pb") + + // Set the model to RedisAI so that we can afterwards test the modelget + err := client.ModelSetFromModel("financialNet", model) + // print the error (should be ) + fmt.Println(err) + + ///////////////////////////////////////////////////////////// + // The important part of ModelGetToModel example starts here + // Create an empty load to store the model from RedisAI + model1 := implementations.NewEmptyModel() + err = client.ModelGetToModel("financialNet", model1) + // print the error (should be ) + fmt.Println(err) + + // print the backend and device info of the model + fmt.Println(model1.Backend(), model1.Device()) + + // Output: + // + // + // TF CPU +} + +func ExampleClient_ModelRun() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // read the model from file + data, err := ioutil.ReadFile("./../tests/test_data/graph.pb") + + // set the model to RedisAI + err = client.ModelSet("example-model", redisai.BackendTF, redisai.DeviceCPU, data, []string{"a", "b"}, []string{"mul"}) + // print the error (should be ) + fmt.Println(err) + + // set the input tensors + err = client.TensorSet("a", redisai.TypeFloat32, []int64{1}, []float32{1.1}) + err = client.TensorSet("b", redisai.TypeFloat32, []int64{1}, []float32{4.4}) + + // run the model + err = client.ModelRun("example-model", []string{"a", "b"}, []string{"mul"}) + // print the error (should be ) + fmt.Println(err) + + // Output: + // + // +} + + +func ExampleClient_Info() { + // Create a client. + client := redisai.Connect("redis://localhost:6379", nil) + + // read the model from file + data, err := ioutil.ReadFile("./../tests/test_data/graph.pb") + + // set the model to RedisAI + err = client.ModelSet("example-info", redisai.BackendTF, redisai.DeviceCPU, data, []string{"a", "b"}, []string{"mul"}) + // print the error (should be ) + fmt.Println(err) + + // set the input tensors + err = client.TensorSet("a", redisai.TypeFloat32, []int64{1}, []float32{1.1}) + err = client.TensorSet("b", redisai.TypeFloat32, []int64{1}, []float32{4.4}) + + // run the model + err = client.ModelRun("example-info", []string{"a", "b"}, []string{"mul"}) + // print the error (should be ) + fmt.Println(err) + + // get the model run info + info, err := client.Info("example-info") + + // one model runs + fmt.Println(fmt.Sprintf("Total runs: %s", info["calls"])) + + // Output: + // + // + // Total runs: 1 +} \ No newline at end of file diff --git a/redisai/implementations/AIModel.go b/redisai/implementations/AIModel.go index be476e6..aaf7dd3 100644 --- a/redisai/implementations/AIModel.go +++ b/redisai/implementations/AIModel.go @@ -2,7 +2,7 @@ package implementations import "io/ioutil" -type AiModel struct { +type AIModel struct { backend string device string blob []byte @@ -10,55 +10,55 @@ type AiModel struct { outputs []string } -func (m *AiModel) Outputs() []string { +func (m *AIModel) Outputs() []string { return m.outputs } -func (m *AiModel) SetOutputs(outputs []string) { +func (m *AIModel) SetOutputs(outputs []string) { m.outputs = outputs } -func (m *AiModel) Inputs() []string { +func (m *AIModel) Inputs() []string { return m.inputs } -func (m *AiModel) SetInputs(inputs []string) { +func (m *AIModel) SetInputs(inputs []string) { m.inputs = inputs } -func (m *AiModel) Blob() []byte { +func (m *AIModel) Blob() []byte { return m.blob } -func (m *AiModel) SetBlob(blob []byte) { +func (m *AIModel) SetBlob(blob []byte) { m.blob = blob } -func (m *AiModel) Device() string { +func (m *AIModel) Device() string { return m.device } -func (m *AiModel) SetDevice(device string) { +func (m *AIModel) SetDevice(device string) { m.device = device } -func (m *AiModel) Backend() string { +func (m *AIModel) Backend() string { return m.backend } -func (m *AiModel) SetBackend(backend string) { +func (m *AIModel) SetBackend(backend string) { m.backend = backend } -func NewModel(backend string, device string) *AiModel { - return &AiModel{backend: backend, device: device} +func NewModel(backend string, device string) *AIModel { + return &AIModel{backend: backend, device: device} } -func NewEmptyModel() *AiModel { - return &AiModel{} +func NewEmptyModel() *AIModel { + return &AIModel{} } -func (m *AiModel) SetBlobFromFile(path string) (err error) { +func (m *AIModel) SetBlobFromFile(path string) (err error) { var data []byte data, err = ioutil.ReadFile(path) if err != nil { diff --git a/redisai/implementations/AITensor.go b/redisai/implementations/AITensor.go index 5dccb6f..088de6d 100644 --- a/redisai/implementations/AITensor.go +++ b/redisai/implementations/AITensor.go @@ -4,55 +4,55 @@ import "reflect" // TensorInterface is an interface that represents the skeleton of a tensor ( n-dimensional array of numerical data ) // needed to map it to a RedisAI Model with the proper operations -type AiTensor struct { +type AITensor struct { // the size - in each dimension - of the tensor. - shape []int + shape []int64 data interface{} } -func (t *AiTensor) Dtype() reflect.Type { +func (t *AITensor) Dtype() reflect.Type { return reflect.TypeOf(t.data) } -func NewAiTensor() *AiTensor { - return &AiTensor{} +func NewAiTensor() *AITensor { + return &AITensor{} } -func (t *AiTensor) NumDims() int { - return len(t.Shape()) +func (t *AITensor) NumDims() int64 { + return int64(len(t.Shape())) } -func (t *AiTensor) Len() int { - result := 0 +func (t *AITensor) Len() int64 { + var result int64 = 0 for _, v := range t.shape { result += v } return result } -func (m *AiTensor) Shape() []int { +func (m *AITensor) Shape() []int64 { return m.shape } -func (m *AiTensor) SetShape(shape []int) { +func (m *AITensor) SetShape(shape []int64) { m.shape = shape } -func NewAiTensorWithShape(shape []int) *AiTensor { - return &AiTensor{shape: shape} +func NewAiTensorWithShape(shape []int64) *AITensor { + return &AITensor{shape: shape} } -func NewAiTensorWithData(typestr string, shape []int, data interface{}) *AiTensor { +func NewAiTensorWithData(typestr string, shape []int64, data interface{}) *AITensor { tensor := NewAiTensorWithShape(shape) tensor.SetData(data) return tensor } -func (m *AiTensor) SetData(data interface{}) { +func (m *AITensor) SetData(data interface{}) { m.data = data } -func (m *AiTensor) Data() interface{} { +func (m *AITensor) Data() interface{} { return m.data } diff --git a/redisai/tensor.go b/redisai/tensor.go index d8599a0..64707e2 100644 --- a/redisai/tensor.go +++ b/redisai/tensor.go @@ -12,15 +12,15 @@ import ( type TensorInterface interface { // Shape returns the size - in each dimension - of the tensor. - Shape() []int + Shape() []int64 - SetShape(shape []int) + SetShape(shape []int64) // NumDims returns the number of dimensions of the tensor. - NumDims() int + NumDims() int64 // Len returns the number of elements in the tensor. - Len() int + Len() int64 Dtype() reflect.Type @@ -66,7 +66,7 @@ func TensorGetTypeStrFromType(dtype reflect.Type) (typestr string, err error) { return } -func tensorSetFlatArgs(name string, dt string, dims []int, data interface{}) (redis.Args, error) { +func tensorSetFlatArgs(name string, dt string, dims []int64, data interface{}) (redis.Args, error) { args := redis.Args{} var err error = nil args = args.Add(name, dt).AddFlat(dims) @@ -148,7 +148,7 @@ func ProcessTensorReplyValues(dtype string, reply interface{}) (data interface{} return data, err } -func ProcessTensorGetReply(reply interface{}, errIn error) (err error, dtype string, shape []int, data interface{}) { +func ProcessTensorGetReply(reply interface{}, errIn error) (err error, dtype string, shape []int64, data interface{}) { var replySlice []interface{} var key string err = errIn @@ -168,7 +168,7 @@ func ProcessTensorGetReply(reply interface{}, errIn error) (err error, dtype str return } case "shape": - shape, err = redis.Ints(replySlice[pos+1], err) + shape, err = redis.Int64s(replySlice[pos+1], err) if err != nil { return } diff --git a/redisai/tensor_test.go b/redisai/tensor_test.go index 3e797a5..daff229 100644 --- a/redisai/tensor_test.go +++ b/redisai/tensor_test.go @@ -14,7 +14,7 @@ func Test_tensorSetFlatArgs(t *testing.T) { type args struct { name string dt string - dims []int + dims []int64 data interface{} } tests := []struct { @@ -23,18 +23,18 @@ func Test_tensorSetFlatArgs(t *testing.T) { want string wantErr bool }{ - {"test:TestTensorSetArgs:[]float32:1", args{"test:TestTensorSetArgs:1", TypeFloat, []int{1}, []float32{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]byte:1", args{"test:TestTensorSetArgs:1", TypeFloat, []int{1}, f32Bytes}, string(TensorContentTypeBlob), false}, - {"test:TestTensorSetArgs:[]int:1", args{"test:TestTensorSetArgs:1", TypeInt32, []int{1}, []int{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]int8:1", args{"test:TestTensorSetArgs:1", TypeInt8, []int{1}, []int8{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]int16:1", args{"test:TestTensorSetArgs:1", TypeInt16, []int{1}, []int16{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]int64:1", args{"test:TestTensorSetArgs:1", TypeInt64, []int{1}, []int64{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]uint8:1", args{"test:TestTensorSetArgs:1", TypeUint8, []int{1}, []uint8{1}}, string(TensorContentTypeBlob), false}, - {"test:TestTensorSetArgs:[]uint16:1", args{"test:TestTensorSetArgs:1", TypeUint16, []int{1}, []uint16{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]uint32:1", args{"test:TestTensorSetArgs:1", TypeUint8, []int{1}, []uint32{1}}, string(TensorContentTypeBlob), true}, - {"test:TestTensorSetArgs:[]uint64:1", args{"test:TestTensorSetArgs:1", TypeUint16, []int{1}, []uint64{1}}, string(TensorContentTypeValues), true}, - {"test:TestTensorSetArgs:[]float32:1", args{"test:TestTensorSetArgs:1", TypeFloat32, []int{1}, []float32{1}}, string(TensorContentTypeValues), false}, - {"test:TestTensorSetArgs:[]float64:1", args{"test:TestTensorSetArgs:1", TypeFloat64, []int{1}, []float64{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]float32:1", args{"test:TestTensorSetArgs:1", TypeFloat, []int64{1}, []float32{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]byte:1", args{"test:TestTensorSetArgs:1", TypeFloat, []int64{1}, f32Bytes}, string(TensorContentTypeBlob), false}, + {"test:TestTensorSetArgs:[]int:1", args{"test:TestTensorSetArgs:1", TypeInt32, []int64{1}, []int64{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]int8:1", args{"test:TestTensorSetArgs:1", TypeInt8, []int64{1}, []int8{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]int16:1", args{"test:TestTensorSetArgs:1", TypeInt16, []int64{1}, []int16{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]int64:1", args{"test:TestTensorSetArgs:1", TypeInt64, []int64{1}, []int64{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]uint8:1", args{"test:TestTensorSetArgs:1", TypeUint8, []int64{1}, []uint8{1}}, string(TensorContentTypeBlob), false}, + {"test:TestTensorSetArgs:[]uint16:1", args{"test:TestTensorSetArgs:1", TypeUint16, []int64{1}, []uint16{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]uint32:1", args{"test:TestTensorSetArgs:1", TypeUint8, []int64{1}, []uint32{1}}, string(TensorContentTypeBlob), true}, + {"test:TestTensorSetArgs:[]uint64:1", args{"test:TestTensorSetArgs:1", TypeUint16, []int64{1}, []uint64{1}}, string(TensorContentTypeValues), true}, + {"test:TestTensorSetArgs:[]float32:1", args{"test:TestTensorSetArgs:1", TypeFloat32, []int64{1}, []float32{1}}, string(TensorContentTypeValues), false}, + {"test:TestTensorSetArgs:[]float64:1", args{"test:TestTensorSetArgs:1", TypeFloat64, []int64{1}, []float64{1}}, string(TensorContentTypeValues), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -99,7 +99,7 @@ func TestProcessTensorGetReply(t *testing.T) { name string args args wantDtype string - wantShape []int + wantShape []int64 wantData interface{} wantErr bool }{