Skip to content

Commit 8a4a879

Browse files
committed
added tests for LaunchAndSync
1 parent 04787f7 commit 8a4a879

File tree

2 files changed

+150
-1
lines changed

2 files changed

+150
-1
lines changed

batch_test.go

+83-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package cu
22

3-
import "testing"
3+
import (
4+
"testing"
5+
"unsafe"
6+
)
47

58
func TestAttributes(t *testing.T) {
69
devices, _ := NumDevices()
@@ -40,3 +43,82 @@ func TestAttributes(t *testing.T) {
4043
t.Errorf("Expected ComputeCapabilityMinor to be %v. Got %v instead", min, attrs[2])
4144
}
4245
}
46+
47+
func TestLaunchAndSync(t *testing.T) {
48+
devices, _ := NumDevices()
49+
50+
if devices == 0 {
51+
return
52+
}
53+
54+
var err error
55+
var ctx Context
56+
var mod Module
57+
var fn Function
58+
59+
d := Device(0)
60+
if ctx, err = d.MakeContext(SchedAuto); err != nil {
61+
t.Fatal(err)
62+
}
63+
64+
a := make([]float32, 1000)
65+
b := make([]float32, 1000)
66+
for i := range b {
67+
a[i] = 1
68+
b[i] = 1
69+
}
70+
71+
size := int64(len(a) * 4)
72+
73+
var memA, memB DevicePtr
74+
if memA, err = MemAlloc(size); err != nil {
75+
t.Fatalf("Failed to allocate for a: %v", err)
76+
}
77+
if memB, err = MemAlloc(size); err != nil {
78+
t.Fatalf("Failed to allocate for b: %v", err)
79+
}
80+
81+
if err = MemcpyHtoD(memA, unsafe.Pointer(&a[0]), size); err != nil {
82+
t.Fatalf("Failed to copy memory from a: %v", err)
83+
}
84+
85+
if err = MemcpyHtoD(memB, unsafe.Pointer(&b[0]), size); err != nil {
86+
t.Fatalf("Failed to copy memory from b: %v", err)
87+
}
88+
89+
if mod, err = LoadData(add32PTX); err != nil {
90+
t.Fatalf("Cannot load add32: %v", err)
91+
}
92+
93+
if fn, err = mod.Function("add32"); err != nil {
94+
t.Fatalf("Cannot get add32(): %v", err)
95+
}
96+
97+
args := []unsafe.Pointer{
98+
unsafe.Pointer(&memA),
99+
unsafe.Pointer(&memB),
100+
unsafe.Pointer(&size),
101+
}
102+
103+
if err = fn.LaunchAndSync(1, 1, 1, len(a), 1, 1, 1, Stream(0), args); err != nil {
104+
t.Error("Launch and Sync Failed: %v", err)
105+
}
106+
107+
if err = MemcpyDtoH(unsafe.Pointer(&a[0]), memA, size); err != nil {
108+
t.Fatalf("Failed to copy memory to a: %v", err)
109+
}
110+
111+
if err = MemcpyDtoH(unsafe.Pointer(&b[0]), memA, size); err != nil {
112+
t.Fatalf("Failed to copy memory to b: %v", err)
113+
}
114+
115+
for _, v := range a {
116+
if v != float32(2) {
117+
t.Error("Expected all values to be 2.")
118+
break
119+
}
120+
}
121+
122+
Unload(mod)
123+
DestroyContext(&ctx)
124+
}

test_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package cu
2+
3+
const add32PTX = `//
4+
// Generated by NVIDIA NVVM Compiler
5+
//
6+
// Compiler Build ID: CL-21554848
7+
// Cuda compilation tools, release 8.0, V8.0.61
8+
// Based on LLVM 3.4svn
9+
//
10+
11+
.version 5.0
12+
.target sm_30
13+
.address_size 64
14+
15+
// .globl add32
16+
17+
.visible .entry add32(
18+
.param .u64 add32_param_0,
19+
.param .u64 add32_param_1,
20+
.param .u32 add32_param_2
21+
)
22+
{
23+
.reg .pred %p<2>;
24+
.reg .f32 %f<4>;
25+
.reg .b32 %r<19>;
26+
.reg .b64 %rd<8>;
27+
28+
29+
ld.param.u64 %rd1, [add32_param_0];
30+
ld.param.u64 %rd2, [add32_param_1];
31+
ld.param.u32 %r2, [add32_param_2];
32+
mov.u32 %r3, %ctaid.x;
33+
mov.u32 %r4, %ctaid.z;
34+
mov.u32 %r5, %nctaid.y;
35+
mov.u32 %r6, %ctaid.y;
36+
mad.lo.s32 %r7, %r4, %r5, %r6;
37+
mov.u32 %r8, %nctaid.x;
38+
mad.lo.s32 %r9, %r7, %r8, %r3;
39+
mov.u32 %r10, %ntid.y;
40+
mov.u32 %r11, %ntid.x;
41+
mul.lo.s32 %r12, %r10, %r11;
42+
mov.u32 %r13, %ntid.z;
43+
mov.u32 %r14, %tid.y;
44+
mov.u32 %r15, %tid.z;
45+
mad.lo.s32 %r16, %r9, %r13, %r15;
46+
mov.u32 %r17, %tid.x;
47+
mad.lo.s32 %r18, %r14, %r11, %r17;
48+
mad.lo.s32 %r1, %r12, %r16, %r18;
49+
setp.ge.s32 %p1, %r1, %r2;
50+
@%p1 bra BB0_2;
51+
52+
cvta.to.global.u64 %rd3, %rd1;
53+
mul.wide.s32 %rd4, %r1, 4;
54+
add.s64 %rd5, %rd3, %rd4;
55+
cvta.to.global.u64 %rd6, %rd2;
56+
add.s64 %rd7, %rd6, %rd4;
57+
ld.global.f32 %f1, [%rd7];
58+
ld.global.f32 %f2, [%rd5];
59+
add.rn.f32 %f3, %f2, %f1;
60+
st.global.f32 [%rd5], %f3;
61+
62+
BB0_2:
63+
ret;
64+
}
65+
66+
67+
`

0 commit comments

Comments
 (0)