Skip to content

Commit a21986c

Browse files
committed
test: add unit tests for get_data_ptr_ipc
1 parent 2d8b07a commit a21986c

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
5+
import multiprocessing as mp
6+
import os
7+
import queue
8+
import unittest
9+
from multiprocessing import Process, Queue
10+
11+
import numpy as np
12+
import paddle
13+
14+
from fastdeploy.model_executor.ops.gpu import (
15+
get_data_ptr_ipc,
16+
read_data_ipc,
17+
set_data_ipc,
18+
)
19+
20+
21+
def _create_test_tensor(shape, dtype):
22+
# Create GPU tensor with deterministic data type
23+
paddle.device.set_device("gpu:0")
24+
return paddle.rand(shape=shape, dtype=dtype)
25+
26+
27+
def _producer_set_ipc(shm_name, shape, dtype, ready_q, done_q, error_q):
28+
try:
29+
paddle.device.set_device("gpu:0")
30+
t = _create_test_tensor(shape, dtype)
31+
set_data_ipc(t, shm_name)
32+
ready_q.put(("ready", True))
33+
_ = done_q.get(timeout=20)
34+
except Exception as e:
35+
error_q.put(("producer_error", str(e)))
36+
37+
38+
def _consumer_get_and_read(shm_name, shape, dtype, result_q, error_q):
39+
try:
40+
paddle.device.set_device("gpu:0")
41+
ptr_tensor = get_data_ptr_ipc(paddle.zeros([1], dtype=paddle.float32), shm_name)
42+
ptr_val = int(ptr_tensor.numpy()[0])
43+
if ptr_val == 0:
44+
raise RuntimeError("get_data_ptr_ipc returned null pointer")
45+
46+
# Note(ooooo): Because it can print the ptr of shm, make it to check in stdout with `hex(ptr_val)`.
47+
read_data_ipc(paddle.zeros([1], dtype=paddle.float32), ptr_val, shm_name)
48+
49+
result_q.put(("ok", ptr_val))
50+
except Exception as e:
51+
error_q.put(("consumer_error", str(e)))
52+
53+
54+
# Ensure spawn to avoid inheriting CUDA context
55+
try:
56+
mp.set_start_method("spawn", force=True)
57+
except RuntimeError:
58+
pass
59+
60+
61+
class TestGetReadDataIPC(unittest.TestCase):
62+
def setUp(self):
63+
paddle.seed(2024)
64+
np.random.seed(42)
65+
if not paddle.device.is_compiled_with_cuda():
66+
self.skipTest("CUDA not available, skipping GPU tests")
67+
paddle.device.set_device("gpu:0")
68+
69+
# ensure >= 10 elems since read_data_ipc prints first 10
70+
self.shape = [4, 8]
71+
self.dtype = paddle.float32
72+
self.shm_name = f"test_get_read_ipc_{os.getpid()}"
73+
74+
def test_get_then_read_ipc_cross_process(self):
75+
ready_q, result_q, error_q, done_q = Queue(), Queue(), Queue(), Queue()
76+
77+
# producer: set IPC and hold until done
78+
p = Process(target=_producer_set_ipc, args=(self.shm_name, self.shape, self.dtype, ready_q, done_q, error_q))
79+
p.start()
80+
81+
try:
82+
status, _ = ready_q.get(timeout=20)
83+
self.assertEqual(status, "ready")
84+
except Exception:
85+
p.terminate()
86+
self.fail("Producer did not become ready in time")
87+
88+
# consumer: get ptr and invoke read
89+
c = Process(target=_consumer_get_and_read, args=(self.shm_name, self.shape, self.dtype, result_q, error_q))
90+
c.start()
91+
c.join(timeout=30)
92+
93+
# let producer exit
94+
done_q.put("done")
95+
p.join(timeout=30)
96+
97+
# collect errors
98+
errors = []
99+
try:
100+
while True:
101+
errors.append(error_q.get_nowait())
102+
except queue.Empty:
103+
pass
104+
self.assertFalse(errors, f"Errors occurred: {errors}")
105+
106+
# result check
107+
self.assertFalse(result_q.empty(), "No result from consumer")
108+
status, ptr_val = result_q.get()
109+
self.assertEqual(status, "ok")
110+
self.assertTrue(isinstance(ptr_val, int) and ptr_val != 0)
111+
print(hex(ptr_val))
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)