Skip to content

Commit

Permalink
Mr. zylo117's fix
Browse files Browse the repository at this point in the history
* fix typo

* add ret on init

* update README.md

* 适配旧版libax_engine.so没有AX_ENGINE_GetGroupIOInfoCount接口

* change device_no to device_id

* 利用axclrtSetDevice完成context的get/set,从而支持跨线程推理
  • Loading branch information
zylo117 authored Jan 2, 2025
1 parent 9371cff commit c83df64
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Class Index: 277, Score: 8.452778816223145
Class Index: 281, Score: 8.320704460144043
Class Index: 287, Score: 7.924479961395264
# 默认将自动检测计算设备,但也可以强制要求跑在AX650 M.2算力卡上,设备号是1
# 默认将自动检测计算设备,但也可以强制要求跑在AX650 M.2算力卡上,假设设备号是1,(设备号必须大于等于0,具体查看axcl-smi)
root@ax650:~/samples# python3 classification.py -b axcl -d 1
[INFO] SOC Name: AX650N
[INFO] Runtime version: 1.0.0
Expand Down
13 changes: 13 additions & 0 deletions axengine/_axcl_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
#define AXCL_MAX_DEVICE_COUNT 256
typedef int32_t axclError;
typedef void *axclrtContext;
"""
)

Expand Down Expand Up @@ -149,6 +150,18 @@
"""
axclError axclrtGetDeviceList(axclrtDeviceList *deviceList);
axclError axclrtSetDevice(int32_t deviceId);
axclError axclrtResetDevice(int32_t deviceId);
"""
)

# axcl_rt_context.h
O.cdef(
"""
axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId);
axclError axclrtDestroyContext(axclrtContext context);
axclError axclrtSetCurrentContext(axclrtContext context);
axclError axclrtGetCurrentContext(axclrtContext *context);
axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId);
"""
)

Expand Down
6 changes: 5 additions & 1 deletion axengine/ax_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def __init__(
print(f"[INFO] Compiler version: {self._get_model_tool_version()}")

# get shape group count
self._shape_count = self._get_shape_count()
try:
self._shape_count = self._get_shape_count()
except AttributeError as e:
print(f"[WARNING] {e}")
self._shape_count = 1

# get model shape
self._info = self._get_info()
Expand Down
26 changes: 20 additions & 6 deletions axengine/axcl_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(

super(BaseInferenceSession).__init__()

self.device_id = device_id

# load shared library
self._rt_lib = _capi.R
self._rt_ffi = _capi.O
Expand All @@ -34,12 +36,18 @@ def __init__(
print(f"[INFO] SOC Name: {self.soc_name}")

# init axcl
self.axcl_device_id = -1 # axcl_device_id != device_id, device_id is just the index of the list of axcl_device_ids
ret = self._init(device_id)
if 0 != ret:
raise RuntimeError("Failed to initialize axclrt.")
print(f"[INFO] Runtime version: {self._get_version()}")

# handle, context, info, io
self._thread_context = self._rt_ffi.new("axclrtContext *")
ret = self._rt_lib.axclrtGetCurrentContext(self._thread_context)
if ret != 0:
raise RuntimeError("axclrtGetCurrentContext failed")

# model handle, context, info, io
self._handle = self._rt_ffi.new("uint64_t *")
self._context = self._rt_ffi.new("uint64_t *")
self.io_info = self._rt_ffi.new("axclrtEngineIOInfo *")
Expand Down Expand Up @@ -249,14 +257,15 @@ def _free_io(self, io_data):
def _init(self, device_id=0, vnpu=VNPUType.DISABLED): # vnpu type, the default is disabled
ret = self._rt_lib.axclInit([])
if ret != 0:
raise RuntimeError("Failed to initialize runtime.")
raise RuntimeError(f"Failed to initialize runtime. {ret}.")

lst = self._rt_ffi.new("axclrtDeviceList *")
ret = self._rt_lib.axclrtGetDeviceList(lst)
if ret != 0 or lst.num == 0:
raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.")

ret = self._rt_lib.axclrtSetDevice(lst.devices[device_id])
self.axcl_device_id = lst.devices[device_id]
ret = self._rt_lib.axclrtSetDevice(self.axcl_device_id)
if ret != 0 or lst.num == 0:
raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.")

Expand All @@ -269,6 +278,7 @@ def _init(self, device_id=0, vnpu=VNPUType.DISABLED): # vnpu type, the default
def _final(self):
if self._handle[0] is not None:
self._unload()
self._rt_lib.axclrtResetDevice(self.axcl_device_id)
self._rt_lib.axclFinalize()
return

Expand Down Expand Up @@ -331,6 +341,10 @@ def run(self, output_names, input_feed, run_options=None):
self._validate_input(list(input_feed.keys()))
self._validate_output(output_names)

ret = self._rt_lib.axclrtSetCurrentContext(self._thread_context[0])
if ret != 0:
raise RuntimeError("axclrtSetCurrentContext failed")

if None is output_names:
output_names = [o.name for o in self.get_outputs()]

Expand Down Expand Up @@ -384,8 +398,8 @@ def run(self, output_names, input_feed, run_options=None):
for i, output_tensor in enumerate(self.mgroup_output_tensors[grp_id])
if self.get_outputs()[i].name in output_names]

print(f'[INFO] cost time in host to device: {cost_host_to_device * 1009:.3f}ms, '
f'inference: {cost_inference * 1009:.3f}ms, '
f'device to host: {cost_device_to_host * 1009:.3f}ms')
print(f'[INFO] cost time in host to device: {cost_host_to_device * 1000:.3f}ms, '
f'inference: {cost_inference * 1000:.3f}ms, '
f'device to host: {cost_device_to_host * 1000:.3f}ms')

return outputs
2 changes: 1 addition & 1 deletion axengine/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def InferenceSession(path_or_bytes: str | bytes | os.PathLike, device_id: int =
print("axcl_rt not found, please install axcl_host driver")

if is_axcl:
print(f"Using axclrt backend, device_no: {device_id}")
print(f"Using axclrt backend, device_id: {device_id}")
return AXCLInferenceSession(path_or_bytes, device_id)
else:
print("Using ax backend with onboard npu")
Expand Down
23 changes: 14 additions & 9 deletions examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from PIL import Image


def load_model(model_path, backend='auto', device_no=-1):
def load_model(model_path, backend='auto', device_id=-1):
if backend == 'auto':
session = axe.InferenceSession(model_path, device_no)
session = axe.InferenceSession(model_path, device_id)
elif backend == 'ax':
session = axe.AXInferenceSession(model_path)
elif backend == 'axcl':
session = axe.AXCLInferenceSession(model_path, device_no)
session = axe.AXCLInferenceSession(model_path, device_id)
return session


Expand Down Expand Up @@ -62,16 +62,21 @@ def get_top_k_predictions(output, k=5):
return top_k_indices, top_k_scores


def main(model_path, image_path, target_size, crop_size, k, backend='auto', device_no=-1):
def main(model_path, image_path, target_size, crop_size, k, backend='auto', device_id=-1):
# Load the model
session = load_model(model_path, backend, device_no)
session = load_model(model_path, backend, device_id)

# Preprocess the image
input_tensor = preprocess_image(image_path, target_size, crop_size)

# Get input name and run inference
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: input_tensor})
import time
for i in range(10):
t1 = time.time()
output = session.run(None, {input_name: input_tensor})
t2 = time.time()
print(t2 - t1)

# Get top k predictions
top_k_indices, top_k_scores = get_top_k_predictions(output, k)
Expand All @@ -86,14 +91,14 @@ def main(model_path, image_path, target_size, crop_size, k, backend='auto', devi
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('-b', '--backend', type=str, help='auto/ax/axcl', default='auto')
ap.add_argument('-d', '--device_no', type=int, help='axcl device no, -1: onboard npu, >0: axcl devices', default=0)
ap.add_argument('-d', '--device_id', type=int, help='axcl device no, -1: onboard npu, >0: axcl devices', default=0)
args = ap.parse_args()
assert args.backend in ['auto', 'ax', 'axcl'], "backend must be auto/ax/axcl"
assert args.device_no >= -1, "device_no must be greater than -1"
assert args.device_id >= -1, "device_id must be greater than -1"

MODEL_PATH = "../mobilenetv2.axmodel"
IMAGE_PATH = "../cat.jpg"
TARGET_SIZE = (256, 256) # Resize to 256x256
CROP_SIZE = (224, 224) # Crop to 224x224
K = 5 # Top K predictions
main(MODEL_PATH, IMAGE_PATH, TARGET_SIZE, CROP_SIZE, K, args.backend, args.device_no)
main(MODEL_PATH, IMAGE_PATH, TARGET_SIZE, CROP_SIZE, K, args.backend, args.device_id)
14 changes: 7 additions & 7 deletions examples/yolov5_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,14 @@ def post_processing(outputs, origin_shape, input_shape):
return pred


def detect_yolov5(model_path, image_path, save_path, backend='auto', device_no=-1):
def detect_yolov5(model_path, image_path, save_path, backend='auto', device_id=-1):

if backend == 'auto':
session = axe.InferenceSession(model_path, device_no)
session = axe.InferenceSession(model_path, device_id)
elif backend == 'ax':
session = axe.AXInferenceSession(model_path)
elif backend == 'axcl':
session = axe.AXCLInferenceSession(model_path, device_no)
session = axe.AXCLInferenceSession(model_path, device_id)
image_data = cv2.imread(image_path)
inputs, origin_shape = pre_processing(image_data, (640, 640))
inputs = np.ascontiguousarray(inputs)
Expand All @@ -475,13 +475,13 @@ def parse_args() -> argparse.ArgumentParser:
parser.add_argument("--model", type=str, required=True, help="axmodel path")
parser.add_argument("--image_path", type=str, required=True, help="image path")
parser.add_argument('-b', '--backend', type=str, help='auto/ax/axcl', default='auto')
parser.add_argument('-d', '--device_no', type=int, help='axcl device no, -1: onboard npu, >0: axcl devices', default=0)
parser.add_argument('-d', '--device_id', type=int, help='axcl device no, -1: onboard npu, >0: axcl devices', default=0)
parser.add_argument(
"--save_path", type=str, default="save.jpg", help="save image path"
)
args = parser.parse_args()
assert args.backend in ['auto', 'ax', 'axcl'], "backend must be ax or axcl"
assert args.device_no >= -1, "device_no must be greater than -1"
assert args.device_id >= -1, "device_id must be greater than -1"
return args


Expand All @@ -490,7 +490,7 @@ def parse_args() -> argparse.ArgumentParser:
print(f"model : {args.model}")
print(f"image path : {args.image_path}")
print(f"backend : {args.backend}")
print(f"device_no : {args.device_no}")
print(f"device_id : {args.device_id}")
print(f"save draw image to: {args.save_path}")
detect_yolov5(args.model, args.image_path, args.save_path, args.backend, args.device_no)
detect_yolov5(args.model, args.image_path, args.save_path, args.backend, args.device_id)
# python3 yolov5_example.py --model /opt/data/npu/models/yolov5s.axmodel --image_path /opt/data/npu/images/dog.jpg --save_path ./detect_dog.jpg

0 comments on commit c83df64

Please sign in to comment.