Skip to content

Commit 5bdd11a

Browse files
committed
Use Stream wrapping instead of casting
1 parent 800e9a0 commit 5bdd11a

9 files changed

+41
-54
lines changed

addressing.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (d DevicePtr) MemAdvise(count int64, advice MemAdvice, dev Device) error {
8181
func (d DevicePtr) MemPrefetchAsync(count int64, dst Device, hStream Stream) error {
8282
devPtr := C.CUdeviceptr(d)
8383
cc := C.size_t(count)
84-
str := C.CUstream(unsafe.Pointer(uintptr(hStream)))
84+
str := hStream.s
8585
dv := C.CUdevice(dst)
8686
return result(C.cuMemPrefetchAsync(devPtr, cc, dv, str))
8787
}

batch.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ func (ctx *BatchedContext) LaunchKernel(function Function, gridDimX, gridDimY, g
413413
blockDimY: C.uint(blockDimY),
414414
blockDimZ: C.uint(blockDimZ),
415415
sharedMemBytes: C.uint(sharedMemBytes),
416-
stream: C.CUstream(unsafe.Pointer(uintptr(stream))),
416+
stream: stream.c(),
417417
kernelParams: (*unsafe.Pointer)(argp),
418418
extra: (*unsafe.Pointer)(unsafe.Pointer(uintptr(0))),
419419
}

batch_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func TestBatchContext(t *testing.T) {
6464
}
6565

6666
bctx.MemcpyHtoD(memB, unsafe.Pointer(&b[0]), size)
67-
bctx.LaunchKernel(fn, 1, 1, 1, len(a), 1, 1, 0, Stream(0), args)
67+
bctx.LaunchKernel(fn, 1, 1, 1, len(a), 1, 1, 0, Stream{}, args)
6868
bctx.Synchronize()
6969
bctx.MemcpyDtoH(unsafe.Pointer(&a[0]), memA, size)
7070
bctx.MemcpyDtoH(unsafe.Pointer(&b[0]), memB, size)
@@ -161,7 +161,7 @@ func TestLargeBatch(t *testing.T) {
161161
}
162162

163163
bctx.MemcpyHtoD(memB, unsafe.Pointer(&b[0]), size)
164-
bctx.LaunchKernel(fn, 1, 1, 1, len(a), 1, 1, 0, Stream(0), args)
164+
bctx.LaunchKernel(fn, 1, 1, 1, len(a), 1, 1, 0, Stream{}, args)
165165
bctx.Synchronize()
166166

167167
if i%13 == 0 {
@@ -273,7 +273,7 @@ func BenchmarkNoBatching(bench *testing.B) {
273273
bench.Fatalf("Failed to copy memory from b: %v", err)
274274
}
275275

276-
if err = fn.LaunchAndSync(100, 10, 1, 1000, 1, 1, 1, Stream(0), args); err != nil {
276+
if err = fn.LaunchAndSync(100, 10, 1, 1000, 1, 1, 1, Stream{}, args); err != nil {
277277
bench.Errorf("Launch and Sync Failed: %v", err)
278278
}
279279

@@ -353,7 +353,7 @@ func BenchmarkBatching(bench *testing.B) {
353353
default:
354354
bctx.MemcpyHtoD(memA, unsafe.Pointer(&a[0]), size)
355355
bctx.MemcpyHtoD(memB, unsafe.Pointer(&b[0]), size)
356-
bctx.LaunchKernel(fn, 100, 10, 1, 1000, 1, 1, 0, Stream(0), args)
356+
bctx.LaunchKernel(fn, 100, 10, 1, 1000, 1, 1, 0, Stream{}, args)
357357
bctx.Synchronize()
358358
bctx.MemcpyDtoH(unsafe.Pointer(&a[0]), memA, size)
359359
bctx.MemcpyDtoH(unsafe.Pointer(&b[0]), memB, size)

batchedPatterns.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (fn Function) LaunchAndSync(gridDimX, gridDimY, gridDimZ, blockDimX, blockD
5353
C.uint(blockDimY),
5454
C.uint(blockDimZ),
5555
C.uint(sharedMemBytes),
56-
C.CUstream(unsafe.Pointer(uintptr(stream))),
56+
stream.c(),
5757
(*unsafe.Pointer)(argp),
5858
(*unsafe.Pointer)(unsafe.Pointer(uintptr(0)))))
5959
return err

batchedPatterns_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func TestLaunchAndSync(t *testing.T) {
101101
unsafe.Pointer(&size),
102102
}
103103

104-
if err = fn.LaunchAndSync(1, 1, 1, len(a), 1, 1, 1, Stream(0), args); err != nil {
104+
if err = fn.LaunchAndSync(1, 1, 1, len(a), 1, 1, 1, Stream{}, args); err != nil {
105105
t.Errorf("Launch and Sync Failed: %v", err)
106106
}
107107

cucontext_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func TestMultipleContextSingleHostThread(t *testing.T) {
173173
unsafe.Pointer(&size),
174174
}
175175

176-
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream(0), args); err != nil {
176+
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream{}, args); err != nil {
177177
t.Errorf("Failed to launcj add32: %v", err)
178178
}
179179

@@ -204,7 +204,7 @@ func TestMultipleContextSingleHostThread(t *testing.T) {
204204
unsafe.Pointer(&size),
205205
}
206206

207-
if err = fn1.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream(0), args); err != nil {
207+
if err = fn1.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream{}, args); err != nil {
208208
t.Errorf("Failed to launcj add32: %v", err)
209209
}
210210

@@ -221,7 +221,7 @@ func TestMultipleContextSingleHostThread(t *testing.T) {
221221
unsafe.Pointer(&size),
222222
}
223223

224-
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream(0), args); err == nil {
224+
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream{}, args); err == nil {
225225
t.Errorf("Expected error when launching a kernel defined in a different context")
226226
}
227227
t.Log(err)
@@ -230,7 +230,7 @@ func TestMultipleContextSingleHostThread(t *testing.T) {
230230
if err = SetCurrentContext(ctx0); err != nil {
231231
t.Errorf("Failed to swtch to ctx0 %v", err)
232232
}
233-
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream(0), args); err != nil {
233+
if err = fn0.LaunchAndSync(1, 1, 1, len(data), 1, 1, 0, Stream{}, args); err != nil {
234234
t.Errorf("fn0 errored while using memory declared in ctx1: %v", err)
235235
}
236236
t.Log(err)

execution.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (fn Function) LaunchKernel(gridDimX, gridDimY, gridDimZ int, blockDimX, blo
3535
C.uint(blockDimY),
3636
C.uint(blockDimZ),
3737
C.uint(sharedMemBytes),
38-
C.CUstream(unsafe.Pointer(uintptr(stream))),
38+
stream.c(),
3939
(*unsafe.Pointer)(argp),
4040
(*unsafe.Pointer)(unsafe.Pointer(uintptr(0)))))
4141
return err
@@ -67,7 +67,7 @@ func (ctx *Ctx) LaunchKernel(fn Function, gridDimX, gridDimY, gridDimZ int, bloc
6767
C.uint(blockDimY),
6868
C.uint(blockDimZ),
6969
C.uint(sharedMemBytes),
70-
C.CUstream(unsafe.Pointer(uintptr(stream))),
70+
stream.c(),
7171
(*unsafe.Pointer)(argp),
7272
(*unsafe.Pointer)(unsafe.Pointer(uintptr(0)))))
7373
}

module_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestModule(t *testing.T) {
5353
grid := DivUp(N, block)
5454
shmem := 0
5555
args := []unsafe.Pointer{unsafe.Pointer(&A), unsafe.Pointer(&value), unsafe.Pointer(&n)}
56-
if err = f.LaunchKernel(grid, 1, 1, block, 1, 1, shmem, 0, args); err != nil {
56+
if err = f.LaunchKernel(grid, 1, 1, block, 1, 1, shmem, Stream{}, args); err != nil {
5757
t.Fatal(err)
5858
}
5959

stream.go

+26-39
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,25 @@ package cu
33
// #include <cuda.h>
44
import "C"
55
import (
6-
"unsafe"
7-
86
"github.com/pkg/errors"
97
)
108

119
// Stream represents a CUDA stream.
12-
type Stream uintptr
10+
type Stream struct {
11+
s C.CUstream
12+
}
1313

14-
func makeStream(s C.CUstream) Stream { return Stream(uintptr(unsafe.Pointer(s))) }
15-
func (s Stream) c() C.CUstream { return C.CUstream(unsafe.Pointer(uintptr(s))) }
14+
func makeStream(s C.CUstream) Stream { return Stream{s} }
15+
func (s Stream) c() C.CUstream { return s.s }
1616

1717
// C is the exported version of the c method
1818
func (s Stream) C() C.CUstream { return s.c() }
1919

2020
// MakeStream creates a stream. The flags determines the behaviors of the stream.
2121
func MakeStream(flags StreamFlags) (Stream, error) {
22-
var s C.CUstream
23-
if err := result(C.cuStreamCreate(&s, C.uint(flags))); err != nil {
24-
return 0, err
25-
}
26-
return makeStream(s), nil
22+
var s Stream
23+
err := result(C.cuStreamCreate(&s.s, C.uint(flags)))
24+
return s, err
2725
}
2826

2927
// MakeStreamWithPriority creates a stream with the given priority. The flags determines the behaviors of the stream.
@@ -35,52 +33,41 @@ func MakeStream(flags StreamFlags) (Stream, error) {
3533
// If the specified priority is outside the numerical range returned by `StreamPriorityRange`,
3634
// it will automatically be clamped to the lowest or the highest number in the range.
3735
func MakeStreamWithPriority(priority int, flags StreamFlags) (Stream, error) {
38-
var s C.CUstream
39-
if err := result(C.cuStreamCreateWithPriority(&s, C.uint(flags), C.int(priority))); err != nil {
40-
return 0, err
41-
}
42-
return makeStream(s), nil
36+
var s Stream
37+
err := result(C.cuStreamCreateWithPriority(&s.s, C.uint(flags), C.int(priority)))
38+
return s, err
4339
}
4440

4541
// DestroyStream destroys the stream specified by hStream.
4642
//
4743
// In case the device is still doing work in the stream hStream when DestroyStrea() is called,
4844
// the function will return immediately and the resources associated with hStream will be released automatically once the device has completed all work in hStream.
49-
func DestroyStream(hStream *Stream) error {
50-
stream := *hStream
51-
s := stream.c()
52-
*hStream = 0
53-
return result(C.cuStreamDestroy(s))
45+
func (hStream *Stream) Destroy() error {
46+
err := result(C.cuStreamDestroy(hStream.s))
47+
*hStream = Stream{}
48+
return err
5449
}
5550

5651
func (ctx *Ctx) MakeStream(flags StreamFlags) (stream Stream, err error) {
57-
var s C.CUstream
52+
var s Stream
5853

59-
f := func() error { return result(C.cuStreamCreate(&s, C.uint(flags))) }
54+
f := func() error { return result(C.cuStreamCreate(&s.s, C.uint(flags))) }
6055
if err = ctx.Do(f); err != nil {
61-
err = errors.Wrap(err, "MakeStream")
62-
return
56+
return s, errors.Wrap(err, "MakeStream")
6357
}
64-
stream = makeStream(s)
65-
return
58+
return s, nil
6659
}
6760

68-
func (ctx *Ctx) MakeStreamWithPriority(priority int, flags StreamFlags) (stream Stream, err error) {
69-
var s C.CUstream
70-
71-
f := func() error { return result(C.cuStreamCreateWithPriority(&s, C.uint(flags), C.int(priority))) }
72-
if err = ctx.Do(f); err != nil {
73-
err = errors.Wrap(err, "MakeStream With Priority")
74-
return
61+
func (ctx *Ctx) MakeStreamWithPriority(priority int, flags StreamFlags) (Stream, error) {
62+
var s Stream
63+
f := func() error { return result(C.cuStreamCreateWithPriority(&s.s, C.uint(flags), C.int(priority))) }
64+
if err := ctx.Do(f); err != nil {
65+
return s, errors.Wrap(err, "MakeStream With Priority")
7566
}
76-
stream = makeStream(s)
77-
return
67+
return s, nil
7868
}
7969

8070
func (ctx *Ctx) DestroyStream(hStream *Stream) {
81-
stream := *hStream
82-
s := stream.c()
83-
84-
f := func() error { return result(C.cuStreamDestroy(s)) }
71+
f := func() error { return result(C.cuStreamDestroy(hStream.s)) }
8572
ctx.err = ctx.Do(f)
8673
}

0 commit comments

Comments
 (0)