Skip to content

Commit 3ad528e

Browse files
committed
Cleaned up debugging messages
Added Run() method for batched context for managed runs Added a final Synchronize() call to CUContext.Unlock() Added checks to ensure BatchedContext and Ctx are Contexts Fixed tests to conform to new paradigm
1 parent e530e8c commit 3ad528e

13 files changed

+231
-59
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ The work to fully represent the CUDA Driver API is a work in progress. At the mo
4646

4747
## Roadmap ##
4848

49-
* [ ] All texture, surface and graphics related API have an equivalent Go prototype.
49+
* [ ] Remaining API to be ported over
50+
* [x] All texture, surface and graphics related API have an equivalent Go prototype.
5051
* [x] Batching of common operations (see for example `Device.Attributes(...)`
5152
* [x] Generic queueing/batching of API calls (by some definition of generic)
5253

api.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ func MemFreeHost(p unsafe.Pointer) (err error) {
224224
return result(C.cuMemFreeHost(Cp))
225225
}
226226

227-
func MemAllocManaged(bytesize int64, flags uint) (dptr DevicePtr, err error) {
227+
func MemAllocManaged(bytesize int64, flags MemAttachFlags) (dptr DevicePtr, err error) {
228228
Cbytesize := C.size_t(bytesize)
229229
Cflags := C.uint(flags)
230230
var Cdptr C.CUdeviceptr

batch.go

+24-15
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,6 @@ func (ctx *BatchedContext) WorkAvailable() <-chan struct{} { return ctx.workAvai
195195
// DoWork waits for work to come in from the queue. If it's blocking, the entire queue will be processed immediately.
196196
// Otherwise it will be added to the batch queue.
197197
func (ctx *BatchedContext) DoWork() {
198-
// ctx.Lock()
199-
// defer ctx.Unlock()
200198
for {
201199
select {
202200
case w := <-ctx.work:
@@ -225,9 +223,7 @@ func (ctx *BatchedContext) DoWork() {
225223
}
226224

227225
// debug and instrumentation related stuff
228-
logf("GOING TO PROCESS")
229-
pc, _, _, _ := runtime.Caller(1)
230-
logf("Called by %v", runtime.FuncForPC(pc).Name())
226+
logCaller("DoWork()")
231227
logf(ctx.introspect())
232228
addQueueLength(len(ctx.queue))
233229
addBlockingCallers()
@@ -267,6 +263,29 @@ func (ctx *BatchedContext) DoWork() {
267263
}
268264
}
269265

266+
// Run manages the running of the BatchedContext. Because it's expected to run in a goroutine, an error channel is to be passed in
267+
func (ctx *BatchedContext) Run(errChan chan error) error {
268+
runtime.LockOSThread()
269+
for {
270+
select {
271+
case <-ctx.workAvailable:
272+
ctx.DoWork()
273+
if err := ctx.Errors(); err != nil {
274+
if errChan == nil {
275+
runtime.UnlockOSThread()
276+
return err
277+
}
278+
errChan <- err
279+
280+
}
281+
case w := <-ctx.Work():
282+
ctx.ErrChan() <- w()
283+
}
284+
}
285+
runtime.UnlockOSThread()
286+
return nil
287+
}
288+
270289
// Cleanup is the cleanup function. It cleans up all the ancilliary allocations that has happened for all the batched calls.
271290
// This method should be called when the context is done with - otherwise there'd be a lot of leaked memory.
272291
//
@@ -323,9 +342,6 @@ func (ctx *BatchedContext) MemAllocManaged(bytesize int64, flags MemAttachFlags)
323342
}
324343

325344
func (ctx *BatchedContext) Memcpy(dst, src DevicePtr, byteCount int64) {
326-
// pc, _, _, _ := runtime.Caller(1)
327-
// logf("Memcpy %v %v| called by %v", dst, src, runtime.FuncForPC(pc).Name())
328-
329345
fn := &fnargs{
330346
fn: C.fn_memcpy,
331347
devptr0: C.CUdeviceptr(dst),
@@ -337,8 +353,6 @@ func (ctx *BatchedContext) Memcpy(dst, src DevicePtr, byteCount int64) {
337353
}
338354

339355
func (ctx *BatchedContext) MemcpyHtoD(dst DevicePtr, src unsafe.Pointer, byteCount int64) {
340-
// logf("Memcpy H2D: 0x%v, %v", dst, src)
341-
// log.Printf("Memcpy H2D: 0x%v, %v", dst, src)
342356
fn := &fnargs{
343357
fn: C.fn_memcpyHtoD,
344358
devptr0: C.CUdeviceptr(dst),
@@ -350,9 +364,6 @@ func (ctx *BatchedContext) MemcpyHtoD(dst DevicePtr, src unsafe.Pointer, byteCou
350364
}
351365

352366
func (ctx *BatchedContext) MemcpyDtoH(dst unsafe.Pointer, src DevicePtr, byteCount int64) {
353-
// pc, _, _, _ := runtime.Caller(2)
354-
// log.Printf("MemcpyD2H %v %v| called by %v", dst, src, runtime.FuncForPC(pc).Name())
355-
// logf("Memcpy D2H: %v 0x%v", dst, src)
356367
fn := &fnargs{
357368
fn: C.fn_memcpyDtoH,
358369
devptr0: C.CUdeviceptr(src),
@@ -364,8 +375,6 @@ func (ctx *BatchedContext) MemcpyDtoH(dst unsafe.Pointer, src DevicePtr, byteCou
364375
}
365376

366377
func (ctx *BatchedContext) MemFree(mem DevicePtr) {
367-
// pc, _, _, _ := runtime.Caller(1)
368-
// logf("MEMFREE %v CALLED BY %v", mem, runtime.FuncForPC(pc).Name())
369378
fn := &fnargs{
370379
fn: C.fn_memfreeD,
371380
devptr0: C.CUdeviceptr(mem),

batch_test.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ loop:
8181
bctx.DoWork()
8282
case <-doneChan:
8383
break loop
84-
default:
8584
}
8685
}
86+
if err = Synchronize(); err != nil {
87+
t.Errorf("Failed to Sync %v", err)
88+
}
8789

8890
for _, v := range a {
8991
if v != float32(2) {
@@ -93,7 +95,7 @@ loop:
9395
}
9496

9597
Unload(mod)
96-
// DestroyContext(&ctx)
98+
DestroyContext(&cuctx)
9799
}
98100

99101
func TestLargeBatch(t *testing.T) {
@@ -138,9 +140,9 @@ func TestLargeBatch(t *testing.T) {
138140
}
139141
size := int64(len(a) * 4)
140142

141-
var frees []DevicePtr
142143
go func() {
143144
var memA, memB DevicePtr
145+
var frees []DevicePtr
144146

145147
for i := 0; i < 104729; i++ {
146148
if memA, err = bctx.AllocAndCopy(unsafe.Pointer(&a[0]), size); err != nil {
@@ -173,6 +175,7 @@ func TestLargeBatch(t *testing.T) {
173175

174176
bctx.MemcpyDtoH(unsafe.Pointer(&a[0]), memA, size)
175177
bctx.MemcpyDtoH(unsafe.Pointer(&b[0]), memB, size)
178+
log.Printf("Number of frees %v", len(frees))
176179
for _, free := range frees {
177180
bctx.MemFree(free)
178181
}
@@ -191,6 +194,11 @@ loop:
191194
}
192195
}
193196

197+
bctx.DoWork()
198+
if err = Synchronize(); err != nil {
199+
t.Errorf("Failed to Sync %v", err)
200+
}
201+
194202
for _, v := range a {
195203
if v != float32(2) {
196204
t.Errorf("Expected all values to be 2. %v", a)
@@ -201,12 +209,10 @@ loop:
201209
afterFree, _, _ := MemInfo()
202210

203211
if afterFree != beforeFree {
204-
t.Errorf("Before: Freemem: %v. After %v", beforeFree, afterFree)
212+
t.Errorf("Before: Freemem: %v. After %v | Diff %v", beforeFree, afterFree, (beforeFree-afterFree)/1024)
205213
}
206-
207214
Unload(mod)
208-
// DestroyContext(&ctx)
209-
215+
DestroyContext(&cuctx)
210216
}
211217

212218
func BenchmarkNoBatching(bench *testing.B) {
@@ -359,6 +365,6 @@ func BenchmarkBatching(bench *testing.B) {
359365
MemFree(memA)
360366
MemFree(memB)
361367
Unload(mod)
362-
// DestroyContext(&ctx)
368+
DestroyContext(&cuctx)
363369

364370
}

cmd/genlib/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,8 @@ The first line preprocesses all the macros, leaving a singular header file. The
5050
* `CurrentContext` - deleted
5151
* `CurrentDevice`
5252
* `CurrentFlags`
53+
* `CanAccessPeer` - deleted
54+
* `P2PAttribute` - deleted
55+
* `MemAllocManaged`
5356

5457
## Ctx related methods - manually written ##

context.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ package cu
22

33
// #include <cuda.h>
44
import "C"
5-
import (
6-
"sync"
7-
"unsafe"
8-
)
5+
import "unsafe"
96

10-
var contextLock = new(sync.Mutex)
11-
var pkgContext CUContext
7+
var (
8+
_ Context = &Ctx{}
9+
_ Context = &BatchedContext{}
10+
)
1211

1312
// Context interface. Typically you'd just embed *Ctx. Rarely do you need to use CUContext
1413
type Context interface {
@@ -17,15 +16,15 @@ type Context interface {
1716
Error() error
1817
Run(chan error) error
1918
Do(fn func() error) error
20-
Work() chan func() error
19+
Work() <-chan func() error
20+
ErrChan() chan<- error
2121

2222
// actual methods
2323
Address(hTexRef TexRef) (pdptr DevicePtr, err error)
2424
AddressMode(hTexRef TexRef, dim int) (pam AddressMode, err error)
2525
Array(hTexRef TexRef) (phArray Array, err error)
2626
AttachMemAsync(hStream Stream, dptr DevicePtr, length int64, flags uint)
2727
BorderColor(hTexRef TexRef) (pBorderColor [3]float32, err error)
28-
CanAccessPeer(dev Device, peerDev Device) (canAccessPeer int, err error)
2928
CurrentCacheConfig() (pconfig FuncCacheConfig, err error)
3029
CurrentDevice() (device Device, err error)
3130
CurrentFlags() (flags ContextFlags, err error)
@@ -49,7 +48,7 @@ type Context interface {
4948
MakeStreamWithPriority(priority int, flags StreamFlags) (stream Stream, err error)
5049
MaxAnisotropy(hTexRef TexRef) (pmaxAniso int, err error)
5150
MemAlloc(bytesize int64) (dptr DevicePtr, err error)
52-
MemAllocManaged(bytesize int64, flags uint) (dptr DevicePtr, err error)
51+
MemAllocManaged(bytesize int64, flags MemAttachFlags) (dptr DevicePtr, err error)
5352
MemAllocPitch(WidthInBytes int64, Height int64, ElementSizeBytes uint) (dptr DevicePtr, pPitch int64, err error)
5453
MemFree(dptr DevicePtr)
5554
MemFreeHost(p unsafe.Pointer)
@@ -92,7 +91,6 @@ type Context interface {
9291
MemsetD8Async(dstDevice DevicePtr, uc byte, N int64, hStream Stream)
9392
ModuleFunction(m Module, name string) (function Function, err error)
9493
ModuleGlobal(m Module, name string) (dptr DevicePtr, size int64, err error)
95-
P2PAttribute(srcDevice Device, attrib P2PAttribute, dstDevice Device) (value int, err error)
9694
Priority(hStream Stream) (priority int, err error)
9795
QueryEvent(hEvent Event)
9896
QueryStream(hStream Stream)

ctx.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package cu
55
// #include <cuda.h>
66
import "C"
77
import (
8-
"log"
98
"runtime"
109
"unsafe"
1110
)
@@ -41,9 +40,6 @@ func newContext(c CUContext) *Ctx {
4140
work: make(chan func() error),
4241
errChan: make(chan error),
4342
}
44-
pc, _, _, _ := runtime.Caller(2)
45-
46-
log.Printf("Created %p by %v", ctx, runtime.FuncForPC(pc).Name())
4743
runtime.SetFinalizer(ctx, finalizeCtx)
4844
return ctx
4945

@@ -61,7 +57,10 @@ func (ctx *Ctx) CUDAContext() CUContext { return ctx.CUContext }
6157
func (ctx *Ctx) Error() error { return ctx.err }
6258

6359
// Work returns the channel where work will be passed in. In most cases you don't need this. Use Run instead.
64-
func (ctx *Ctx) Work() chan func() error { return ctx.work }
60+
func (ctx *Ctx) Work() <-chan func() error { return ctx.work }
61+
62+
// ErrChan returns the internal error channel used
63+
func (ctx *Ctx) ErrChan() chan<- error { return ctx.errChan }
6564

6665
// Run locks the goroutine to the OS thread and ties the CUDA context to the OS thread. For most cases, this would suffice
6766
//
@@ -115,7 +114,6 @@ func (ctx *Ctx) Run(errChan chan error) error {
115114
}
116115

117116
func finalizeCtx(ctx *Ctx) {
118-
log.Printf("Finalizing %p", ctx)
119117
if ctx.CUContext == 0 {
120118
close(ctx.errChan)
121119
close(ctx.work)

ctx_api.go

+1-16
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (ctx *Ctx) MemFreeHost(p unsafe.Pointer) {
169169
ctx.err = ctx.Do(f)
170170
}
171171

172-
func (ctx *Ctx) MemAllocManaged(bytesize int64, flags uint) (dptr DevicePtr, err error) {
172+
func (ctx *Ctx) MemAllocManaged(bytesize int64, flags MemAttachFlags) (dptr DevicePtr, err error) {
173173
Cbytesize := C.size_t(bytesize)
174174
Cflags := C.uint(flags)
175175
var Cdptr C.CUdeviceptr
@@ -1022,21 +1022,6 @@ func (ctx *Ctx) CanAccessPeer(dev Device, peerDev Device) (canAccessPeer int, er
10221022
return
10231023
}
10241024

1025-
func (ctx *Ctx) P2PAttribute(srcDevice Device, attrib P2PAttribute, dstDevice Device) (value int, err error) {
1026-
CsrcDevice := C.CUdevice(srcDevice)
1027-
Cattrib := C.CUdevice_P2PAttribute(attrib)
1028-
CdstDevice := C.CUdevice(dstDevice)
1029-
var Cvalue C.int
1030-
f := func() error {
1031-
return result(C.cuDeviceGetP2PAttribute(&Cvalue, Cattrib, CsrcDevice, CdstDevice))
1032-
}
1033-
if err = ctx.Do(f); err != nil {
1034-
err = errors.Wrap(err, "P2PAttribute")
1035-
}
1036-
value = int(Cvalue)
1037-
return
1038-
}
1039-
10401025
func (ctx *Ctx) EnablePeerAccess(peerContext CUContext, Flags uint) {
10411026
CpeerContext := peerContext.c()
10421027
CFlags := C.uint(Flags)

cucontext.go

+39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package cu
44
import "C"
55
import (
66
"fmt"
7+
"runtime"
78
"unsafe"
89
)
910

@@ -27,6 +28,44 @@ func (d Device) MakeContext(flags ContextFlags) (CUContext, error) {
2728
return makeContext(ctx), nil
2829
}
2930

31+
// Lock ties the calling goroutine to an OS thread, then ties the CUDA context to the thread.
32+
// Do not call in a goroutine.
33+
//
34+
// Good:
35+
/*
36+
func main() {
37+
dev, _ := GetDevice(0)
38+
ctx, _ := dev.MakeContext()
39+
if err := ctx.Lock(); err != nil{
40+
// handle error
41+
}
42+
43+
mem, _ := MemAlloc(1024)
44+
}
45+
*/
46+
// Bad:
47+
/*
48+
func main() {
49+
dev, _ := GetDevice(0)
50+
ctx, _ := dev.MakeContext()
51+
go ctx.Lock() // this will tie the goroutine that calls ctx.Lock to the OS thread, while the main thread does not get the lock
52+
mem, _ := MemAlloc(1024)
53+
}
54+
*/
55+
func (ctx CUContext) Lock() error {
56+
runtime.LockOSThread()
57+
return SetCurrentContext(ctx)
58+
}
59+
60+
// Unlock unlocks unbinds the goroutine from the OS thread
61+
func (ctx CUContext) Unlock() error {
62+
if err := Synchronize(); err != nil {
63+
return err
64+
}
65+
runtime.UnlockOSThread()
66+
return nil
67+
}
68+
3069
// DestroyContext destroys the context. It returns an error if it wasn't properly destroyed
3170
//
3271
// Wrapper over cuCtxDestroy: http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e

0 commit comments

Comments
 (0)