diff --git a/.gitignore b/.gitignore index b7f27fbfa8880..397f5b5a0e9af 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,10 @@ models/* /main /quantize /server +/result +/perplexity +/embedding +/Pipfile arm_neon.h compile_commands.json diff --git a/convert-ggml-to-pth.py b/convert-ggml-to-pth.py index 20158c9ca8650..8ab17410d7929 100644 --- a/convert-ggml-to-pth.py +++ b/convert-ggml-to-pth.py @@ -84,6 +84,11 @@ def read_variables(fin): shape = shape[::-1] name = fin.read(name_length).decode("utf-8") + # ensure tensor data is aligned + tensor_data_offset = fin.tell() + tensor_data_offset = (tensor_data_offset + 31) & -32 + fin.seek(tensor_data_offset) + if ftype_cur == 2: # 4-bit quantized weights dtype = np.uint8 diff --git a/convert-gptq-to-ggml.py b/convert-gptq-to-ggml.py index 6c77808fcd186..860eb148b2a86 100644 --- a/convert-gptq-to-ggml.py +++ b/convert-gptq-to-ggml.py @@ -72,6 +72,11 @@ def write_header(shape, dst_name, ftype_cur): fout.write(struct.pack("i" * len(shape), *shape[::-1])) fout.write(sname) + # ensure tensor data is aligned + tensor_data_offset = fout.tell() + tensor_data_offset = (tensor_data_offset + 31) & -32 + fout.seek(tensor_data_offset) + def convert_non_q4(src_name, dst_name): v = model[src_name] shape = v.shape diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d83f8a1373251..7d461157b3702 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -24,8 +24,57 @@ from sentencepiece import SentencePieceProcessor -def parse_args(): +QK = 32 + +GGML_TYPE_Q4_0 = 0 +GGML_TYPE_Q4_1 = 1 +GGML_TYPE_I8 = 2 +GGML_TYPE_I16 = 3 +GGML_TYPE_I32 = 4 +GGML_TYPE_F16 = 5 +GGML_TYPE_F32 = 6 + +WTYPES = { + 0: GGML_TYPE_F32, + 1: GGML_TYPE_F16, + 2: GGML_TYPE_Q4_0, + 3: GGML_TYPE_Q4_1, +} + +GGML_BLCK_SIZE = { + GGML_TYPE_Q4_0: QK, + GGML_TYPE_Q4_1: QK, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 1, + GGML_TYPE_I32: 1, + GGML_TYPE_F16: 1, + GGML_TYPE_F32: 1, +} + +GGML_TYPE_SIZE = { + GGML_TYPE_Q4_0: 4 + QK/2, + GGML_TYPE_Q4_1: 4*2 + QK/2, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 2, + GGML_TYPE_I32: 4, + GGML_TYPE_F16: 2, + GGML_TYPE_F32: 4, +} + +def ggml_nelements(shape): + r = 1 + for i in shape: + r *= i + return r + +def ggml_nbytes(shape, ftype): + x = ggml_nelements(shape) + t = WTYPES[ftype] + x *= GGML_TYPE_SIZE[t] + x //= GGML_BLCK_SIZE[t] + return x +def parse_args(): parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1) @@ -33,7 +82,6 @@ def parse_args(): return parser.parse_args() def get_n_parts(dim): - mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} n_parts = mappings.get(dim) if n_parts is None: @@ -44,30 +92,24 @@ def get_n_parts(dim): return n_parts def load_hparams_and_tokenizer(dir_model): - # `dir_model` is something like `models/7B` or `models/7B/`. # "tokenizer.model" is expected under model's parent dir. # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found. # Let's use the model's parent dir directly. model_parent_dir = os.path.dirname(os.path.normpath(dir_model)) - fname_hparams = f"{dir_model}/params.json" fname_tokenizer = f"{model_parent_dir}/tokenizer.model" - with open(fname_hparams, "r") as f: hparams = json.load(f) print(hparams) - tokenizer = SentencePieceProcessor(fname_tokenizer) hparams.update({"vocab_size": tokenizer.vocab_size()}) - return hparams, tokenizer def write_header(fout, hparams, ftype): - keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] values = [ - 0x67676d66, # magic: ggmf in hex + 0x67676a74, # magic: ggjt in hex 1, # file version *[hparams[key] for key in keys], hparams["dim"] // hparams["n_heads"], # rot (obsolete) @@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype): fout.write(struct.pack("i" * len(values), *values)) def write_tokens(fout, tokenizer): - for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): text = " \u2047 ".encode("utf-8") @@ -95,85 +136,141 @@ def write_tokens(fout, tokenizer): fout.write(text) fout.write(struct.pack("f", tokenizer.get_score(i))) -def process_and_write_variables(fout, model, ftype): - +def process_and_write_variables(fout, model, ftype, part_id, n_parts): for name, datao in model.items(): - if name.endswith("freqs"): continue - shape = datao.shape - - print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}") - + # remove dimensions with a single element data = datao.numpy().squeeze() - n_dims = len(shape) + partshape = data.shape + n_dims = len(data.shape) + assert n_dims in (1, 2) + + print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}") - # default type is fp16 + # coerce single-dimensional tensors from float16 to float32 ftype_cur = 1 if ftype == 0 or n_dims == 1: print(" Converting to float32") data = data.astype(np.float32) ftype_cur = 0 - - # header + blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]] + type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]] + + # determine dimension along which multipart tensor is sharded + # + # split_dim 0 regex: + # - output.* + # - layers.*.attention.wq.weight + # - layers.*.attention.wk.weight + # - layers.*.attention.wv.weight + # - layers.*.feed_forward.w1.weight + # - layers.*.feed_forward.w3.weight + # + # split_dim 1 regex: + # - tok_embeddings.* + # - layers.*.attention.wo.weight + # - layers.*.feed_forward.w2.weight + # + if n_dims > 1: + split_dim = 1 + if "tok_embeddings" in name: + split_dim = 1 + elif "layers" in name: + if "attention.wo.weight" in name: + split_dim = 1 + elif "feed_forward.w2.weight" in name: + split_dim = 1 + else: + split_dim = 0 + elif "output" in name: + split_dim = 0 + + # output tensor header + fullshape = list(partshape) + if n_dims > 1: + fullshape[split_dim] *= n_parts sname = name.encode('utf-8') - fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) - for dim in reversed(data.shape): + fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) + for dim in reversed(fullshape): fout.write(struct.pack("i", dim)) fout.write(sname) - # data output to file - data.tofile(fout) + # ensure tensor data is aligned + tensor_data_offset = fout.tell() + while tensor_data_offset % QK != 0: + fout.write(struct.pack("B", 0)) + tensor_data_offset += 1 + + # output unified mappable tensor data + if n_dims == 1 or n_parts == 1: + # copy tensor which we thankfully received in one piece + if part_id == 0: + data.tofile(fout) + elif split_dim == 0: + # reassemble multifile tensor containing some of the rows + rows_per_chunk = partshape[0] + current_row = part_id * rows_per_chunk + bytes_per_row = fullshape[1] // blck_size * type_size + offset = current_row * bytes_per_row + fout.seek(tensor_data_offset + offset) + data.tofile(fout) + elif split_dim == 1: + # reassemble multifile tensor containing some of the cols + cols_per_chunk = partshape[1] + current_col = part_id * cols_per_chunk + bytes_per_row = fullshape[1] // blck_size * type_size + offset_current_col = current_col // blck_size * type_size + for row in range(partshape[0]): + offset_row = row * bytes_per_row + offset = offset_row + offset_current_col + fout.seek(tensor_data_offset + offset) + data[row].tofile(fout) + + # advance file position to next tensor + fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur)) def main(): - args = parse_args() dir_model = args.dir_model ftype = args.ftype ftype_str = ["f32", "f16"] - hparams, tokenizer = load_hparams_and_tokenizer(dir_model) print(args) # if only writing vocab to file if args.vocab_only: - fname_model = f"{dir_model}/consolidated.00.pth" fname_out = f"{dir_model}/ggml-vocab.bin" - print(f"Extracting only the vocab from '{fname_model}'\n") - - + model = torch.load(fname_model, map_location="cpu") with open(fname_out, "wb") as fout: write_header(fout, hparams, ftype) write_tokens(fout, tokenizer) - - + del model print(f"Done. Output file: {fname_out}\n") - return n_parts = get_n_parts(hparams["dim"]) - - for p in range(n_parts): - - print(f"Processing part {p+1} of {n_parts}\n") - - fname_model = f"{dir_model}/consolidated.0{p}.pth" - fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" - - model = torch.load(fname_model, map_location="cpu") - - with open(fname_out, "wb") as fout: - write_header(fout, hparams, ftype) - write_tokens(fout, tokenizer) - process_and_write_variables(fout, model, ftype) - - del model - - print(f"Done. Output file: {fname_out}, (part {p})\n") + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin" + + # we output a single file for ggml + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + offset_of_tensors = fout.tell() + # the tensors we load could be split across multiple files + for part_id in range(n_parts): + fout.seek(offset_of_tensors) + print(f"Processing part {part_id+1} of {n_parts}\n") + fname_model = f"{dir_model}/consolidated.0{part_id}.pth" + model = torch.load(fname_model, map_location="cpu") + process_and_write_variables(fout, model, ftype, part_id, n_parts) + del model + + print(f"Done. Output file: {fname_out}\n") if __name__ == "__main__": main() diff --git a/llama.cpp b/llama.cpp index 87633f9726b21..b00e065230433 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12,17 +12,19 @@ #include #include -// mmap -#if defined (__unix__) || defined (__APPLE__) -# include -# include -# include -#elif defined(_WIN32) -# define WIN32_LEAN_AND_MEAN -# include -//#include +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#include +#include +#include #endif +#define Min(X, Y) ((Y) > (X) ? (X) : (Y)) +#define Max(X, Y) ((Y) < (X) ? (X) : (Y)) + #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -155,7 +157,7 @@ struct llama_model { // model memory mapped file void * mm_addr = NULL; - size_t mm_length = 0; + uint64_t mm_length = 0; // tensors int n_loaded; @@ -180,6 +182,7 @@ struct llama_context { int64_t t_load_us = 0; int64_t t_start_us = 0; + bool has_evaluated_once = false; int64_t t_sample_us = 0; int64_t t_eval_us = 0; @@ -221,7 +224,7 @@ struct llama_context { } if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size); } buf_last = i; @@ -304,59 +307,57 @@ struct llama_context_params llama_context_default_params() { // model loading // -static void mmap_file(const char* fname, void * &mm_addr, size_t &mm_length) { -#if defined(MAP_FAILED) - // POSIX - int fd = open(fname, O_RDONLY); - mm_length = lseek(fd, 0, SEEK_END); - mm_addr = mmap(NULL, mm_length, PROT_READ, MAP_SHARED, fd, 0); - close(fd); - if (mm_addr == MAP_FAILED) { - perror("mmap failed"); - mm_addr = NULL; - mm_length = 0; - } -#elif defined(_WIN32) - mm_addr = NULL; - - HANDLE hFile = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - return; - } - - // not really necessary +static void *mmap_file(const char *fname, uint64_t *mm_length) { +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) + HANDLE hFile = CreateFileA(fname, + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + NULL, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED, + NULL); + if (hFile == INVALID_HANDLE_VALUE) return 0; LARGE_INTEGER fileSize; + fileSize.QuadPart = -1; GetFileSizeEx(hFile, &fileSize); - mm_length = fileSize; - + int64_t length = fileSize.QuadPart; HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); CloseHandle(hFile); - - if (hMapping == NULL) { - return; - } - - mm_addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + if (!hMapping) return 0; + void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); CloseHandle(hMapping); + if (!addr) return 0; #else - mm_addr = NULL; - mm_length = 0; - (void)(fname); // suppress warnings + int fd = open(fname, O_RDONLY); + if (fd == -1) return 0; + int64_t length = lseek(fd, 0, SEEK_END); + void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0); + close(fd); + if (addr == MAP_FAILED) return 0; #endif + *mm_length = length; + return addr; } static void munmap_file(void * addr, size_t length) { -#if defined(MAP_FAILED) - // POSIX - munmap(addr, length); -#elif defined(_WIN32) +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) UnmapViewOfFile(addr); #else - (void)(addr); // suppress warnings - (void)(length); + munmap(addr, length); #endif } +static bool report_bad_magic(const char *path) { + fprintf(stderr, + "%s: invalid model file (bad magic)\n" + "you most likely need to regenerate your ggml files\n" + "the benefit is you'll get 10-100x faster load times\n" + "see https://github.com/ggerganov/llama.cpp/issues/91\n" + "use convert-pth-to-ggml.py on your llama model files\n", + path); + return false; +} + static bool llama_model_load( const std::string & fname, llama_context & lctx, @@ -368,23 +369,24 @@ static bool llama_model_load( void *progress_callback_user_data) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); - const int64_t t_start_us = ggml_time_us(); - - lctx.t_start_us = t_start_us; - - // TODO: this could probably be smaller when using mmap - std::vector f_buf(1024*1024); + lctx.t_start_us = ggml_time_us(); auto & model = lctx.model; auto & vocab = lctx.vocab; auto fin = std::ifstream(fname, std::ios::binary); - fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); return false; } + std::vector f_buf(1024*1024); + fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + + fin.seekg(0, fin.end); + const size_t file_size = fin.tellg(); + fin.seekg(0); + // verify magic { uint32_t magic; @@ -395,8 +397,7 @@ static bool llama_model_load( return false; } if (magic != LLAMA_FILE_MAGIC) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); - return false; + return report_bad_magic(fname.c_str()); } uint32_t format_version; @@ -519,54 +520,24 @@ static bool llama_model_load( } } - bool use_mmap = (n_parts == 1); - - // try to memory map the model file - void * mm_addr = NULL; - if (use_mmap) { - mmap_file(fname.c_str(), model.mm_addr, model.mm_length); - if (model.mm_addr == NULL) { - use_mmap = false; - } - else { - mm_addr = model.mm_addr; - } + // map model into memory + char *mm_addr = NULL; + model.mm_addr = mmap_file(fname.c_str(), &model.mm_length); + if (model.mm_addr == NULL) { + fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str()); + return false; } + mm_addr = (char *)model.mm_addr; + fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0)); auto & ctx = model.ctx; size_t ctx_size = 0; { - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; + const auto &hparams = model.hparams; const int n_layer = hparams.n_layer; - const int n_vocab = hparams.n_vocab; - - if (!use_mmap) { - ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings - - ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm - - ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output - - ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm - - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo - - ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm - - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1 - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 - } - ctx_size += (5 + 10*n_layer)*256; // object overhead - - fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); } // print memory requirements @@ -576,6 +547,7 @@ static bool llama_model_load( // this is the total memory required to run the inference const size_t mem_required = ctx_size + + model.mm_length + MEM_REQ_SCRATCH0.at(model.type) + MEM_REQ_SCRATCH1.at(model.type) + MEM_REQ_EVAL.at (model.type); @@ -595,7 +567,7 @@ static bool llama_model_load( struct ggml_init_params params = { /*.mem_size =*/ lctx.model.buf.size(), /*.mem_buffer =*/ lctx.model.buf.data(), - /*.no_alloc =*/ use_mmap, + /*.no_alloc =*/ true, }; model.ctx = ggml_init(params); @@ -658,241 +630,106 @@ static bool llama_model_load( } } - const size_t file_offset = fin.tellg(); - - fin.close(); - std::vector tmp; if (progress_callback) { progress_callback(0.0, progress_callback_user_data); } - for (int i = 0; i < n_parts; ++i) { - const int part_id = i; - //const int part_id = n_parts - i - 1; - - std::string fname_part = fname; - if (i > 0) { - fname_part += "." + std::to_string(i); - } + fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str()); - fprintf(stderr, "%s: loading model part %d/%d from '%s'%s\n", __func__, i+1, n_parts, fname_part.c_str(), use_mmap ? " (memory mapped)" : ""); + // load weights + { + size_t total_size = 0; + model.n_loaded = 0; - fin = std::ifstream(fname_part, std::ios::binary); - fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; - fin.seekg(0, fin.end); - const size_t file_size = fin.tellg(); + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - fin.seekg(file_offset); + if (fin.eof()) { + break; + } - // load weights - { - size_t total_size = 0; + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } - model.n_loaded = 0; + std::string name(length, 0); + fin.read(&name[0], length); - fprintf(stderr, "%s: ", __func__); + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; + auto tensor = model.tensors[name.data()]; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return false; + } + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]); + } - if (fin.eof()) { + switch (ftype) { + case 0: // f32 + case 1: // f16 break; - } - - int32_t nelements = 1; - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); - nelements *= ne[i]; - } - - std::string name(length, 0); - fin.read(&name[0], length); - - if (model.tensors.find(name.data()) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + case 2: // q4_0 + case 3: // q4_1 + assert(ne[0] % 64 == 0); + break; + default: + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); return false; - } - - // split_type = 0: split by columns - // split_type = 1: split by rows - int split_type = 0; - - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - if (name.find("tok_embeddings") != std::string::npos) { - split_type = 0; - } else if (name.find("layers") != std::string::npos) { - if (name.find("attention.wo.weight") != std::string::npos) { - split_type = 0; - } else if (name.find("feed_forward.w2.weight") != std::string::npos) { - split_type = 0; - } else { - split_type = 1; - } - } else if (name.find("output") != std::string::npos) { - split_type = 1; - } - - auto tensor = model.tensors[name.data()]; - - if (n_dims == 1) { - if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } - } else { - if (ggml_nelements(tensor)/n_parts != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } - } - - if (n_dims == 1) { - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); - return false; - } - } else { - if (split_type == 0) { - if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]); - return false; - } - } else { - if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]); - return false; - } - } - } - - if (0) { - static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; - fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type); - } - - size_t bpe = 0; - - switch (ftype) { - case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; - case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; - case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; - case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; - default: - { - fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); - return false; - } - }; - - if (n_dims == 1 || n_parts == 1) { - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } - - if (part_id == 0) { - if (mm_addr) { - off_t offset = fin.tellg(); - tensor->data = (char *) mm_addr + offset; - fin.seekg(ggml_nbytes(tensor), std::ios::cur); - } - else { - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - } - } else { - fin.seekg(ggml_nbytes(tensor), std::ios::cur); - } - - total_size += ggml_nbytes(tensor); - } else { - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe); - return false; - } - - if (split_type == 0) { - const int np0 = ne[0]; - - const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - assert(row_size == tensor->nb[1]); - - for (int i1 = 0; i1 < ne[1]; ++i1) { - const size_t offset_row = i1*row_size; - const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - fin.read(reinterpret_cast(tensor->data) + offset, row_size/n_parts); - } - } else { - const int np1 = ne[1]; - - const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - - for (int i1 = 0; i1 < ne[1]; ++i1) { - const size_t offset_row = (i1 + part_id*np1)*row_size; - fin.read(reinterpret_cast(tensor->data) + offset_row, row_size); - } - } - - total_size += ggml_nbytes(tensor)/n_parts; - } - - //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - model.n_loaded++; - - // progress - if (progress_callback) { - float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset); - float current_progress = (float(i) + current_file_progress) / float(n_parts); - progress_callback(current_progress, progress_callback_user_data); - } - if (model.n_loaded % 8 == 0) { - fprintf(stderr, "."); - fflush(stderr); - } - } - - fprintf(stderr, " done\n"); + }; - fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded); - if (model.n_loaded == 0) { - fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (model.n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); - return false; + // load the tensor data into memory without copying or reading it + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(tensor); + offset = (offset + 31) & -32; + tensor->data = mm_addr + offset; + fin.seekg(offset + tensor_data_size); + total_size += tensor_data_size; + model.n_loaded++; + + // progress + if (progress_callback) { + double current_progress = size_t(fin.tellg()) / double(file_size); + progress_callback(current_progress, progress_callback_user_data); } } fin.close(); + + fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded); + if (model.n_loaded == 0) { + fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } } - lctx.t_load_us = ggml_time_us() - t_start_us; + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; if (progress_callback) { progress_callback(1.0, progress_callback_user_data); @@ -1216,7 +1053,7 @@ struct llama_tokenizer { size_t offs = 0; while (offs < text.size()) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); + size_t char_len = Min(text.size() - offs, utf8_len(text[offs])); sym.text = text.c_str() + offs; sym.n = char_len; offs += char_len; @@ -1381,7 +1218,7 @@ static llama_vocab::id llama_sample_top_p_top_k( float maxl = -std::numeric_limits::infinity(); for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); + maxl = Max(maxl, kv.first); } // compute probs for the top k tokens @@ -1475,8 +1312,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s return false; } if (magic != LLAMA_FILE_MAGIC) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); - return false; + return report_bad_magic(fname_inp.c_str()); } fout.write((char *) &magic, sizeof(magic)); @@ -1542,8 +1378,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s fout.write((char *) &len, sizeof(len)); word.resize(len); - finp.read ((char *) word.data(), len); - fout.write((char *) word.data(), len); + finp.read ((char *) &word[0], len); + fout.write((char *) &word[0], len); float score; finp.read ((char *) &score, sizeof(score)); @@ -1593,6 +1429,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s std::string name(length, 0); finp.read (&name[0], length); + { + // ensure tensor data is aligned + uint64_t offset = finp.tellg(); + offset = (offset + 31) & -32; + finp.seekg(offset); + } + { static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); @@ -1648,6 +1491,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s } fout.write(&name[0], length); + { + // ensure tensor data is aligned + uint64_t offset = fout.tellp(); + offset = (offset + 31) & -32; + fout.seekp(offset); + } + if (quantize) { printf("quantizing .. "); work.resize(nelements); // for quantization @@ -1824,7 +1674,11 @@ int llama_eval( fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } - + // get a more accurate load time, upon first eval + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } return 0; } @@ -1917,9 +1771,9 @@ llama_token llama_sample_top_p_top_k( void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_eval = std::max(1, ctx->n_eval); - const int32_t n_p_eval = std::max(1, ctx->n_p_eval); + const int32_t n_sample = Max(1, ctx->n_sample); + const int32_t n_eval = Max(1, ctx->n_eval); + const int32_t n_p_eval = Max(1, ctx->n_p_eval); fprintf(stderr, "\n"); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); @@ -1931,7 +1785,6 @@ void llama_print_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx) { ctx->t_start_us = ggml_time_us(); - ctx->t_sample_us = ctx->n_sample = 0; ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; diff --git a/llama.h b/llama.h index 3368de3e00f58..258de5a944976 100644 --- a/llama.h +++ b/llama.h @@ -20,7 +20,7 @@ #endif #define LLAMA_FILE_VERSION 1 -#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex +#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files #ifdef __cplusplus diff --git a/models/ggml-vocab.bin b/models/ggml-vocab.bin index 3651f708e80ea..38f63493a97a7 100644 Binary files a/models/ggml-vocab.bin and b/models/ggml-vocab.bin differ