1
1
package cublas
2
2
3
3
import (
4
+ "reflect"
5
+ "unsafe"
6
+
4
7
"github.com/pkg/errors"
8
+ "gonum.org/v1/gonum/blas"
5
9
"gorgonia.org/cu"
10
+ "gorgonia.org/tensor"
6
11
)
7
12
8
13
func testSetup () (dev cu.Device , err error ) {
@@ -16,3 +21,161 @@ func testSetup() (dev cu.Device, err error) {
16
21
dev = cu .Device (0 )
17
22
return
18
23
}
24
+
25
+ type Engine struct {
26
+ tensor.StdEng
27
+ ctx cu.Context
28
+ * Standard
29
+ }
30
+
31
+ func newEngine () * Engine {
32
+ ctx := cu .NewContext (cu .Device (0 ), cu .SchedAuto )
33
+ blas := New (WithContext (ctx ))
34
+ return & Engine {
35
+ ctx : ctx ,
36
+ Standard : blas ,
37
+ }
38
+ }
39
+
40
+ func (e * Engine ) AllocAccessible () bool { return true }
41
+
42
+ func (e * Engine ) Alloc (size int64 ) (tensor.Memory , error ) {
43
+ return e .ctx .MemAllocManaged (size , cu .AttachGlobal )
44
+ }
45
+
46
+ func (e * Engine ) AllocFlags () (tensor.MemoryFlag , tensor.DataOrder ) {
47
+ return tensor .MakeMemoryFlag (tensor .ManuallyManaged ), tensor .ColMajor
48
+ }
49
+
50
+ func (e * Engine ) Free (mem tensor.Memory , size int64 ) error {
51
+ e .ctx .MemFree (mem .(cu.DevicePtr ))
52
+ return nil
53
+ }
54
+
55
+ func (e * Engine ) Memset (mem tensor.Memory , val interface {}) error {
56
+ panic ("not implemented" )
57
+ }
58
+
59
+ func (e * Engine ) Memclr (mem tensor.Memory ) {
60
+ panic ("not implemented" )
61
+ }
62
+
63
+ func (e * Engine ) Memcpy (dst tensor.Memory , src tensor.Memory ) error {
64
+ panic ("not implemented" )
65
+ }
66
+
67
+ func (e * Engine ) Accessible (mem tensor.Memory ) (tensor.Memory , error ) {
68
+ // panic("not implemented")
69
+ size := mem .MemSize ()
70
+ retVal := make ([]byte , int (size ))
71
+ e .ctx .MemcpyDtoH (unsafe .Pointer (& retVal [0 ]), cu .DevicePtr (mem .Uintptr ()), int64 (size ))
72
+ l := int (size / 8 )
73
+ foo2 := & reflect.SliceHeader {
74
+ Data : uintptr (unsafe .Pointer (& retVal [0 ])),
75
+ Len : l ,
76
+ Cap : l ,
77
+ }
78
+ return * (* foomem )(unsafe .Pointer (foo2 )), e .ctx .Error ()
79
+
80
+ }
81
+
82
+ func (e * Engine ) WorksWith (order tensor.DataOrder ) bool { return true }
83
+
84
+ func (e * Engine ) NonStdAlloc () {}
85
+
86
+ func (e * Engine ) ContextErr () error { return e .ctx .Error () }
87
+
88
+ type foomem []float64
89
+
90
+ func (m foomem ) Uintptr () uintptr { return uintptr (unsafe .Pointer (& m [0 ])) }
91
+ func (m foomem ) Pointer () unsafe.Pointer { return unsafe .Pointer (& m [0 ]) }
92
+ func (m foomem ) MemSize () uintptr { return uintptr (len (m ) * 8 ) }
93
+
94
+ func (e * Engine ) checkThreeFloat (a , b , ret tensor.Tensor ) (ad , bd , retVal * tensor.Dense , err error ) {
95
+ if /*a.IsNativelyAccessible() &&*/ ! a .IsManuallyManaged () {
96
+ return nil , nil , nil , errors .New ("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). a isn't." )
97
+ }
98
+
99
+ if /* b.IsNativelyAccessible() && */ ! b .IsManuallyManaged () {
100
+ return nil , nil , nil , errors .New ("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). b isn't" )
101
+ }
102
+
103
+ if /* ret.IsNativelyAccessible() && */ ! ret .IsManuallyManaged () {
104
+ return nil , nil , nil , errors .New ("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). ret isn't" )
105
+ }
106
+
107
+ if a .Dtype () != b .Dtype () || b .Dtype () != ret .Dtype () {
108
+ return nil , nil , nil , errors .New ("Expected a and b and retVal all to have the same Dtype" )
109
+ }
110
+ var ok bool
111
+ if ad , ok = a .(* tensor.Dense ); ! ok {
112
+ return nil , nil , nil , errors .New ("Expected a to be a *tensor.Dense" )
113
+ }
114
+ if bd , ok = b .(* tensor.Dense ); ! ok {
115
+ return nil , nil , nil , errors .New ("Expected b to be a *tensor.Dense" )
116
+ }
117
+ if retVal , ok = ret .(* tensor.Dense ); ! ok {
118
+ return nil , nil , nil , errors .New ("Expected ret to be a *tensor.Dense" )
119
+ }
120
+ return
121
+ }
122
+
123
+ func (e * Engine ) MatVecMul (a , b , prealloc tensor.Tensor ) (err error ) {
124
+ var ad , bd , pd * tensor.Dense = a .(* tensor.Dense ), b .(* tensor.Dense ), prealloc .(* tensor.Dense )
125
+
126
+ // if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil {
127
+ // return errors.Wrapf(err, "MatVecMul failed pre check")
128
+ // }
129
+
130
+ tA := blas .Trans
131
+ do := a .DataOrder ()
132
+ z := do .IsTransposed ()
133
+
134
+ m := a .Shape ()[0 ]
135
+ n := a .Shape ()[1 ]
136
+
137
+ var lda int
138
+ switch {
139
+ case do .IsRowMajor () && z :
140
+ tA = blas .NoTrans
141
+ lda = m
142
+ case do .IsRowMajor () && ! z :
143
+ lda = n
144
+ m , n = n , m
145
+ case do .IsColMajor () && z :
146
+ tA = blas .Trans
147
+ lda = n
148
+ m , n = n , m
149
+ case do .IsColMajor () && ! z :
150
+ lda = m
151
+ tA = blas .NoTrans
152
+ }
153
+
154
+ incX , incY := 1 , 1 // step size
155
+
156
+ // ASPIRATIONAL TODO: different incX and incY
157
+ // TECHNICAL DEBT. TECHDEBT. TECH DEBT
158
+ // Example use case:
159
+ // log.Printf("a %v %v", ad.Strides(), ad.ostrides())
160
+ // log.Printf("b %v", b.Strides())
161
+ // incX := a.Strides()[0]
162
+ // incY = b.Strides()[0]
163
+
164
+ switch ad .Dtype () {
165
+ case tensor .Float64 :
166
+ A := ad .Float64s ()
167
+ X := bd .Float64s ()
168
+ Y := pd .Float64s ()
169
+ alpha , beta := float64 (1 ), float64 (0 )
170
+ e .Standard .Dgemv (tA , m , n , alpha , A , lda , X , incX , beta , Y , incY )
171
+ case tensor .Float32 :
172
+ A := ad .Float32s ()
173
+ X := bd .Float32s ()
174
+ Y := pd .Float32s ()
175
+ alpha , beta := float32 (1 ), float32 (0 )
176
+ e .Standard .Sgemv (tA , m , n , alpha , A , lda , X , incX , beta , Y , incY )
177
+ default :
178
+ return errors .New ("Unsupported Dtype" )
179
+ }
180
+ return e .Standard .Err ()
181
+ }
0 commit comments