Skip to content

Commit 751f731

Browse files
committed
Major refactoring inspired by Chad
1 parent 12e8b10 commit 751f731

20 files changed

+1132
-3283
lines changed

array.go

+33-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ package cu
22

33
// #include <cuda.h>
44
import "C"
5-
import "unsafe"
5+
import (
6+
"unsafe"
7+
8+
"github.com/pkg/errors"
9+
)
610

711
// Format is the type of array (think array types)
812
type Format byte
@@ -18,10 +22,11 @@ const (
1822
Float32 Format = C.CU_AD_FORMAT_FLOAT // 32-bit floating point
1923
)
2024

21-
// Array is the pointer to a CUDA array. The name is a bit of a misnomer, as it would lead one to imply that it's rangeable. It's not.
25+
// Array is the pointer to a CUDA array. The name is a bit of a misnomer,
26+
// as it would lead one to imply that it's rangeable. It's not.
2227
type Array uintptr
2328

24-
func (arr Array) cuda() C.CUarray {
29+
func (arr Array) c() C.CUarray {
2530
return *(*C.CUarray)(unsafe.Pointer(uintptr(arr)))
2631
}
2732

@@ -129,3 +134,28 @@ func goArrayDesc(desc *C.CUDA_ARRAY_DESCRIPTOR) ArrayDesc {
129134
NumChannels: uint(ad.NumChannels),
130135
}
131136
}
137+
138+
// Descriptor3 get a 3D CUDA array descriptor
139+
func (arr Array) Descriptor3() (Array3Desc, error) {
140+
hArray := arr.c()
141+
var desc C.CUDA_ARRAY3D_DESCRIPTOR
142+
if err := result(C.cuArray3DGetDescriptor(&desc, hArray)); err != nil {
143+
return Array3Desc{}, errors.Wrapf(err, "Array3DGetDescriptor")
144+
}
145+
return goArray3Desc(&desc), nil
146+
}
147+
148+
// Descriptor gets a 1D or 2D CUDA array descriptor
149+
func (arr Array) Descriptor() (ArrayDesc, error) {
150+
hArray := arr.c()
151+
var desc C.CUDA_ARRAY_DESCRIPTOR
152+
if err := result(C.cuArrayGetDescriptor(&desc, hArray)); err != nil {
153+
return ArrayDesc{}, errors.Wrapf(err, "ArrayGetDescriptor")
154+
}
155+
return goArrayDesc(&desc), nil
156+
157+
}
158+
159+
func cuArrayToArray(arr *C.CUarray) Array {
160+
return Array(uintptr(unsafe.Pointer(arr)))
161+
}

batch.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ CUresult processFn(fnargs_t* args){
5757
abort();
5858
break;
5959
case fn_mallocManaged:
60-
abort();
60+
ret = cuMemAllocManaged(&args->devPtr0, args->size, CU_MEM_ATTACH_GLOBAL);
6161
break;
6262
case fn_memfreeD:
6363
// fprintf(stderr, "memfree %p\n", args->devPtr0);

batch.go

+17-6
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func (fn *fnargs) c() C.uintptr_t {
133133
// For the moment, BatchedContext only supports a limited number of CUDA Runtime APIs.
134134
// Feel free to send a pull request with more APIs.
135135
type BatchedContext struct {
136-
Context
136+
CUContext
137137
Device
138138

139139
workAvailable chan struct{} // an empty struct is sent down workAvailable when there is work
@@ -149,10 +149,10 @@ type BatchedContext struct {
149149
}
150150

151151
// NewBatchedContext creates a batched CUDA context.
152-
func NewBatchedContext(c Context, d Device) *BatchedContext {
152+
func NewBatchedContext(c CUContext, d Device) *BatchedContext {
153153
return &BatchedContext{
154-
Context: c,
155-
Device: d,
154+
CUContext: c,
155+
Device: d,
156156

157157
workAvailable: make(chan struct{}, 1),
158158
work: make(chan call, workBufLen),
@@ -232,7 +232,7 @@ func (ctx *BatchedContext) DoWork() {
232232
addQueueLength(len(ctx.queue))
233233
addBlockingCallers()
234234

235-
cctx := C.CUcontext(unsafe.Pointer(uintptr(ctx.Context)))
235+
cctx := C.CUcontext(unsafe.Pointer(uintptr(ctx.CUContext)))
236236
ctx.results = ctx.results[:cap(ctx.results)] // make sure of the maximum availability for ctx.results
237237
C.process(cctx, &ctx.fns[0], &ctx.results[0], C.int(len(ctx.queue))) // process the queue
238238
ctx.results = ctx.results[:len(ctx.queue)] // then truncate it to the len of queue for reporting purposes
@@ -252,6 +252,8 @@ func (ctx *BatchedContext) DoWork() {
252252
ctx.retVal <- DevicePtr(retVal.devptr0)
253253
case C.fn_mallocH:
254254
case C.fn_mallocManaged:
255+
retVal = (*fnargs)(unsafe.Pointer(uintptr(ctx.fns[len(ctx.fns)-1])))
256+
ctx.retVal <- DevicePtr(retVal.devptr0)
255257
case C.fn_allocAndCopy:
256258
retVal = (*fnargs)(unsafe.Pointer(uintptr(ctx.fns[len(ctx.fns)-1])))
257259
ctx.retVal <- DevicePtr(retVal.devptr0)
@@ -295,7 +297,7 @@ func (ctx *BatchedContext) FirstError() error {
295297
func (ctx *BatchedContext) SetCurrent() {
296298
fn := &fnargs{
297299
fn: C.fn_setCurrent,
298-
ctx: C.CUcontext(unsafe.Pointer(uintptr(ctx.Context))),
300+
ctx: C.CUcontext(unsafe.Pointer(uintptr(ctx.CUContext))),
299301
}
300302
c := call{fn, false}
301303
ctx.enqueue(c)
@@ -311,6 +313,15 @@ func (ctx *BatchedContext) MemAlloc(bytesize int64) (retVal DevicePtr, err error
311313
return ctx.enqueue(c)
312314
}
313315

316+
func (ctx *BatchedContext) MemAllocManaged(bytesize int64, flags MemAttachFlags) (retVal DevicePtr, err error) {
317+
fn := &fnargs{
318+
fn: C.fn_mallocManaged,
319+
size: C.size_t(bytesize),
320+
}
321+
c := call{fn, true}
322+
return ctx.enqueue(c)
323+
}
324+
314325
func (ctx *BatchedContext) Memcpy(dst, src DevicePtr, byteCount int64) {
315326
// pc, _, _, _ := runtime.Caller(1)
316327
// logf("Memcpy %v %v| called by %v", dst, src, runtime.FuncForPC(pc).Name())

batch_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
func TestBatchContext(t *testing.T) {
1010
var err error
1111
var dev Device
12-
var ctx Context
12+
var ctx CUContext
1313
var mod Module
1414
var fn Function
1515

@@ -97,7 +97,7 @@ loop:
9797
func TestLargeBatch(t *testing.T) {
9898
var err error
9999
var dev Device
100-
var ctx Context
100+
var ctx CUContext
101101
var mod Module
102102
var fn Function
103103

@@ -210,7 +210,7 @@ func BenchmarkNoBatching(bench *testing.B) {
210210
defer runtime.UnlockOSThread()
211211

212212
var err error
213-
var ctx Context
213+
var ctx CUContext
214214
var mod Module
215215
var fn Function
216216

@@ -289,7 +289,7 @@ func BenchmarkBatching(bench *testing.B) {
289289

290290
var err error
291291
var dev Device
292-
var ctx Context
292+
var ctx CUContext
293293
var mod Module
294294
var fn Function
295295

batchedPatterns_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88

99
func TestAttributes(t *testing.T) {
1010
var dev Device
11-
var ctx Context
11+
var ctx CUContext
1212
var err error
1313

1414
if dev, ctx, err = testSetup(); err != nil {
@@ -51,7 +51,7 @@ func TestAttributes(t *testing.T) {
5151

5252
func TestLaunchAndSync(t *testing.T) {
5353
var err error
54-
var ctx Context
54+
var ctx CUContext
5555
var mod Module
5656
var fn Function
5757

@@ -128,7 +128,7 @@ func TestLaunchAndSync(t *testing.T) {
128128

129129
func TestAllocAndCopy(t *testing.T) {
130130
var err error
131-
var ctx Context
131+
var ctx CUContext
132132
var mem DevicePtr
133133

134134
if _, ctx, err = testSetup(); err != nil {
File renamed without changes.

cublas/cmd/cublasgen/main.go cmd/gencublas/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ var (
2828
)
2929

3030
const (
31+
typ = "impl *Standalone"
3132
header = "cublasgen.h"
32-
typ = "impl *Implementation"
3333
prefix = "cublas"
3434
warning = "Float32 implementations are autogenerated and not directly tested."
3535
)

cublas/cmd/cublasgen/templates.go cmd/gencublas/templates.go

+20-20
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ type drotmParams struct {
3737
h [4]float64
3838
}
3939
40-
func (impl *Implementation) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) {
40+
func (impl *Standalone) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) {
4141
impl.e = status(C.cublasSrotg(C.cublasHandle_t(impl.h), (*C.float)(&a), (*C.float)(&b), (*C.float)(&c), (*C.float)(&s)))
4242
return c, s, a, b
4343
}
44-
func (impl *Implementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) {
44+
func (impl *Standalone) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) {
4545
if impl.e != nil {
4646
return
4747
}
@@ -50,7 +50,7 @@ func (impl *Implementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float3
5050
return blas.SrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
5151
}
5252
53-
func (impl *Implementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
53+
func (impl *Standalone) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
5454
if impl.e != nil {
5555
return
5656
}
@@ -83,15 +83,15 @@ func (impl *Implementation) Srotm(n int, x []float32, incX int, y []float32, inc
8383
impl.e = status(C.cublasSrotm(C.cublasHandle_t(impl.h), C.int(n), (*C.float)(&x[0]), C.int(incX), (*C.float)(&y[0]), C.int(incY), (*C.float)(unsafe.Pointer(&pi))))
8484
}
8585
86-
func (impl *Implementation) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) {
86+
func (impl *Standalone) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) {
8787
if impl.e != nil {
8888
return
8989
}
9090
impl.e = status(C.cublasDrotg(C.cublasHandle_t(impl.h), (*C.double)(&a), (*C.double)(&b), (*C.double)(&c), (*C.double)(&s)))
9191
return c, s, a, b
9292
}
9393
94-
func (impl *Implementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) {
94+
func (impl *Standalone) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) {
9595
if impl.e != nil {
9696
return
9797
}
@@ -100,7 +100,7 @@ func (impl *Implementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float6
100100
return blas.DrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
101101
}
102102
103-
func (impl *Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
103+
func (impl *Standalone) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
104104
if impl.e != nil {
105105
return
106106
}
@@ -132,7 +132,7 @@ func (impl *Implementation) Drotm(n int, x []float64, incX int, y []float64, inc
132132
impl.e = status(C.cublasDrotm(C.cublasHandle_t(impl.h), C.int(n), (*C.double)(&x[0]), C.int(incX), (*C.double)(&y[0]), C.int(incY), (*C.double)(unsafe.Pointer(&pi))))
133133
}
134134
135-
func (impl *Implementation) Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64) {
135+
func (impl *Standalone) Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64) {
136136
if impl.e != nil {
137137
return
138138
}
@@ -157,7 +157,7 @@ func (impl *Implementation) Cdotu(n int, x []complex64, incX int, y []complex64,
157157
impl.e = status(C.cublasCdotu(C.cublasHandle_t(impl.h), C.int(n), (*C.cuComplex)(unsafe.Pointer(&x[0])), C.int(incX), (*C.cuComplex)(unsafe.Pointer(&y[0])), C.int(incY), (*C.cuComplex)(unsafe.Pointer(&dotu))))
158158
return dotu
159159
}
160-
func (impl *Implementation) Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64) {
160+
func (impl *Standalone) Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64) {
161161
if impl.e != nil {
162162
return
163163
}
@@ -183,7 +183,7 @@ func (impl *Implementation) Cdotc(n int, x []complex64, incX int, y []complex64,
183183
impl.e = status(C.cublasCdotc(C.cublasHandle_t(impl.h), C.int(n), (*C.cuComplex)(unsafe.Pointer(&x[0])), C.int(incX), (*C.cuComplex)(unsafe.Pointer(&y[0])), C.int(incY), (*C.cuComplex)(unsafe.Pointer(&dotc))))
184184
return dotc
185185
}
186-
func (impl *Implementation) Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128) {
186+
func (impl *Standalone) Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128) {
187187
if impl.e != nil {
188188
return
189189
}
@@ -209,7 +209,7 @@ func (impl *Implementation) Zdotu(n int, x []complex128, incX int, y []complex12
209209
impl.e = status(C.cublasZdotu(C.cublasHandle_t(impl.h), C.int(n), (*C.cuDoubleComplex)(unsafe.Pointer(&x[0])), C.int(incX), (*C.cuDoubleComplex)(unsafe.Pointer(&y[0])), C.int(incY), (*C.cuDoubleComplex)(unsafe.Pointer(&dotu))))
210210
return dotu
211211
}
212-
func (impl *Implementation) Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128) {
212+
func (impl *Standalone) Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128) {
213213
if impl.e != nil {
214214
return
215215
}
@@ -235,27 +235,27 @@ func (impl *Implementation) Zdotc(n int, x []complex128, incX int, y []complex12
235235
return dotc
236236
}
237237
238-
func (impl *Implementation) Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32 {
238+
func (impl *Standalone) Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32 {
239239
panic("Unimplemented in cuBLAS. Please contact nvidia.")
240240
}
241241
242-
func (impl *Implementation) Dsdot(n int, x []float32, incX int, y []float32, incY int) float64 {
242+
func (impl *Standalone) Dsdot(n int, x []float32, incX int, y []float32, incY int) float64 {
243243
panic("Unimplemented in cuBLAS. Please contact nvidia.")
244244
}
245245
246-
func (impl *Implementation) Strmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int){
246+
func (impl *Standalone) Strmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int){
247247
panic("Unimplemented in cuBLAS. Please contact nvidia.")
248248
}
249249
250-
func (impl *Implementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int){
250+
func (impl *Standalone) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int){
251251
panic("Unimplemented in cuBLAS. Please contact nvidia.")
252252
}
253253
254-
func (impl *Implementation) Ctrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int){
254+
func (impl *Standalone) Ctrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int){
255255
panic("Unimplemented in cuBLAS. Please contact nvidia.")
256256
}
257257
258-
func (impl *Implementation) Ztrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int){
258+
func (impl *Standalone) Ztrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int){
259259
panic("Unimplemented in cuBLAS. Please contact nvidia.")
260260
}
261261
@@ -265,10 +265,10 @@ func (impl *Implementation) Ztrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose,
265265

266266
// TODO: complex scale
267267
const complexScales = `
268-
func (impl *Implementation) Cscal(n int, alpha complex64, x []complex64, incX int) {}
269-
func (impl *Implementation) Zscal(n int, alpha complex64, x []complex128, incX int){}
270-
func (impl *Implementation) Csscal(n int, alpha float32, x []complex64, incX int) {}
271-
func (impl *Implementation) Zsscal(n int, alpha float64, x []complex128, incX int){}
268+
func (impl *Standalone) Cscal(n int, alpha complex64, x []complex64, incX int) {}
269+
func (impl *Standalone) Zscal(n int, alpha complex64, x []complex128, incX int){}
270+
func (impl *Standalone) Csscal(n int, alpha float32, x []complex64, incX int) {}
271+
func (impl *Standalone) Zsscal(n int, alpha float64, x []complex128, incX int){}
272272
`
273273

274274
const amaxRaw = `

0 commit comments

Comments
 (0)