Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that checks if all async mem copies are not strided.
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
TVM_DLL static Postproc DisallowAsyncStridedMemCopy();
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
Expand Down
15 changes: 13 additions & 2 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,15 +726,26 @@ TVM_DLL const Op& texture2d_store();
TVM_DLL const Op& texture2d_load();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
* \brief Initiate a non-blocking DMA copy from source to destination; a DMA copy outside of a group
* has a defacto group size of one
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \brief Wait until the number of DMA groups in flight is less than or equal to some maximum
*/
TVM_DLL const Op& dma_wait();

/*!
* \brief Start a group of DMA copies
*/
TVM_DLL const Op& dma_start_group();

/*!
* \brief End a group of DMA copies
*/
TVM_DLL const Op& dma_end_group();

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@

@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
class DisallowAsyncStridedMemCopy(Postproc):
"""A postprocessor that disallows schedules that use async strided mem copies.
"""A postprocessor that disallows schedules that use async strided mem copies."""

Parameters
----------
merge_async_commit_queue_scope : bool
Whether or not to merge the async commit queue scope.
"""

def __init__(self, merge_async_commit_queue_scope=True) -> None:
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
merge_async_commit_queue_scope,
)
35 changes: 21 additions & 14 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,27 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy",
-1, # Use QueueId of -1 to not interfere with async copies.
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0, # Do not use experimental bypass mode.
dtype="int32",
)
)
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_wait",
-1,
0, # Wait for the sync queue (-1) to have 0 messages.
"device_api.hexagon.dma_copy_dltensor",
T.tvm_stack_make_array(
T.address_of(C[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
C.dtype,
0,
dtype="handle",
),
T.tvm_stack_make_array(
T.address_of(A[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
A.dtype,
0,
dtype="handle",
),
T.cast(size, dtype="int"),
False, # Do not use experimental bypass mode.
dtype="int32",
)
)
Expand Down
1 change: 0 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
Bool(merge_async_commit_queue_scope));
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
Expand All @@ -169,15 +166,12 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
return Postproc(n);
}

bool merge_async_commit_queue_scope = true;

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};

Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
Postproc Postproc::DisallowAsyncStridedMemCopy() {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
return Postproc(n);
}

Expand Down
13 changes: 13 additions & 0 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,19 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: When the input/output types are fixed, the .set_body_typed() method can be used to avoid needing manual argument wrangling.

.set_body_typed([](int queue_id) -> int32_t {
      return HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
    });

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chose not to implement this for this iteration as it seems like we could / should redo the entire Hexagon Device API with this change.

int queue_id = args[0];
HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
12 changes: 7 additions & 5 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ unsigned int HexagonUserDMA::Init() {
return status;
}

int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length,
bool bypass_cache) {
// length limited to 24 bits
if (length > DESC_LENGTH_MASK) {
return DMA_FAILURE;
Expand Down Expand Up @@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bo
return DMA_SUCCESS;
}

void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
// wait (forever) until max DMAs in flight <= actual DMAs in flight
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
}
}

uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }
uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAsInFlight(queue_id); }

uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
uint32_t HexagonUserDMA::DMAsInFlight(uint32_t queue_id) {
dmpoll(); // update DMA engine status
return descriptors_->InFlight(queue_id);
}
Expand All @@ -125,7 +126,8 @@ HexagonUserDMA::HexagonUserDMA() {
unsigned int done = dma_desc_get_done(dma_desc);
return (done != DESC_DONE_COMPLETE);
};
descriptors_ = new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_DESCRIPTORS, desc_in_flight);
descriptors_ =
new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight);
}

HexagonUserDMA::~HexagonUserDMA() {
Expand Down
31 changes: 25 additions & 6 deletions src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace hexagon {
#define DMA_FAILURE -1
#define DMA_RETRY 1
#define MAX_DMA_DESCRIPTORS 100
#define SYNC_DMA_QUEUE -1
#define MAX_DMA_QUEUES 10
#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1

class HexagonUserDMA {
public:
Expand All @@ -47,32 +48,50 @@ class HexagonUserDMA {

/*!
* \brief Initiate DMA to copy memory from source to destination address
* \param queue_id The virtual DMA queue
* \param dst Destination address
* \param src Source address
* \param length Length in bytes to copy
* \returns Status: DMA_SUCCESS or DMA_FAILURE
*/
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \param queue_id The virtual DMA queue
* \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight
* to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete
*/
void Wait(int queue_id, uint32_t max_dmas_in_flight);
void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight);

/*!
* \brief Poll the number of DMAs in flight
* \param queue_id The virtual DMA queue
* \returns Number of DMAs in flight
*/
uint32_t Poll(int queue_id);
uint32_t Poll(uint32_t queue_id);

/*!
* \brief Start a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); }

/*!
* \brief End a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); }

private:
//! \brief Initializes the Hexagon User DMA engine
unsigned int Init();

//! \brief Calculates and returns the number of DMAs in flight
uint32_t DMAsInFlight(int queue_id);
/*!
* \brief Calculates and returns the number of DMAs in flight
* \param queue_id The virtual DMA queue
*/
uint32_t DMAsInFlight(uint32_t queue_id);

//! \brief Tracks whether the very first DMA has been executed
bool first_dma_ = true;
Expand Down
76 changes: 70 additions & 6 deletions src/runtime/hexagon/ring_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#define TVM_RUNTIME_HEXAGON_RING_BUFFER_H_

#include <functional>
#include <queue>
#include <unordered_map>
#include <vector>

#include "hexagon_common.h"
Expand Down Expand Up @@ -94,17 +96,33 @@ class RingBuffer {
template <class T>
class QueuedRingBuffer : RingBuffer<T> {
public:
QueuedRingBuffer(uint32_t ring_buff_size, std::function<bool(T*)> in_flight)
: RingBuffer<T>(ring_buff_size, in_flight) {}
QueuedRingBuffer(uint32_t max_queues, uint32_t ring_buff_size, std::function<bool(T*)> in_flight)
: RingBuffer<T>(ring_buff_size, in_flight), max_queues_(max_queues) {
queue_descriptors_.resize(max_queues_);
}

//! \brief Returns pointer to next T; add the queue ID for tracking
T* Next(int queue_id) {
T* Next(uint32_t queue_id) {
CHECK_LT(queue_id, max_queues_);
queue_ids_.push_back(queue_id);
queue_descriptor* d = &queue_descriptors_[queue_id];
if (d->group_started) {
// if we have a group started just update then pending count
d->pending_in_group++;
} else {
// else create group with size one
d->groups.push(1);
d->pending_total++;
}
return RingBuffer<T>::Next();
}

//! \brief Returns the number of Ts in flight for a given queue ID
uint32_t InFlight(int queue_id) {
//! \brief Returns the number of groups of Ts in flight for a given queue ID
uint32_t InFlight(uint32_t queue_id) {
CHECK_LT(queue_id, max_queues_);
queue_descriptor* d = &queue_descriptors_[queue_id];
CHECK(!d->group_started);

uint32_t in_flight = 0;
// look at the queue IDs for the RingBuffer entries in flight
for (size_t i = queue_ids_.size() - RingBuffer<T>::InFlight(); i < queue_ids_.size(); ++i) {
Expand All @@ -113,11 +131,57 @@ class QueuedRingBuffer : RingBuffer<T> {
in_flight++;
}
}
return in_flight;

// calculate number of groups in flight
while (!d->groups.empty() && d->pending_total - d->groups.front() >= in_flight) {
d->pending_total -= d->groups.front();
d->groups.pop();
}

// return the number of groups in flight
return d->groups.size();
}

//! \brief Start a group of Ts, if not called the deafault group size is one
void StartGroup(uint32_t queue_id) {
CHECK_LT(queue_id, max_queues_);
queue_descriptor* d = &queue_descriptors_[queue_id];
CHECK(!d->group_started);

// start group
d->group_started = true;
d->pending_in_group = 0;
}

//! \brief End a group of Ts
void EndGroup(uint32_t queue_id) {
CHECK_LT(queue_id, max_queues_);
queue_descriptor* d = &queue_descriptors_[queue_id];
CHECK(d->group_started);
CHECK(d->pending_in_group);

// create group
if (d->pending_in_group) {
d->groups.emplace(d->pending_in_group);
}
d->pending_total += d->pending_in_group;

// end group
d->group_started = false;
d->pending_in_group = 0;
}

private:
struct queue_descriptor {
uint32_t pending_total = 0;
uint32_t pending_in_group = 0;
bool group_started = false;
std::queue<int> groups;
};

const int max_queues_;
std::vector<int> queue_ids_;
std::vector<queue_descriptor> queue_descriptors_;
};

} // namespace hexagon
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_start_group)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_end_group)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);
Expand Down
Loading