Skip to content

Commit 78b014b

Browse files
authored
Feat/get input tensors (#172)
* feat: flag for console output * feat: GetInputTensors * fix: test and bug fix * feat: bump version * feat: more tests
1 parent c21037c commit 78b014b

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

doc/introduction/utils/draw.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"encoding/base64"
5+
"flag"
56
"fmt"
67
"image"
78
"image/color"
@@ -20,8 +21,13 @@ var (
2021
func main() {
2122
reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(img8))
2223
im, _, _ = image.Decode(reader)
23-
//outputConsole()
24-
outputValues()
24+
console := flag.Bool("c", false, "console output")
25+
flag.Parse()
26+
if *console {
27+
outputConsole()
28+
} else {
29+
outputValues()
30+
}
2531
}
2632

2733
func outputConsole() {

go.sum

+3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8
5858
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
5959
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
6060
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
61+
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU=
6162
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
6263
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81 h1:00VmoueYNlNz/aHIilyyQz/MHSqGoWJzpFv/HW8xpzI=
6364
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
@@ -81,9 +82,11 @@ gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW
8182
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM=
8283
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU=
8384
gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
85+
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
8486
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
8587
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
8688
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
89+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
8790
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
8891
gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
8992
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

io.go

+13
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,16 @@ func (m *Model) GetOutputTensors() ([]tensor.Tensor, error) {
3838
}
3939
return output, nil
4040
}
41+
42+
// GetInpuTensors from the graph. This function is useful to get informations if the tensor is a placeholder
43+
// and does not contain any data yet.
44+
func (m *Model) GetInputTensors() []tensor.Tensor {
45+
output := make([]tensor.Tensor, len(m.Input))
46+
for i := range m.Input {
47+
n := m.backend.Node(int64(m.Input[i]))
48+
if n != nil {
49+
output[i] = n.(DataCarrier).GetTensor()
50+
}
51+
}
52+
return output
53+
}

io_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,23 @@ func TestSetInput_nil_model(t *testing.T) {
1717
err := m.SetInput(0, tens)
1818
t.Fatal("should have paniced but have passed with error", err)
1919
}
20+
21+
func TestGetInputTensors(t *testing.T) {
22+
backend := newTestBackend()
23+
n1 := backend.NewNode()
24+
backend.AddNode(n1)
25+
n2 := backend.NewNode()
26+
backend.AddNode(n2)
27+
n2.(*nodeTest).SetTensor(tensor.NewDense(tensor.Float32, []int{1, 1}))
28+
model := &Model{
29+
Input: []int64{n1.ID(), n2.ID()},
30+
backend: backend,
31+
}
32+
input := model.GetInputTensors()
33+
if len(input) != 2 {
34+
t.FailNow()
35+
}
36+
if input[0] != nil || input[1] == nil {
37+
t.Fail()
38+
}
39+
}

0 commit comments

Comments
 (0)