Skip to content

Commit a49599f

Browse files
authored
Merge pull request #36 from gorgonia/v0.9.0-working
Added tests for illustrating the failing case of #35
2 parents 64484d7 + 3c14e8f commit a49599f

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,6 @@ This author loves pull requests from everyone. Here's how to contribute to this
7777

7878
We understand that this package is an interfacing package with a third party API. As such, tests may not always be viable. However, please do try to include as much tests as possible.
7979

80+
8081
# Licence #
8182
The package is licenced with a MIT-like licence. Ther is one file (`cgoflags.go`) where code is directly copied and two files (`execution.go` and `memory.go`) where code was partially copied from Arne Vansteenkiste's package, which is unlicenced (but to be safe, just assume a GPL-like licence, as [mumax/3](https://github.com/mumax/3) is licenced under GPL).

blas/test_test.go

+163
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package cublas
22

33
import (
4+
"reflect"
5+
"unsafe"
6+
47
"github.com/pkg/errors"
8+
"gonum.org/v1/gonum/blas"
59
"gorgonia.org/cu"
10+
"gorgonia.org/tensor"
611
)
712

813
func testSetup() (dev cu.Device, err error) {
@@ -16,3 +21,161 @@ func testSetup() (dev cu.Device, err error) {
1621
dev = cu.Device(0)
1722
return
1823
}
24+
25+
type Engine struct {
26+
tensor.StdEng
27+
ctx cu.Context
28+
*Standard
29+
}
30+
31+
func newEngine() *Engine {
32+
ctx := cu.NewContext(cu.Device(0), cu.SchedAuto)
33+
blas := New(WithContext(ctx))
34+
return &Engine{
35+
ctx: ctx,
36+
Standard: blas,
37+
}
38+
}
39+
40+
func (e *Engine) AllocAccessible() bool { return true }
41+
42+
func (e *Engine) Alloc(size int64) (tensor.Memory, error) {
43+
return e.ctx.MemAllocManaged(size, cu.AttachGlobal)
44+
}
45+
46+
func (e *Engine) AllocFlags() (tensor.MemoryFlag, tensor.DataOrder) {
47+
return tensor.MakeMemoryFlag(tensor.ManuallyManaged), tensor.ColMajor
48+
}
49+
50+
func (e *Engine) Free(mem tensor.Memory, size int64) error {
51+
e.ctx.MemFree(mem.(cu.DevicePtr))
52+
return nil
53+
}
54+
55+
func (e *Engine) Memset(mem tensor.Memory, val interface{}) error {
56+
panic("not implemented")
57+
}
58+
59+
func (e *Engine) Memclr(mem tensor.Memory) {
60+
panic("not implemented")
61+
}
62+
63+
func (e *Engine) Memcpy(dst tensor.Memory, src tensor.Memory) error {
64+
panic("not implemented")
65+
}
66+
67+
func (e *Engine) Accessible(mem tensor.Memory) (tensor.Memory, error) {
68+
// panic("not implemented")
69+
size := mem.MemSize()
70+
retVal := make([]byte, int(size))
71+
e.ctx.MemcpyDtoH(unsafe.Pointer(&retVal[0]), cu.DevicePtr(mem.Uintptr()), int64(size))
72+
l := int(size / 8)
73+
foo2 := &reflect.SliceHeader{
74+
Data: uintptr(unsafe.Pointer(&retVal[0])),
75+
Len: l,
76+
Cap: l,
77+
}
78+
return *(*foomem)(unsafe.Pointer(foo2)), e.ctx.Error()
79+
80+
}
81+
82+
func (e *Engine) WorksWith(order tensor.DataOrder) bool { return true }
83+
84+
func (e *Engine) NonStdAlloc() {}
85+
86+
func (e *Engine) ContextErr() error { return e.ctx.Error() }
87+
88+
type foomem []float64
89+
90+
func (m foomem) Uintptr() uintptr { return uintptr(unsafe.Pointer(&m[0])) }
91+
func (m foomem) Pointer() unsafe.Pointer { return unsafe.Pointer(&m[0]) }
92+
func (m foomem) MemSize() uintptr { return uintptr(len(m) * 8) }
93+
94+
func (e *Engine) checkThreeFloat(a, b, ret tensor.Tensor) (ad, bd, retVal *tensor.Dense, err error) {
95+
if /*a.IsNativelyAccessible() &&*/ !a.IsManuallyManaged() {
96+
return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). a isn't.")
97+
}
98+
99+
if /* b.IsNativelyAccessible() && */ !b.IsManuallyManaged() {
100+
return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). b isn't")
101+
}
102+
103+
if /* ret.IsNativelyAccessible() && */ !ret.IsManuallyManaged() {
104+
return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). ret isn't")
105+
}
106+
107+
if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() {
108+
return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype")
109+
}
110+
var ok bool
111+
if ad, ok = a.(*tensor.Dense); !ok {
112+
return nil, nil, nil, errors.New("Expected a to be a *tensor.Dense")
113+
}
114+
if bd, ok = b.(*tensor.Dense); !ok {
115+
return nil, nil, nil, errors.New("Expected b to be a *tensor.Dense")
116+
}
117+
if retVal, ok = ret.(*tensor.Dense); !ok {
118+
return nil, nil, nil, errors.New("Expected ret to be a *tensor.Dense")
119+
}
120+
return
121+
}
122+
123+
func (e *Engine) MatVecMul(a, b, prealloc tensor.Tensor) (err error) {
124+
var ad, bd, pd *tensor.Dense = a.(*tensor.Dense), b.(*tensor.Dense), prealloc.(*tensor.Dense)
125+
126+
// if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil {
127+
// return errors.Wrapf(err, "MatVecMul failed pre check")
128+
// }
129+
130+
tA := blas.Trans
131+
do := a.DataOrder()
132+
z := do.IsTransposed()
133+
134+
m := a.Shape()[0]
135+
n := a.Shape()[1]
136+
137+
var lda int
138+
switch {
139+
case do.IsRowMajor() && z:
140+
tA = blas.NoTrans
141+
lda = m
142+
case do.IsRowMajor() && !z:
143+
lda = n
144+
m, n = n, m
145+
case do.IsColMajor() && z:
146+
tA = blas.Trans
147+
lda = n
148+
m, n = n, m
149+
case do.IsColMajor() && !z:
150+
lda = m
151+
tA = blas.NoTrans
152+
}
153+
154+
incX, incY := 1, 1 // step size
155+
156+
// ASPIRATIONAL TODO: different incX and incY
157+
// TECHNICAL DEBT. TECHDEBT. TECH DEBT
158+
// Example use case:
159+
// log.Printf("a %v %v", ad.Strides(), ad.ostrides())
160+
// log.Printf("b %v", b.Strides())
161+
// incX := a.Strides()[0]
162+
// incY = b.Strides()[0]
163+
164+
switch ad.Dtype() {
165+
case tensor.Float64:
166+
A := ad.Float64s()
167+
X := bd.Float64s()
168+
Y := pd.Float64s()
169+
alpha, beta := float64(1), float64(0)
170+
e.Standard.Dgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY)
171+
case tensor.Float32:
172+
A := ad.Float32s()
173+
X := bd.Float32s()
174+
Y := pd.Float32s()
175+
alpha, beta := float32(1), float32(0)
176+
e.Standard.Sgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY)
177+
default:
178+
return errors.New("Unsupported Dtype")
179+
}
180+
return e.Standard.Err()
181+
}

0 commit comments

Comments
 (0)