Skip to content

Commit 98d52db

Browse files
committed
graph context & manual op config
1 parent ce96890 commit 98d52db

31 files changed

+455
-183
lines changed

ark/api/model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
namespace ark {
1111

12-
Model Model::compress() const {
12+
Model Model::compress(bool merge_nodes) const {
1313
Model model(*this);
14-
model.compress_nodes();
14+
model.compress_nodes(merge_nodes);
1515
return model;
1616
}
1717

ark/api/model_graph.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,30 @@ int ModelGraph::rank() const { return impl_->rank(); }
3333

3434
int ModelGraph::world_size() const { return impl_->world_size(); }
3535

36-
void ModelGraph::compress_nodes() { impl_->compress_nodes(); }
36+
void ModelGraph::compress_nodes(bool merge_nodes) {
37+
impl_->compress_nodes(merge_nodes);
38+
}
3739

3840
bool ModelGraph::compressed() const { return impl_->compressed(); }
3941

4042
bool ModelGraph::verify() const { return impl_->verify(); }
4143

44+
ModelGraph::ContextManager ModelGraph::context(const std::string& key,
45+
const std::string& value) {
46+
return impl_->context_manager(key, value);
47+
}
48+
49+
ModelGraph::ContextManager::ContextManager(
50+
const std::map<std::string, std::string>& context, const std::string& key,
51+
std::shared_ptr<std::map<std::string, std::vector<std::string>>>
52+
context_stacks)
53+
: context_(context), key_(key), context_stacks_(context_stacks) {}
54+
55+
ModelGraph::ContextManager::~ContextManager() {
56+
if (context_stacks_->find(key_) == context_stacks_->end()) {
57+
ERR(ModelError, "context stack not found: {}", key_);
58+
}
59+
context_stacks_->at(key_).pop_back();
60+
}
61+
4262
} // namespace ark

ark/api/model_test.cpp

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ark::unittest::State test_model_basics() {
3636
// (AddOp,)
3737
//
3838

39-
compressed = model.compress();
39+
compressed = model.compress(true);
4040
UNITTEST_TRUE(compressed.verify());
4141
UNITTEST_TRUE(compressed.compressed());
4242
UNITTEST_EQ(compressed.nodes().size(), 1);
@@ -70,7 +70,7 @@ ark::unittest::State test_model_basics() {
7070
// (AddOp,AddOp,)
7171
//
7272

73-
compressed = model.compress();
73+
compressed = model.compress(true);
7474
UNITTEST_TRUE(compressed.verify());
7575
UNITTEST_EQ(compressed.nodes().size(), 1);
7676

@@ -104,7 +104,7 @@ ark::unittest::State test_model_basics() {
104104
// (AddOp,AddOp,ReluOp,)
105105
//
106106

107-
compressed = model.compress();
107+
compressed = model.compress(true);
108108
UNITTEST_TRUE(compressed.verify());
109109
UNITTEST_EQ(compressed.nodes().size(), 1);
110110

@@ -143,7 +143,7 @@ ark::unittest::State test_model_basics() {
143143
// (AddOp,AddOp,ReluOp,AddOp,)
144144
//
145145

146-
compressed = model.compress();
146+
compressed = model.compress(true);
147147
UNITTEST_TRUE(compressed.verify());
148148

149149
auto nodes = compressed.nodes();
@@ -190,7 +190,7 @@ ark::unittest::State test_model_basics() {
190190
// (AddOp,) --+--> (AddOp,)
191191
//
192192

193-
compressed = model.compress();
193+
compressed = model.compress(true);
194194
UNITTEST_TRUE(compressed.verify());
195195

196196
nodes = compressed.nodes();
@@ -250,7 +250,7 @@ ark::unittest::State test_model_basics() {
250250
// (AddOp,)
251251
//
252252

253-
compressed = model.compress();
253+
compressed = model.compress(true);
254254
UNITTEST_TRUE(compressed.verify());
255255

256256
nodes = compressed.nodes();
@@ -312,7 +312,7 @@ ark::unittest::State test_model_basics() {
312312
// (AddOp,)
313313
//
314314

315-
compressed = model.compress();
315+
compressed = model.compress(true);
316316
UNITTEST_TRUE(compressed.verify());
317317

318318
nodes = compressed.nodes();
@@ -353,7 +353,7 @@ ark::unittest::State test_model_dependent_inputs() {
353353
ark::Tensor x4 = m.mul(x2, x3);
354354
ark::Tensor y = m.add(x0, x4);
355355

356-
auto compressed = m.compress();
356+
auto compressed = m.compress(true);
357357
auto nodes = compressed.nodes();
358358
UNITTEST_EQ(nodes.size(), 4);
359359
auto nodes_iter = nodes.begin();
@@ -399,7 +399,7 @@ ark::unittest::State test_model_noop() {
399399

400400
UNITTEST_TRUE(model.verify());
401401

402-
auto compressed = model.compress();
402+
auto compressed = model.compress(true);
403403
UNITTEST_TRUE(compressed.verify());
404404
UNITTEST_EQ(compressed.nodes().size(), 0);
405405
return ark::unittest::SUCCESS;
@@ -425,7 +425,7 @@ ark::unittest::State test_model_identity() {
425425
ark::Tensor t4 = model.relu(t3);
426426
UNITTEST_TRUE(model.verify());
427427

428-
auto compressed = model.compress();
428+
auto compressed = model.compress(true);
429429
UNITTEST_TRUE(compressed.verify());
430430
auto nodes = compressed.nodes();
431431
UNITTEST_EQ(nodes.size(), 3);
@@ -478,7 +478,7 @@ ark::unittest::State test_model_sharding() {
478478
ark::Tensor t5 = model.relu(t4);
479479
UNITTEST_TRUE(model.verify());
480480

481-
auto compressed = model.compress();
481+
auto compressed = model.compress(true);
482482
UNITTEST_TRUE(compressed.verify());
483483
auto nodes = compressed.nodes();
484484
UNITTEST_EQ(nodes.size(), 4);
@@ -526,7 +526,7 @@ ark::unittest::State test_model_cumulate() {
526526

527527
UNITTEST_TRUE(model.verify());
528528

529-
auto compressed = model.compress();
529+
auto compressed = model.compress(true);
530530
auto nodes = compressed.nodes();
531531
UNITTEST_EQ(nodes.size(), 5);
532532

@@ -538,12 +538,54 @@ ark::unittest::State test_model_cumulate() {
538538
return ark::unittest::SUCCESS;
539539
}
540540

541+
ark::unittest::State test_model_context() {
542+
ark::Model model;
543+
ark::Tensor t0 = model.tensor({1}, ark::FP32);
544+
ark::Tensor t1 = model.tensor({1}, ark::FP32);
545+
ark::Tensor t2 = model.add(t0, t1);
546+
547+
ark::Tensor t3;
548+
ark::Tensor t4;
549+
ark::Tensor t5;
550+
{
551+
ark::Model::ContextManager cm0_1 = model.context("lev0", "1");
552+
t3 = model.relu(t2);
553+
554+
ark::Model::ContextManager cm1_1 = model.context("lev1", "2");
555+
t4 = model.sqrt(t3);
556+
}
557+
{
558+
ark::Model::ContextManager cm0_2 = model.context("lev0", "3");
559+
t5 = model.exp(t2);
560+
}
561+
562+
UNITTEST_TRUE(model.verify());
563+
564+
auto compressed = model.compress(false);
565+
UNITTEST_TRUE(compressed.verify());
566+
567+
auto nodes = compressed.nodes();
568+
UNITTEST_EQ(nodes.size(), 4);
569+
570+
UNITTEST_EQ(nodes[0]->context.size(), 0);
571+
UNITTEST_EQ(nodes[1]->context.size(), 1);
572+
UNITTEST_EQ(nodes[1]->context.at("lev0"), "1");
573+
UNITTEST_EQ(nodes[2]->context.size(), 2);
574+
UNITTEST_EQ(nodes[2]->context.at("lev0"), "1");
575+
UNITTEST_EQ(nodes[2]->context.at("lev1"), "2");
576+
UNITTEST_EQ(nodes[3]->context.size(), 1);
577+
UNITTEST_EQ(nodes[3]->context.at("lev0"), "3");
578+
579+
return ark::unittest::SUCCESS;
580+
}
581+
541582
int main() {
542583
UNITTEST(test_model_basics);
543584
UNITTEST(test_model_dependent_inputs);
544585
UNITTEST(test_model_noop);
545586
UNITTEST(test_model_identity);
546587
UNITTEST(test_model_sharding);
547588
UNITTEST(test_model_cumulate);
589+
UNITTEST(test_model_context);
548590
return 0;
549591
}

ark/api/planner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const {
6969
task_info["Id"] = next_node_id++;
7070

7171
Json config;
72-
if (!config_rules_.empty()) {
72+
if (!op->config().empty()) {
73+
config = op->config();
74+
} else if (!config_rules_.empty()) {
7375
const std::string op_str = op->serialize().dump();
7476
for (auto &rule : config_rules_) {
7577
auto config_str = rule(op_str, gpu_info.arch->name());

ark/include/ark/model.hpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ class Model : public ModelGraph {
2626

2727
Model &operator=(const Model &other) = default;
2828

29-
Model compress() const;
29+
Model compress(bool merge_nodes = false) const;
30+
31+
using ContextManager = ModelGraph::ContextManager;
3032

3133
int unique_tag();
3234

@@ -87,23 +89,29 @@ class Model : public ModelGraph {
8789
// result in `output`.
8890
// Currently, only reduction along the last dimension is supported.
8991
Tensor reduce_sum(Tensor input, int axis, bool keepdims = true,
90-
Tensor output = NullTensor, const std::string &name = "");
92+
Tensor output = NullTensor,
93+
const std::string &config = "",
94+
const std::string &name = "");
9195
Tensor reduce_mean(Tensor input, int axis, bool keepdims = true,
9296
Tensor output = NullTensor,
97+
const std::string &config = "",
9398
const std::string &name = "");
9499
Tensor reduce_max(Tensor input, int axis, bool keepdims = true,
95-
Tensor output = NullTensor, const std::string &name = "");
100+
Tensor output = NullTensor,
101+
const std::string &config = "",
102+
const std::string &name = "");
96103

97104
// Transposes the `input` tensor according to the given `permutation`.
98105
// For example, transpose(input, {0, 1 ,3, 2}) will swap the last two
99106
// dimensions of the input tensor. Currently, only 4D tensors are supported.
100107
Tensor transpose(Tensor input, const std::vector<int64_t> &permutation,
101-
Tensor output = NullTensor, const std::string &name = "");
108+
Tensor output = NullTensor, const std::string &config = "",
109+
const std::string &name = "");
102110
// Performs matrix multiplication between the `input` tensor and another
103111
// `other` tensor, storing the result in `output`.
104112
Tensor matmul(Tensor input, Tensor other, Tensor output = NullTensor,
105113
bool trans_input = false, bool trans_other = false,
106-
const std::string &name = "");
114+
const std::string &config = "", const std::string &name = "");
107115
// Implements the 'im2col' method for 2D convolution layers, which takes an
108116
// `input` tensor and reshapes it to a 2D matrix by extracting image patches
109117
// from the input tensor based on the provided parameters.
@@ -120,72 +128,76 @@ class Model : public ModelGraph {
120128
Tensor output = NullTensor, const std::string &name = "");
121129
// Calculates the exponential of the `input` tensor, element-wise.
122130
Tensor exp(Tensor input, Tensor output = NullTensor,
123-
const std::string &name = "");
131+
const std::string &config = "", const std::string &name = "");
124132
// Calculates the square root of the `input` tensor, element-wise.
125133
Tensor sqrt(Tensor input, Tensor output = NullTensor,
126-
const std::string &name = "");
134+
const std::string &config = "", const std::string &name = "");
127135
// Calculates the reverse square root of the `input` tensor, element-wise.
128136
Tensor rsqrt(Tensor input, Tensor output = NullTensor,
129-
const std::string &name = "");
137+
const std::string &config = "", const std::string &name = "");
130138
// ReLU activation
131139
Tensor relu(Tensor input, Tensor output = NullTensor,
132-
const std::string &name = "");
140+
const std::string &config = "", const std::string &name = "");
133141
// Copy the `input` tensor to `output` tensor
134142
Tensor copy(Tensor input, Tensor output = NullTensor,
135-
const std::string &name = "");
143+
const std::string &config = "", const std::string &name = "");
136144
Tensor copy(float val, Tensor output = NullTensor,
137-
const std::string &name = "");
145+
const std::string &config = "", const std::string &name = "");
138146
// Applies the Gaussian Error Linear Unit (GELU) activation function to the
139147
// `input` tensor, element-wise. GELU is a smooth approximation of the
140148
// rectifier function and is widely used in deep learning models.
141149
Tensor gelu(Tensor input, Tensor output = NullTensor,
142-
const std::string &name = "");
150+
const std::string &config = "", const std::string &name = "");
143151
// Sigmoid activation
144152
Tensor sigmoid(Tensor input, Tensor output = NullTensor,
153+
const std::string &config = "",
145154
const std::string &name = "");
146155
// Performs rotary position embedding (RoPE) on the `input` tensor
147156
Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor,
148-
const std::string &name = "");
157+
const std::string &config = "", const std::string &name = "");
149158

150159
// Performs an element-wise addition operator between the `input` tensor
151160
// and the `other` tensor
152161
Tensor add(Tensor input, Tensor other, Tensor output = NullTensor,
153-
const std::string &name = "");
162+
const std::string &config = "", const std::string &name = "");
154163
Tensor add(Tensor input, float value, Tensor output = NullTensor,
155-
const std::string &name = "");
164+
const std::string &config = "", const std::string &name = "");
156165
// Performs an element-wise subtraction operator between the `input` tensor
157166
// and the `other` tensor
158167
Tensor sub(Tensor input, Tensor other, Tensor output = NullTensor,
159-
const std::string &name = "");
168+
const std::string &config = "", const std::string &name = "");
160169
Tensor sub(Tensor input, float value, Tensor output = NullTensor,
161-
const std::string &name = "");
170+
const std::string &config = "", const std::string &name = "");
162171
// Performs an element-wise multiplication operator between the `input`
163172
// tensor and the `other` tensor,
164173
Tensor mul(Tensor input, Tensor other, Tensor output = NullTensor,
165-
const std::string &name = "");
174+
const std::string &config = "", const std::string &name = "");
166175
Tensor mul(Tensor input, float value, Tensor output = NullTensor,
167-
const std::string &name = "");
176+
const std::string &config = "", const std::string &name = "");
168177
// Performs an element-wise division operator between the `input`
169178
// tensor and the `other` tensor,
170179
Tensor div(Tensor input, Tensor other, Tensor output = NullTensor,
171-
const std::string &name = "");
180+
const std::string &config = "", const std::string &name = "");
172181
Tensor div(Tensor input, float value, Tensor output = NullTensor,
173-
const std::string &name = "");
182+
const std::string &config = "", const std::string &name = "");
174183

175184
Tensor send(Tensor input, int remote_rank, int tag,
176-
Tensor output = NullTensor, const std::string &name = "");
185+
Tensor output = NullTensor, const std::string &config = "",
186+
const std::string &name = "");
177187
// Blocks the execution until the corresponding 'send' operator with the
178188
// specified `id` is completed.
179-
Tensor send_done(Tensor input, const std::string &name = "");
189+
Tensor send_done(Tensor input, const std::string &config = "",
190+
const std::string &name = "");
180191
// Receives a tensor from a source rank (@p src_rank), identified by the
181192
// `id` parameter. Blocks the execution until the corresponding 'recv'
182193
// operator is completed.
183194
Tensor recv(Tensor output, int remote_rank, int tag,
184-
const std::string &name = "");
195+
const std::string &config = "", const std::string &name = "");
185196
//
186197
Tensor put_packet(Tensor input, Tensor local_tmp_buf, Tensor recv_buf,
187198
int id, int rank, int dst_rank, size_t dst_offset,
188-
int flag, const std::string &name = "");
199+
int flag, const std::string &config = "",
200+
const std::string &name = "");
189201
// Performs an all-reduce operator across all ranks, aggregating the input
190202
// tensors. Takes the `input` tensor, the current GPU's rank, and the
191203
// total number of ranks `rank_num`.
@@ -200,10 +212,12 @@ class Model : public ModelGraph {
200212
const std::string &name = "");
201213
/// Embedding layer.
202214
Tensor embedding(Tensor input, Tensor weight, Tensor output = NullTensor,
215+
const std::string &config = "",
203216
const std::string &name = "");
204217
/// Tensor type casting.
205218
Tensor cast(Tensor input, const DataType &data_type,
206-
Tensor output = NullTensor, const std::string &name = "");
219+
Tensor output = NullTensor, const std::string &config = "",
220+
const std::string &name = "");
207221

208222
// sync across multi devices
209223
Tensor device_sync(Tensor input, int npeers, const std::string &name = "");

0 commit comments

Comments
 (0)