@@ -239,71 +239,97 @@ int main(int argc, char* argv[]) {
239239 using namespace tvm ::runtime;
240240 argparse::ArgumentParser args (" mlc_chat" );
241241
242+ args.add_argument (" --local-id" ).default_value (" " );
243+ args.add_argument (" --model" ).default_value (" vicuna-v1-7b" );
244+ args.add_argument (" --quantization" ).default_value (" auto" );
242245 args.add_argument (" --device-name" ).default_value (" auto" );
243246 args.add_argument (" --device_id" ).default_value (0 ).scan <' i' , int >();
244247 args.add_argument (" --artifact-path" ).default_value (" dist" );
245- args.add_argument (" --model" ).default_value (" vicuna-v1-7b" );
246- args.add_argument (" --quantization" ).default_value (" auto" );
247248 args.add_argument (" --params" ).default_value (" auto" );
248249 args.add_argument (" --evaluate" ).default_value (false ).implicit_value (true );
249250
250251 try {
251252 args.parse_args (argc, argv);
252253 } catch (const std::runtime_error& err) {
253254 std::cerr << err.what () << std::endl;
254- std::cerr << args;
255+ std::cerr << args << std::endl ;
255256 return 1 ;
256257 }
257258
259+ std::string local_id = args.get <std::string>(" --local-id" );
260+ std::string model = args.get <std::string>(" --model" );
261+ std::string quantization = args.get <std::string>(" --quantization" );
258262 std::string device_name = DetectDeviceName (args.get <std::string>(" --device-name" ));
259263 int device_id = args.get <int >(" --device_id" );
260264 DLDevice device = GetDevice (device_name, device_id);
261265 std::string artifact_path = args.get <std::string>(" --artifact-path" );
262- std::string model = args.get <std::string>(" --model" );
263- std::string quantization = args.get <std::string>(" --quantization" );
264266 std::string params = args.get <std::string>(" --params" );
265267
266268 std::string arch_suffix = GetArchSuffix ();
267269
268- std::optional<std::filesystem::path> lib_path_opt;
270+ std::vector<std::string> local_id_candidates;
271+ std::optional<std::filesystem::path> config_path_opt;
269272
270- std::vector<std::string> quantization_candidates;
271- if (quantization == " auto " ) {
272- quantization_candidates = quantization_presets ;
273+ // Configure local id candidates.
274+ if (local_id != " " ) {
275+ local_id_candidates = {local_id} ;
273276 } else {
274- quantization_candidates = {quantization};
277+ std::vector<std::string> quantization_candidates;
278+ if (quantization == " auto" ) {
279+ quantization_candidates = quantization_presets;
280+ } else {
281+ quantization_candidates = {quantization};
282+ }
283+ for (std::string quantization_candidate : quantization_candidates) {
284+ local_id_candidates.push_back (model + " -" + quantization_candidate);
285+ }
275286 }
276287
277- std::optional<std::filesystem::path> lib_path;
278- for (auto candidate : quantization_candidates) {
279- std::string lib_name = model + " -" + candidate + " -" + device_name;
280- std::vector<std::string> search_paths = {artifact_path + " /" + model + " -" + candidate,
281- artifact_path + " /" + model, artifact_path + " /lib" };
282- // search for lib_x86_64 and lib
283- lib_path_opt = FindFile (search_paths, {lib_name, lib_name + arch_suffix}, GetLibSuffixes ());
284- if (lib_path_opt) {
285- quantization = candidate;
288+ // Search for mlc-llm-config.json.
289+ for (auto local_id_candidate : local_id_candidates) {
290+ std::vector<std::string> config_search_paths = {
291+ artifact_path + " /" + local_id_candidate + " /params" , //
292+ artifact_path + " /prebuilt/" + local_id_candidate};
293+ config_path_opt = FindFile (config_search_paths, {" mlc-llm-config" }, {" .json" });
294+ if (config_path_opt) {
295+ local_id = local_id_candidate;
286296 break ;
287297 }
288298 }
299+ if (!config_path_opt) {
300+ std::cerr << " Cannot find \" mlc-llm-config.json\" in path \" " << artifact_path << " /"
301+ << local_id_candidates[0 ] << " /params/\" , \" " << artifact_path
302+ << " /prebuilt/" + local_id_candidates[0 ] << " \" or other candidate paths." ;
303+ return 1 ;
304+ }
305+ std::cout << " Use config " << config_path_opt.value ().string () << std::endl;
306+ std::filesystem::path model_path = config_path_opt.value ().parent_path ();
307+
308+ // Locate the library.
309+ std::string lib_name = local_id + " -" + device_name;
310+ std::string lib_dir_path;
311+ if (model_path.string ().compare (model_path.string ().length () - 7 , 7 , " /params" ) == 0 ) {
312+ lib_dir_path = model_path.parent_path ().string ();
313+ } else {
314+ lib_dir_path = model_path.parent_path ().string () + " /lib" ;
315+ }
316+ std::optional<std::filesystem::path> lib_path_opt =
317+ FindFile ({lib_dir_path}, {lib_name, lib_name + arch_suffix}, GetLibSuffixes ());
289318 if (!lib_path_opt) {
290- std::cerr << " Cannot find " << model << " lib in preferred path \" " << artifact_path << " /"
291- << model << " -" << quantization_candidates[0 ] << " /" << model << " -"
292- << quantization_candidates[0 ] << " -" << device_name << GetLibSuffixes ()[0 ]
293- << " \" or other candidate paths" ;
319+ std::cerr << " Cannot find library \" " << lib_name << GetLibSuffixes ().back ()
320+ << " \" and other library candidate in " << lib_dir_path << std::endl;
294321 return 1 ;
295322 }
296323 std::cout << " Use lib " << lib_path_opt.value ().string () << std::endl;
297- std::string model_path = lib_path_opt.value ().parent_path ().string ();
298- LOG (INFO) << " model_path = " << model_path;
299- // get artifact path lib name
324+
325+ // Locate the tokenizer files.
300326 std::optional<std::filesystem::path> tokenizer_path_opt =
301- FindFile ({model_path, artifact_path + " / " + model }, {" tokenizer" }, {" .model" , " .json" });
327+ FindFile ({model_path. string () }, {" tokenizer" }, {" .model" , " .json" });
302328 if (!tokenizer_path_opt) {
303329 // Try ByteLevelBPETokenizer
304- tokenizer_path_opt = FindFile ({model_path, artifact_path + " / " + model }, {" vocab" }, {" .json" });
330+ tokenizer_path_opt = FindFile ({model_path. string () }, {" vocab" }, {" .json" });
305331 if (!tokenizer_path_opt) {
306- std::cerr << " Cannot find tokenizer file in " << model_path;
332+ std::cerr << " Cannot find tokenizer file in " << model_path. string () << std::endl ;
307333 return 1 ;
308334 } else {
309335 // GPT2 styles tokenizer needs multiple files, we need to
@@ -312,19 +338,16 @@ int main(int argc, char* argv[]) {
312338 }
313339 }
314340
341+ // Locate the params.
315342 if (params == " auto" ) {
316- auto params_json_opt =
317- FindFile ({model_path + " /params" , artifact_path + " /" + model + " /params" },
318- {" ndarray-cache" }, {" .json" });
343+ auto params_json_opt = FindFile ({model_path}, {" ndarray-cache" }, {" .json" });
319344 if (!params_json_opt) {
320- std::cerr << " Cannot find ndarray-cache.json for params in preferred path \" " << model_path
321- << " /params\" and \" " << artifact_path << " /" + model << " /params." ;
345+ std::cerr << " Cannot find ndarray-cache.json for params in " << model_path << std::endl;
322346 return 1 ;
323347 }
324- std::string params_json = params_json_opt.value ().string ();
325- params = params_json.substr (0 , params_json.length () - 18 );
348+ params = params_json_opt.value ().parent_path ().string ();
326349 } else if (!FindFile ({params}, {" ndarray-cache" }, {" .json" })) {
327- std::cerr << " Cannot find params/ ndarray-cache.json in " << model_path ;
350+ std::cerr << " Cannot find ndarray-cache.json for params in " << params << std::endl ;
328351 return 1 ;
329352 }
330353
@@ -345,7 +368,7 @@ int main(int argc, char* argv[]) {
345368 } catch (const std::runtime_error& err) {
346369 // catch exception so error message
347370 // get reported here without silently quit.
348- std::cerr << err.what ();
371+ std::cerr << err.what () << std::endl ;
349372 return 1 ;
350373 }
351374 return 0 ;
0 commit comments