Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func main() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down Expand Up @@ -98,12 +98,12 @@ func main() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
err := client.TensorSet("foo1", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
err := client.TensorSet("foo1", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
if err != nil {
log.Fatal(err)
}
// AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1
err = client.TensorSet("foo2", redisai.TypeFloat, []int{1, 1}, []float32{1.1})
err = client.TensorSet("foo2", redisai.TypeFloat, []int64{1, 1}, []float32{1.1})
if err != nil {
log.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/redisai_pipelined_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ func main() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
err := client.TensorSet("foo1", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
err := client.TensorSet("foo1", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
if err != nil {
log.Fatal(err)
}
// AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1
err = client.TensorSet("foo2", redisai.TypeFloat, []int{1, 1}, []float32{1.1})
err = client.TensorSet("foo2", redisai.TypeFloat, []int64{1, 1}, []float32{1.1})
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/redisai_simple_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func main() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down
2 changes: 1 addition & 1 deletion examples/redisai_tls_client/redisai_tls_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func main() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down
4 changes: 2 additions & 2 deletions redisai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ func TestClient_DisablePipeline(t *testing.T) {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
err := client.TensorSet("foo1", TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
err := client.TensorSet("foo1", TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
if err != nil {
t.Errorf("TensorSet() error = %v", err)
}
// AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1
err = client.TensorSet("foo2", TypeFloat, []int{1, 1}, []float32{1.1})
err = client.TensorSet("foo2", TypeFloat, []int64{1, 1}, []float32{1.1})
if err != nil {
t.Errorf("TensorSet() error = %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions redisai/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// TensorSet sets a tensor
func (c *Client) TensorSet(keyName, dt string, dims []int, data interface{}) (err error) {
func (c *Client) TensorSet(keyName, dt string, dims []int64, data interface{}) (err error) {
args, err := tensorSetFlatArgs(keyName, dt, dims, data)
_, err = c.DoOrSend("AI.TENSORSET", args, err)
return
Expand Down Expand Up @@ -45,7 +45,7 @@ func (c *Client) TensorGetToTensor(name, format string, tensor TensorInterface)
}

// TensorGetValues gets a tensor's values
func (c *Client) TensorGetValues(name string) (dt string, shape []int, data interface{}, err error) {
func (c *Client) TensorGetValues(name string) (dt string, shape []int64, data interface{}, err error) {
args := redis.Args{}.Add(name, TensorContentTypeMeta, TensorContentTypeValues)
var reply interface{}
reply, err = c.DoOrSend("AI.TENSORGET", args, nil)
Expand All @@ -57,7 +57,7 @@ func (c *Client) TensorGetValues(name string) (dt string, shape []int, data inte
}

// TensorGetValues gets a tensor's values
func (c *Client) TensorGetMeta(name string) (dt string, shape []int, err error) {
func (c *Client) TensorGetMeta(name string) (dt string, shape []int64, err error) {
args := redis.Args{}.Add(name, TensorContentTypeMeta)
var reply interface{}
reply, err = c.DoOrSend("AI.TENSORGET", args, nil)
Expand All @@ -69,7 +69,7 @@ func (c *Client) TensorGetMeta(name string) (dt string, shape []int, err error)
}

// TensorGetValues gets a tensor's values
func (c *Client) TensorGetBlob(name string) (dt string, shape []int, data []byte, err error) {
func (c *Client) TensorGetBlob(name string) (dt string, shape []int64, data []byte, err error) {
args := redis.Args{}.Add(name, TensorContentTypeMeta, TensorContentTypeBlob)
var reply interface{}
reply, err = c.DoOrSend("AI.TENSORGET", args, nil)
Expand Down
97 changes: 67 additions & 30 deletions redisai/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestCommand_TensorSet(t *testing.T) {

valuesInt8 := []int8{1}
valuesInt16 := []int16{1}
valuesInt32 := []int{1}
valuesInt32 := []int64{1}
valuesInt64 := []int64{1}

valuesUint8 := []uint8{1}
Expand Down Expand Up @@ -46,12 +46,12 @@ func TestCommand_TensorSet(t *testing.T) {

keyError1 := "test:TestCommand_TensorSet:1:FaultyDims"

shp := []int{1}
shp := []int64{1}

type args struct {
name string
dt string
dims []int
dims []int64
data interface{}
}

Expand All @@ -78,7 +78,7 @@ func TestCommand_TensorSet(t *testing.T) {

{keyByte, args{keyByte, TypeUint8, shp, valuesByte}, false},

{keyError1, args{keyError1, TypeFloat, []int{1, 10}, []float32{1}}, true},
{keyError1, args{keyError1, TypeFloat, []int64{1, 10}, []float32{1}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -91,7 +91,7 @@ func TestCommand_TensorSet(t *testing.T) {
}

func TestCommand_FullFromTensor(t *testing.T) {
tensor := implementations.NewAiTensorWithShape([]int{1})
tensor := implementations.NewAiTensorWithShape([]int64{1})
tensor.SetData([]float32{1.0})
client := createTestClient()
err := client.TensorSetFromTensor("tensor1", tensor)
Expand All @@ -101,7 +101,7 @@ func TestCommand_FullFromTensor(t *testing.T) {
if diff := cmp.Diff(TypeFloat32, gotResp[0]); diff != "" {
t.Errorf("TestCommand_FullFromTensor() mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([]int{1}, gotResp[1]); diff != "" {
if diff := cmp.Diff([]int64{1}, gotResp[1]); diff != "" {
t.Errorf("TestCommand_FullFromTensor() mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([]float32{1.0}, gotResp[2]); diff != "" {
Expand Down Expand Up @@ -149,8 +149,8 @@ func TestCommand_TensorGet(t *testing.T) {
keyUint8 := "test:TensorGet:TypeUint8:1"
keyUint16 := "test:TensorGet:TypeUint16:1"

shp := []int{1}
shpByteSlice := []int{1, 5}
shp := []int64{1}
shpByteSlice := []int64{1, 5}
simpleClient := createTestClient()
simpleClient.TensorSet(keyByteSlice, TypeUint8, shpByteSlice, valuesByteSlice)

Expand All @@ -173,7 +173,7 @@ func TestCommand_TensorGet(t *testing.T) {
name string
args args
wantDt string
wantShape []int
wantShape []int64
wantData interface{}
compareDt bool
compareShape bool
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestCommand_TensorGetBlob(t *testing.T) {
keyByte := "test:TensorGetBlog:[]byte:1"
keyUnexistant := "test:TensorGetMeta:Unexistant"

shp := []int{1, 4}
shp := []int64{1, 4}
simpleClient := Connect("", createPool())
simpleClient.TensorSet(keyByte, TypeInt8, shp, valuesByte)

Expand All @@ -233,7 +233,7 @@ func TestCommand_TensorGetBlob(t *testing.T) {
name string
args args
wantDt string
wantShape []int
wantShape []int64
wantData []byte
wantErr bool
}{
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestCommand_TensorGetMeta(t *testing.T) {

keyUnexistant := "test:TensorGetMeta:Unexistant"

shp := []int{1, 2}
shp := []int64{1, 2}
simpleClient := Connect("", createPool())
simpleClient.TensorSet(keyFloat32, TypeFloat32, shp, nil)
simpleClient.TensorSet(keyFloat64, TypeFloat64, shp, nil)
Expand All @@ -299,7 +299,7 @@ func TestCommand_TensorGetMeta(t *testing.T) {
name string
args args
wantDt string
wantShape []int
wantShape []int64
wantErr bool
}{
{keyFloat32, args{keyFloat32}, TypeFloat32, shp, false},
Expand Down Expand Up @@ -354,8 +354,8 @@ func TestCommand_TensorGetValues(t *testing.T) {
keyUint16 := "test:TensorGetValues:TypeUint16:1"
keyUnexistant := "test:TensorGetValues:Unexistant"

shp := []int{1}
shp2 := []int{1, 5}
shp := []int64{1}
shp2 := []int64{1, 5}
simpleClient := Connect("", createPool())
simpleClient.TensorSet(keyFloat32, TypeFloat32, shp2, valuesFloat32)
simpleClient.TensorSet(keyFloat64, TypeFloat64, shp, valuesFloat64)
Expand All @@ -375,7 +375,7 @@ func TestCommand_TensorGetValues(t *testing.T) {
name string
args args
wantDt string
wantShape []int
wantShape []int64
wantData interface{}
wantErr bool
}{
Expand Down Expand Up @@ -574,12 +574,12 @@ func TestCommand_ModelRun(t *testing.T) {
return
}

errortset := simpleClient.TensorSet(keyTransaction1, TypeFloat, []int{1, 30}, nil)
errortset := simpleClient.TensorSet(keyTransaction1, TypeFloat, []int64{1, 30}, nil)
if errortset != nil {
t.Error(errortset)
}

errortsetReference := simpleClient.TensorSet(keyReference1, TypeFloat, []int{256}, nil)
errortsetReference := simpleClient.TensorSet(keyReference1, TypeFloat, []int64{256}, nil)
if errortsetReference != nil {
t.Error(errortsetReference)
}
Expand Down Expand Up @@ -848,9 +848,9 @@ func TestCommand_Info(t *testing.T) {
assert.Equal(t, BackendTF, info["backend"])
assert.Equal(t, "0", info["calls"])

err = c.TensorSet("a", TypeFloat32, []int{1}, []float32{1.1})
err = c.TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1})
assert.Nil(t, err)
err = c.TensorSet("b", TypeFloat32, []int{1}, []float32{4.4})
err = c.TensorSet("b", TypeFloat32, []int64{1}, []float32{4.4})
assert.Nil(t, err)
err = c.ModelRun(keyModel1, []string{"a", "b"}, []string{"mul"})
assert.Nil(t, err)
Expand Down Expand Up @@ -878,7 +878,7 @@ func TestCommand_DagRun(t *testing.T) {
return
}
err = c.ModelSet(keyModel1, BackendTF, DeviceCPU, data, []string{"a", "b"}, []string{"mul"})
err = c.TensorSet("persisted_tensor_1", TypeFloat32, []int{1, 2}, []float32{5, 10})
err = c.TensorSet("persisted_tensor_1", TypeFloat32, []int64{1, 2}, []float32{5, 10})
assert.Nil(t, err)

type args struct {
Expand All @@ -891,12 +891,12 @@ func TestCommand_DagRun(t *testing.T) {
args args
wantErr bool
}{
{"t_wrong_number", args{[]string{"notnumber"}, nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
{"t_load", args{[]string{"persisted_tensor_1"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, false},
{"t_load_err", args{[]string{"not_exits_tensor"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
{"t1", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1})}, false},
{"t_blob", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1}).TensorSet("b", TypeFloat32, []int{1}, []float32{4.4}).ModelRun("test:DagRun:mymodel:1", []string{"a", "b"}, []string{"mul"}).TensorGet("mul", TensorContentTypeBlob)}, false},
{"t_values", args{nil, nil, NewDag().TensorSet("mytensor", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorGet("mytensor", TensorContentTypeValues)}, false},
{"t_wrong_number", args{[]string{"notnumber"}, nil, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true},
{"t_load", args{[]string{"persisted_tensor_1"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, false},
{"t_load_err", args{[]string{"not_exits_tensor"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true},
{"t1", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1})}, false},
{"t_blob", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int64{1}, []float32{1.1}).TensorSet("b", TypeFloat32, []int64{1}, []float32{4.4}).ModelRun("test:DagRun:mymodel:1", []string{"a", "b"}, []string{"mul"}).TensorGet("mul", TensorContentTypeBlob)}, false},
{"t_values", args{nil, nil, NewDag().TensorSet("mytensor", TypeFloat32, []int64{1, 2}, []int64{5, 10}).TensorGet("mytensor", TensorContentTypeValues)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -932,7 +932,7 @@ func TestCommand_DagRun(t *testing.T) {

func TestCommand_DagRunRO(t *testing.T) {
c := createTestClient()
err := c.TensorSet("persisted_tensor", TypeFloat32, []int{1, 2}, []float32{5, 10})
err := c.TensorSet("persisted_tensor", TypeFloat32, []int64{1, 2}, []float32{5, 10})
assert.Nil(t, err)
type args struct {
loadKeys []string
Expand All @@ -944,8 +944,8 @@ func TestCommand_DagRunRO(t *testing.T) {
wantErr bool
}{
{"t_1", args{[]string{"persisted_tensor"}, NewDag().TensorGet("persisted_tensor", TensorContentTypeValues)}, false},
{"t_2", args{nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorSet("tensor2", TypeFloat32, []int{1, 2}, []int{5, 10})}, false},
{"t_err1", args{[]string{"notnumber"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
{"t_2", args{nil, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10}).TensorSet("tensor2", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, false},
{"t_err1", args{[]string{"notnumber"}, NewDag().TensorSet("tensor1", TypeFloat32, []int64{1, 2}, []int64{5, 10})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -978,3 +978,40 @@ func TestCommand_DagRunRO(t *testing.T) {
})
}
}

func TestClient_ModelRun(t *testing.T) {
type fields struct {
Pool *redis.Pool
PipelineActive bool
PipelineAutoFlushSize uint32
PipelinePos uint32
ActiveConn redis.Conn
}
type args struct {
name string
inputTensorNames []string
outputTensorNames []string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
Pool: tt.fields.Pool,
PipelineActive: tt.fields.PipelineActive,
PipelineAutoFlushSize: tt.fields.PipelineAutoFlushSize,
PipelinePos: tt.fields.PipelinePos,
ActiveConn: tt.fields.ActiveConn,
}
if err := c.ModelRun(tt.args.name, tt.args.inputTensorNames, tt.args.outputTensorNames); (err != nil) != tt.wantErr {
t.Errorf("ModelRun() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
4 changes: 2 additions & 2 deletions redisai/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "github.com/gomodule/redigo/redis"
// DagCommandInterface is an interface that represents the skeleton of DAG supported commands
// needed to map it to a RedisAI DAGRUN and DAGURN_RO commands
type DagCommandInterface interface {
TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface
TensorSet(keyName, dt string, dims []int64, data interface{}) DagCommandInterface
TensorGet(name, format string) DagCommandInterface
ModelRun(name string, inputTensorNames, outputTensorNames []string) DagCommandInterface
FlatArgs() (redis.Args, error)
Expand All @@ -22,7 +22,7 @@ func NewDag() *Dag {
}
}

func (d *Dag) TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface {
func (d *Dag) TensorSet(keyName, dt string, dims []int64, data interface{}) DagCommandInterface {
args := redis.Args{"AI.TENSORSET"}
setFlatArgs, err := tensorSetFlatArgs(keyName, dt, dims, data)
if err == nil {
Expand Down
6 changes: 3 additions & 3 deletions redisai/example_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func ExampleConnect() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down Expand Up @@ -49,7 +49,7 @@ func ExampleConnect_pool() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down Expand Up @@ -120,7 +120,7 @@ func ExampleConnect_ssl() {

// Set a tensor
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
_ = client.TensorSet("foo", redisai.TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
_ = client.TensorSet("foo", redisai.TypeFloat, []int64{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})

// Get a tensor content as a slice of values
// dt DataType, shape []int, data interface{}, err error
Expand Down
Loading