Skip to content

Commit 60a409b

Browse files
authored
Further reorganize artifact structure (octoml#156)
This PR reorganizes the artifact structure. We now have two separate types of directories to store the libs/weights/..., with one "prebuilt" directory which holds all the prebuilt libs and weights downloaded from internet, and other model directories that are generated by local builds. CLI and test scripts are updated accordingly for this change.
1 parent 801f573 commit 60a409b

File tree

6 files changed

+114
-76
lines changed

6 files changed

+114
-76
lines changed

build.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,24 @@
1616

1717
def _parse_args():
1818
args = argparse.ArgumentParser()
19-
utils.argparse_add_common(args)
19+
args.add_argument(
20+
"--model-path",
21+
type=str,
22+
default=None,
23+
help="Custom model path that contains params, tokenizer, and config",
24+
)
25+
args.add_argument(
26+
"--hf-path",
27+
type=str,
28+
default=None,
29+
help="Hugging Face path from which to download params, tokenizer, and config from",
30+
)
31+
args.add_argument(
32+
"--quantization",
33+
type=str,
34+
choices=[*utils.quantization_dict.keys()],
35+
default=list(utils.quantization_dict.keys())[0],
36+
)
2037
args.add_argument("--max-seq-len", type=int, default=-1)
2138
args.add_argument("--target", type=str, default="auto")
2239
args.add_argument(
@@ -62,9 +79,12 @@ def _parse_args():
6279

6380
return parsed
6481

82+
6583
def _setup_model_path(args):
6684
if args.model_path and args.hf_path:
67-
assert (args.model_path and not args.hf_path) or (args.hf_path and not args.model_path), "You cannot specify both a model path and a HF path. Please select one to specify."
85+
assert (args.model_path and not args.hf_path) or (
86+
args.hf_path and not args.model_path
87+
), "You cannot specify both a model path and a HF path. Please select one to specify."
6888
if args.model_path:
6989
validate_config(args)
7090
with open(os.path.join(args.model_path, "config.json")) as f:
@@ -78,20 +98,30 @@ def _setup_model_path(args):
7898
else:
7999
os.makedirs(args.model_path, exist_ok=True)
80100
os.system("git lfs install")
81-
os.system(f"git clone https://huggingface.co/{args.hf_path} {args.model_path}")
101+
os.system(
102+
f"git clone https://huggingface.co/{args.hf_path} {args.model_path}"
103+
)
82104
print(f"Downloaded weights to {args.model_path}")
83105
validate_config(args)
84106
else:
85107
raise ValueError(f"Please specify either the model_path or the hf_path.")
86108
print(f"Using model path {args.model_path}")
87109
return args
88110

111+
89112
def validate_config(args):
90-
assert os.path.exists(os.path.join(args.model_path, "config.json")), "Model path must contain valid config file."
113+
assert os.path.exists(
114+
os.path.join(args.model_path, "config.json")
115+
), "Model path must contain valid config file."
91116
with open(os.path.join(args.model_path, "config.json")) as f:
92117
config = json.load(f)
93-
assert ("model_type" in config) and ("_name_or_path" in config), "Invalid config format."
94-
assert config["model_type"] in utils.supported_model_types, f"Model type {config['model_type']} not supported."
118+
assert ("model_type" in config) and (
119+
"_name_or_path" in config
120+
), "Invalid config format."
121+
assert (
122+
config["model_type"] in utils.supported_model_types
123+
), f"Model type {config['model_type']} not supported."
124+
95125

96126
def debug_dump_script(mod, name, args):
97127
"""Debug dump mode"""
@@ -177,7 +207,7 @@ def dump_default_mlc_llm_config(args):
177207
config["stream_interval"] = 2
178208
config["mean_gen_len"] = 128
179209
config["shift_fill_factor"] = 0.3
180-
dump_path = os.path.join(args.artifact_path, "mlc_llm_config.json")
210+
dump_path = os.path.join(args.artifact_path, "params", "mlc-llm-config.json")
181211
with open(dump_path, "w") as outfile:
182212
json.dump(config, outfile, indent=4)
183213
print(f"Finish exporting mlc_llm_config to {dump_path}")
@@ -255,10 +285,7 @@ def dump_split_tir(mod: tvm.IRModule):
255285
mod, params = llama.get_model(ARGS, config)
256286
elif ARGS.model_category == "gpt_neox":
257287
mod, params = gpt_neox.get_model(
258-
ARGS.model,
259-
ARGS.model_path,
260-
ARGS.quantization.model_dtype,
261-
config
288+
ARGS.model, ARGS.model_path, ARGS.quantization.model_dtype, config
262289
)
263290
elif ARGS.model_category == "moss":
264291
mod, params = moss.get_model(ARGS, config)

cpp/cli_main.cc

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -239,71 +239,97 @@ int main(int argc, char* argv[]) {
239239
using namespace tvm::runtime;
240240
argparse::ArgumentParser args("mlc_chat");
241241

242+
args.add_argument("--local-id").default_value("");
243+
args.add_argument("--model").default_value("vicuna-v1-7b");
244+
args.add_argument("--quantization").default_value("auto");
242245
args.add_argument("--device-name").default_value("auto");
243246
args.add_argument("--device_id").default_value(0).scan<'i', int>();
244247
args.add_argument("--artifact-path").default_value("dist");
245-
args.add_argument("--model").default_value("vicuna-v1-7b");
246-
args.add_argument("--quantization").default_value("auto");
247248
args.add_argument("--params").default_value("auto");
248249
args.add_argument("--evaluate").default_value(false).implicit_value(true);
249250

250251
try {
251252
args.parse_args(argc, argv);
252253
} catch (const std::runtime_error& err) {
253254
std::cerr << err.what() << std::endl;
254-
std::cerr << args;
255+
std::cerr << args << std::endl;
255256
return 1;
256257
}
257258

259+
std::string local_id = args.get<std::string>("--local-id");
260+
std::string model = args.get<std::string>("--model");
261+
std::string quantization = args.get<std::string>("--quantization");
258262
std::string device_name = DetectDeviceName(args.get<std::string>("--device-name"));
259263
int device_id = args.get<int>("--device_id");
260264
DLDevice device = GetDevice(device_name, device_id);
261265
std::string artifact_path = args.get<std::string>("--artifact-path");
262-
std::string model = args.get<std::string>("--model");
263-
std::string quantization = args.get<std::string>("--quantization");
264266
std::string params = args.get<std::string>("--params");
265267

266268
std::string arch_suffix = GetArchSuffix();
267269

268-
std::optional<std::filesystem::path> lib_path_opt;
270+
std::vector<std::string> local_id_candidates;
271+
std::optional<std::filesystem::path> config_path_opt;
269272

270-
std::vector<std::string> quantization_candidates;
271-
if (quantization == "auto") {
272-
quantization_candidates = quantization_presets;
273+
// Configure local id candidates.
274+
if (local_id != "") {
275+
local_id_candidates = {local_id};
273276
} else {
274-
quantization_candidates = {quantization};
277+
std::vector<std::string> quantization_candidates;
278+
if (quantization == "auto") {
279+
quantization_candidates = quantization_presets;
280+
} else {
281+
quantization_candidates = {quantization};
282+
}
283+
for (std::string quantization_candidate : quantization_candidates) {
284+
local_id_candidates.push_back(model + "-" + quantization_candidate);
285+
}
275286
}
276287

277-
std::optional<std::filesystem::path> lib_path;
278-
for (auto candidate : quantization_candidates) {
279-
std::string lib_name = model + "-" + candidate + "-" + device_name;
280-
std::vector<std::string> search_paths = {artifact_path + "/" + model + "-" + candidate,
281-
artifact_path + "/" + model, artifact_path + "/lib"};
282-
// search for lib_x86_64 and lib
283-
lib_path_opt = FindFile(search_paths, {lib_name, lib_name + arch_suffix}, GetLibSuffixes());
284-
if (lib_path_opt) {
285-
quantization = candidate;
288+
// Search for mlc-llm-config.json.
289+
for (auto local_id_candidate : local_id_candidates) {
290+
std::vector<std::string> config_search_paths = {
291+
artifact_path + "/" + local_id_candidate + "/params", //
292+
artifact_path + "/prebuilt/" + local_id_candidate};
293+
config_path_opt = FindFile(config_search_paths, {"mlc-llm-config"}, {".json"});
294+
if (config_path_opt) {
295+
local_id = local_id_candidate;
286296
break;
287297
}
288298
}
299+
if (!config_path_opt) {
300+
std::cerr << "Cannot find \"mlc-llm-config.json\" in path \"" << artifact_path << "/"
301+
<< local_id_candidates[0] << "/params/\", \"" << artifact_path
302+
<< "/prebuilt/" + local_id_candidates[0] << "\" or other candidate paths.";
303+
return 1;
304+
}
305+
std::cout << "Use config " << config_path_opt.value().string() << std::endl;
306+
std::filesystem::path model_path = config_path_opt.value().parent_path();
307+
308+
// Locate the library.
309+
std::string lib_name = local_id + "-" + device_name;
310+
std::string lib_dir_path;
311+
if (model_path.string().compare(model_path.string().length() - 7, 7, "/params") == 0) {
312+
lib_dir_path = model_path.parent_path().string();
313+
} else {
314+
lib_dir_path = model_path.parent_path().string() + "/lib";
315+
}
316+
std::optional<std::filesystem::path> lib_path_opt =
317+
FindFile({lib_dir_path}, {lib_name, lib_name + arch_suffix}, GetLibSuffixes());
289318
if (!lib_path_opt) {
290-
std::cerr << "Cannot find " << model << " lib in preferred path \"" << artifact_path << "/"
291-
<< model << "-" << quantization_candidates[0] << "/" << model << "-"
292-
<< quantization_candidates[0] << "-" << device_name << GetLibSuffixes()[0]
293-
<< "\" or other candidate paths";
319+
std::cerr << "Cannot find library \"" << lib_name << GetLibSuffixes().back()
320+
<< "\" and other library candidate in " << lib_dir_path << std::endl;
294321
return 1;
295322
}
296323
std::cout << "Use lib " << lib_path_opt.value().string() << std::endl;
297-
std::string model_path = lib_path_opt.value().parent_path().string();
298-
LOG(INFO) << "model_path = " << model_path;
299-
// get artifact path lib name
324+
325+
// Locate the tokenizer files.
300326
std::optional<std::filesystem::path> tokenizer_path_opt =
301-
FindFile({model_path, artifact_path + "/" + model}, {"tokenizer"}, {".model", ".json"});
327+
FindFile({model_path.string()}, {"tokenizer"}, {".model", ".json"});
302328
if (!tokenizer_path_opt) {
303329
// Try ByteLevelBPETokenizer
304-
tokenizer_path_opt = FindFile({model_path, artifact_path + "/" + model}, {"vocab"}, {".json"});
330+
tokenizer_path_opt = FindFile({model_path.string()}, {"vocab"}, {".json"});
305331
if (!tokenizer_path_opt) {
306-
std::cerr << "Cannot find tokenizer file in " << model_path;
332+
std::cerr << "Cannot find tokenizer file in " << model_path.string() << std::endl;
307333
return 1;
308334
} else {
309335
// GPT2 styles tokenizer needs multiple files, we need to
@@ -312,19 +338,16 @@ int main(int argc, char* argv[]) {
312338
}
313339
}
314340

341+
// Locate the params.
315342
if (params == "auto") {
316-
auto params_json_opt =
317-
FindFile({model_path + "/params", artifact_path + "/" + model + "/params"},
318-
{"ndarray-cache"}, {".json"});
343+
auto params_json_opt = FindFile({model_path}, {"ndarray-cache"}, {".json"});
319344
if (!params_json_opt) {
320-
std::cerr << "Cannot find ndarray-cache.json for params in preferred path \"" << model_path
321-
<< "/params\" and \"" << artifact_path << "/" + model << "/params.";
345+
std::cerr << "Cannot find ndarray-cache.json for params in " << model_path << std::endl;
322346
return 1;
323347
}
324-
std::string params_json = params_json_opt.value().string();
325-
params = params_json.substr(0, params_json.length() - 18);
348+
params = params_json_opt.value().parent_path().string();
326349
} else if (!FindFile({params}, {"ndarray-cache"}, {".json"})) {
327-
std::cerr << "Cannot find params/ndarray-cache.json in " << model_path;
350+
std::cerr << "Cannot find ndarray-cache.json for params in " << params << std::endl;
328351
return 1;
329352
}
330353

@@ -345,7 +368,7 @@ int main(int argc, char* argv[]) {
345368
} catch (const std::runtime_error& err) {
346369
// catch exception so error message
347370
// get reported here without silently quit.
348-
std::cerr << err.what();
371+
std::cerr << err.what() << std::endl;
349372
return 1;
350373
}
351374
return 0;

mlc_llm/utils.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,6 @@ class Quantization:
4646

4747
supported_model_types = set(["llama", "gpt_neox", "moss"])
4848

49-
def argparse_add_common(args: argparse.ArgumentParser) -> None:
50-
args.add_argument(
51-
"--quantization",
52-
type=str,
53-
choices=[*quantization_dict.keys()],
54-
default=list(quantization_dict.keys())[0],
55-
)
56-
args.add_argument(
57-
"--model-path",
58-
type=str,
59-
default=None,
60-
help="Custom model path that contains params, tokenizer, and config"
61-
)
62-
args.add_argument(
63-
"--hf-path",
64-
type=str,
65-
default=None,
66-
help="Hugging Face path from which to download params, tokenizer, and config from"
67-
)
6849

6950
def argparse_postproc_common(args: argparse.Namespace) -> None:
7051
if hasattr(args, "device_name"):
@@ -208,7 +189,10 @@ def _is_static_shape_func(func: tvm.tir.PrimFunc):
208189
def copy_tokenizer(args: argparse.Namespace) -> None:
209190
for filename in os.listdir(args.model_path):
210191
if filename.startswith("tokenizer") or filename == "vocab.json":
211-
shutil.copy(os.path.join(args.model_path, filename), args.artifact_path)
192+
shutil.copy(
193+
os.path.join(args.model_path, filename),
194+
os.path.join(args.artifact_path, "params"),
195+
)
212196

213197

214198
def parse_target(args: argparse.Namespace) -> None:

tests/chat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ class Colors:
2828

2929
def _parse_args():
3030
args = argparse.ArgumentParser()
31-
utils.argparse_add_common(args)
31+
args.add_argument("--local-id", type=str, required=True)
3232
args.add_argument("--device-name", type=str, default="auto")
3333
args.add_argument("--debug-dump", action="store_true", default=False)
3434
args.add_argument("--artifact-path", type=str, default="dist")
3535
args.add_argument("--max-gen-len", type=int, default=2048)
3636
parsed = args.parse_args()
37+
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
3738
utils.argparse_postproc_common(parsed)
3839
parsed.artifact_path = os.path.join(
3940
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
@@ -223,7 +224,7 @@ def main():
223224
if ARGS.debug_dump:
224225
torch.manual_seed(12)
225226
tokenizer = AutoTokenizer.from_pretrained(
226-
ARGS.artifact_path, trust_remote_code=True
227+
os.path.join(ARGS.artifact_path, "params"), trust_remote_code=True
227228
)
228229
tokenizer.pad_token_id = tokenizer.eos_token_id
229230
if ARGS.model.startswith("dolly-"):

tests/debug/compare_lib.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def deploy_to_pipeline(args) -> None:
135135
primary_device = tvm.device(args.primary_device)
136136
const_params = utils.load_params(args.artifact_path, primary_device)
137137
state = TestState(args)
138-
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
138+
tokenizer = AutoTokenizer.from_pretrained(
139+
os.path.join(args.artifact_path, "params"), trust_remote_code=True
140+
)
139141

140142
print("Tokenizing...")
141143
inputs = tvm.nd.array(
@@ -177,17 +179,17 @@ def deploy_to_pipeline(args) -> None:
177179

178180
def _parse_args():
179181
args = argparse.ArgumentParser()
180-
utils.argparse_add_common(args)
182+
args.add_argument("--local-id", type=str, required=True)
181183
args.add_argument("--artifact-path", type=str, default="dist")
182184
args.add_argument("--primary-device", type=str, default="auto")
183185
args.add_argument("--cmp-device", type=str, required=True)
184186
args.add_argument("--prompt", type=str, default="The capital of Canada is")
185187
args.add_argument("--time-eval", default=False, action="store_true")
186188
args.add_argument("--skip-rounds", type=int, default=0)
187189
parsed = args.parse_args()
190+
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
188191
utils.argparse_postproc_common(parsed)
189192

190-
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
191193
parsed.artifact_path = os.path.join(
192194
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
193195
)

tests/evaluate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
def _parse_args():
2020
args = argparse.ArgumentParser()
21-
utils.argparse_add_common(args)
21+
args.add_argument("--local-id", type=str, required=True)
2222
args.add_argument("--device-name", type=str, default="auto")
2323
args.add_argument("--debug-dump", action="store_true", default=False)
2424
args.add_argument("--artifact-path", type=str, default="dist")
2525
args.add_argument("--prompt", type=str, default="The capital of Canada is")
2626
args.add_argument("--profile", action="store_true", default=False)
2727
parsed = args.parse_args()
28+
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
2829
utils.argparse_postproc_common(parsed)
2930
parsed.artifact_path = os.path.join(
3031
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
@@ -91,7 +92,7 @@ def deploy_to_pipeline(args) -> None:
9192
vm = relax.VirtualMachine(ex, device)
9293

9394
tokenizer = AutoTokenizer.from_pretrained(
94-
args.artifact_path, trust_remote_code=True
95+
os.path.join(args.artifact_path, "params"), trust_remote_code=True
9596
)
9697

9798
print("Tokenizing...")

0 commit comments

Comments
 (0)