Skip to content

Commit

Permalink
cpu: aarch64: matmul: Move allocation of temporary tensors to scratch…
Browse files Browse the repository at this point in the history
…pad in acl_matmul

Introduce 3 new scrathpad memory key names.
  • Loading branch information
annop-w authored and vpirogov committed Jun 25, 2024
1 parent 5806809 commit 6f14365
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
23 changes: 18 additions & 5 deletions src/cpu/aarch64/matmul/acl_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,32 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {

// Run transpose kernel
if (is_transA && !is_transB) {
acl_obj.src_tensor.allocator()->allocate();
auto transA_scratch = scratchpad.get<void>(
memory_tracking::names::key_matmul_src_trans);
acl_obj.src_tensor.allocator()->import_memory(transA_scratch);
acl_obj.src_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
acl_obj.transA.run();
acl_obj.wei_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
} else if (is_transB && !is_transA) {
acl_obj.wei_tensor.allocator()->allocate();
auto transB_scratch = scratchpad.get<void>(
memory_tracking::names::key_matmul_wei_trans);
acl_obj.wei_tensor.allocator()->import_memory(transB_scratch);
acl_obj.wei_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
acl_obj.transB.run();
acl_obj.src_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
} else if (is_transA && is_transB && !do_transC) {
acl_obj.src_tensor.allocator()->allocate();
auto transA_scratch = scratchpad.get<void>(
memory_tracking::names::key_matmul_src_trans);
auto transB_scratch = scratchpad.get<void>(
memory_tracking::names::key_matmul_wei_trans);
acl_obj.src_tensor.allocator()->import_memory(transA_scratch);
acl_obj.src_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
acl_obj.wei_tensor.allocator()->allocate();
acl_obj.wei_tensor.allocator()->import_memory(transB_scratch);
acl_obj.wei_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
acl_obj.transA.run();
Expand All @@ -71,7 +79,11 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
const_cast<data_t *>(src_base));
acl_obj.wei_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
if (do_transC) { acl_obj.dst_acc_tensor.allocator()->allocate(); }
if (do_transC) {
auto transC_scratch = scratchpad.get<void>(
memory_tracking::names::key_matmul_dst_trans);
acl_obj.dst_acc_tensor.allocator()->import_memory(transC_scratch);
}
}

// If we have an unfused sum post op, put the result in a scratchpad tensor.
Expand All @@ -94,6 +106,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
pd()->acl_post_ops.execute(ctx, dst);

acl_obj.dst_tensor.allocator()->free();
if (do_transC) acl_obj.dst_acc_tensor.allocator()->free();

return status;
}
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/acl_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ struct acl_matmul_t : public primitive_t {
}

auto scratchpad = scratchpad_registry().registrar();
CHECK(acl_matmul_utils::init_scratchpad(scratchpad, amp_, dst_md_));
CHECK(acl_matmul_utils::init_scratchpad(
scratchpad, amp_, src_md_, weights_md_, dst_md_));

return status::success;
}
Expand Down
18 changes: 17 additions & 1 deletion src/cpu/aarch64/matmul/acl_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,28 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
}

status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
acl_matmul_conf_t &amp, memory_desc_t &dst_md) {
const acl_matmul_conf_t &amp, const memory_desc_t &src_md,
const memory_desc_t &weights_md, const memory_desc_t &dst_md) {
if (amp.use_dst_acc_for_sum) {
const memory_desc_wrapper dst_d(&dst_md);
scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt,
dst_d.nelems(), dst_d.data_type_size());
}
if (amp.is_transA) {
const memory_desc_wrapper src_d(&src_md);
scratchpad.book(memory_tracking::names::key_matmul_src_trans,
src_d.nelems(), src_d.data_type_size());
}
if (amp.is_transB) {
const memory_desc_wrapper wei_d(&weights_md);
scratchpad.book(memory_tracking::names::key_matmul_wei_trans,
wei_d.nelems(), wei_d.data_type_size());
}
if (amp.do_transC) {
const memory_desc_wrapper dst_d(&dst_md);
scratchpad.book(memory_tracking::names::key_matmul_dst_trans,
dst_d.nelems(), dst_d.data_type_size());
}
return status::success;
}

Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/acl_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
const primitive_attr_t &attr);

status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
acl_matmul_conf_t &amp, memory_desc_t &dst_md);
const acl_matmul_conf_t &amp, const memory_desc_t &src_md,
const memory_desc_t &weights_md, const memory_desc_t &dst_md);

} // namespace acl_matmul_utils

Expand Down

0 comments on commit 6f14365

Please sign in to comment.