Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should use mmap for model loading #91

Closed
l29ah opened this issue Mar 13, 2023 · 58 comments
Closed

Should use mmap for model loading #91

l29ah opened this issue Mar 13, 2023 · 58 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@l29ah
Copy link
Contributor

l29ah commented Mar 13, 2023

So it doesn't create an extra copy in RAM and lives in the kernel page cache happily, loading instantly on subsequent runs.

@ggerganov ggerganov added enhancement New feature or request good first issue Good for newcomers labels Mar 13, 2023
@apaz-cli
Copy link
Contributor

@ggerganov I'm working on putting together the PR. Almost done.

I don't know anything about the order that ggml accesses the weights in. Would you say that it's sequential? If so, there's also madvise().

@apage43
Copy link
Contributor

apage43 commented Mar 14, 2023

you probably don't want to use madvise+MADV_SEQUENTIAL, as in addition to increasing the amount of readahead it also causes pages to be evicted after they've been read - the entire model is going to be executed at least once per output token and read all the weights, MADV_SEQUENTIAL would potentially kick them all out and reread them repeatedly.

what may be more appropriate is to use MADV_WILLNEED on the whole model to kick off paging it all in without needing to wait for it to finish - but mmap can be tricky and you would probably want to make it an option rather than the default as it may not be a perf improvement on all setups and can wind up being slower than regular I/O due to causing lots of TLB shootdowns - you would want to benchmark it, as its not unlikely you may be trading improved time-to-first-token for worse overall throughput

@jart
Copy link
Contributor

jart commented Mar 16, 2023

That will definitely happen with posix_fadvise(sequential), which has a very gentle impact on file caches on Linux. What we might end up wanting here is madvise(random). In order to do that though, we first would need to find a way to avoid the loading and deserialization process where c/c++ data structures are constructed in memory, and instead have the runtime data structures just be mapped directly from the file. That would ensure 100% reduction in startup time, which means we can start generating tokens asap, and pages get loaded off disk on an as-needed basis. Once we're able to implement that design pattern, madvise(random) vs. madvise(sequential) would be a tool that lets the kernel know how to utilize an under-utilized disk, to make predictions on avoiding page faults.

I'm still getting up to speed on this codebase, so I'd like to hear everyone's ideas on how best we'd ensure object data structures (or their floating point content at the very least) could be made directly mappable, thus side-stepping the loading process. One dirty hack for example I've been considering, would be overriding the memory allocators to get all objects at a fixed address, and persisting that to disk. That way, when all the C/C++ objects are loaded back into memory using MAP_FIXED, no relocations would need to be performed. That's obviously a less portable and non-ideal solution, but it'd help us get instant loading happening as quickly as possible, and furthermore permit us an opportunity to explore precisely how sparse the model's memory usage patterns actually are.

@ggerganov
Copy link
Owner

ggerganov commented Mar 16, 2023

@jart
Thanks for stepping in. I will share briefly an idea that might be useful. Just mind I haven't looked into details of the discussion - will do in a few days once things cool off a bit here.

I think ggml_context is extremely well fit for mmap if I understand how it works. The ggml_context uses an externally provided buffer of memory with a pre-determined size:

llama.cpp/main.cpp

Lines 569 to 575 in 7213110

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
};
struct ggml_context * ctx0 = ggml_init(params);

All tensors and model parameters are "emplaced" into this buffer. There are no extra allocations ocuring.
Once you load the model, you can simply dump the memory buffer provided to ggml_context and next time, you can simply load this buffer instead of constructing it. Everything should work.

Edit: So above I incorrectly referenced the "eval" ggml_context which has it's own buffer. The "model" ggml_context is here:

llama.cpp/main.cpp

Lines 228 to 240 in 7213110

// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

Same stuff. If the pointer is NULL it's allocated inside ggml for convenience.

@nazthelizard122
Copy link

All tensors and model parameters are "emplaced" into this buffer. There are no extra allocations ocuring.

In other words, there is a set memory cap?
Like many other people I have basically no experience in AIs or memory buffers
I'm just a guy pushing buttons until either something explodes or an ai becomes self aware lol

@bitRAKE
Copy link
Contributor

bitRAKE commented Mar 16, 2023

For loading from a physical/network drive resharding the larger models to a single file might help, imho. Whereas loading multiple files in parallel would be slower.

https://github.com/jankais3r/LLaMA_MPS/blob/main/reshard.py

@apaz-cli
Copy link
Contributor

I'm getting ready to take another swing at it. My idea of what to do so far:

  1. Create functions in utils.h called llama_load_buffer(), llama_save_buffer(), and llama_destroy_buffer(). These will mmap() (or just malloc() and read), save, and munmap() (or just free()) the buffers respectively. So, files saved on one machine can't necessarily be loaded by another. These files would be stored in some folder, and have the names of the original files. Either in models, or /tmp, or a new folder. This will also hopefully be useful for implementing saving the model state.

  2. Add a new command line argument that tells llama_model_load() to look in this cache folder first. If it finds the file, llama_load_buffer() the file to get your ggml_init_params. Then do whatever else needs to be done (initialize the vocab, get hparams, etc) and exit the function. If the argument is present, call llama_save_buffer() first. Also, call llama_destroy_buffer() at the appropriate location.

I can do 1, I'll submit a PR for that shortly, but it isn't super clear to me how memory is laid out so that I can do 2. In particular, I'm wondering about the "whatever else needs to be done" part. I'm certain that I'm missing something, and it probably wouldn't be obvious to me what I'm breaking even after hours of monkeying.

@jart
Copy link
Contributor

jart commented Mar 16, 2023

My concern with doing that is, wouldn't it effectively double the disk usage? LLaMA is big enough that folks are already likely to be stretched thin on disk space after creating the second copy needed to quantize the model. I'm still working on studying the codebase to determine what exactly are the transformation that need to be made at runtime. For example, if it's just a bunch of float16's on disk, and we're using a bunch of float16's in memory, then I don't see why the buffer field of these tensors couldn't just be populated with pointers to the appropriate positions in the file. Unless of course it needed to be reshaped or shifted to meet things like AVX alignment requirements. In that case, we'd ideally want to modify the quantizer script so that it generates something suitable for our purposes, so that only a single conversion step needs to be performed.

@bitRAKE
Copy link
Contributor

bitRAKE commented Mar 16, 2023

(comment is more a #202 thing)

This is the way I was thinking about it:

After the model loads and the prompt is tokenized, create a hash of the context. If that hash exists in the cache directory (and flag is set), load state.

Saving the state has two modes, imho: post-prompt (only works when state hasn't been loaded) and at end of generation. The post-prompt mode allows jump-starting the model. End saving allows to start-up where one left off.

As to the memory organization - I'd leave that to existing code. The hash would act as pseudo-verification of that it's okay to load a buffer of bytes. The model and prompt would need to be the same, maybe even other options.

Just throwing out some head-space. Haven't starting coding anything.

@jart
Copy link
Contributor

jart commented Mar 16, 2023

Wait. What problem are we trying to solve here exactly? Are we trying to (1) eliminate the three second startup delay? Or are we trying to (2) store the changes made to memory back to disk? Because if your goal is to solve (2) then the only thing you need to save are the random seed and the prompt, since that would restore the state deterministically. Right now I'm focusing on (1) since having fast mmap() loading would not change llama.cpp's behavior, and would instead simply make it go faster. If you want (2) then this change could be the stepping stone you need. All you'd have to do is change MAP_PRIVATE to be MAP_SHARED instead, and whatever mutations are made to the tensors in memory will be transparently remembered on disk. However that's orthogonal to my intended goals at the moment.

@bitRAKE
Copy link
Contributor

bitRAKE commented Mar 16, 2023

I did not intent to expand the meaning of the thread. (2) should probably be addressed elsewhere. #202

@apaz-cli
Copy link
Contributor

@jart It would double the disk usage, yes. But so does converting the model, and so does quantizing it. I think people are prepared for this.

You're right though in that the scripts that convert the model are probably the best way to do this. I was only thinking about implementing a cache for ggml_init_params as originally suggested. Ideally though, everything should just be one call, for everything from vocab/hparams to model weights.

@jart
Copy link
Contributor

jart commented Mar 16, 2023

So I added some logging statements to track the read() operations that are happening. It's 200k+ lines that look like this:

moving 0x640 bytes from offset 0x4a607 to offset 0 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4ac47 to offset 0xc80 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4b287 to offset 0x1900 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4b8c7 to offset 0x2580 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4bf07 to offset 0x3200 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4c547 to offset 0x3e80 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4cb87 to offset 0x4b00 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4d1c7 to offset 0x5780 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4d807 to offset 0x6400 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4de47 to offset 0x7080 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4e487 to offset 0x7d00 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4eac7 to offset 0x8980 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4f107 to offset 0x9600 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4f747 to offset 0xa280 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x4fd87 to offset 0xaf00 (n_dims=2 n_parts=2)
moving 0x640 bytes from offset 0x503c7 to offset 0xbb80 (n_dims=2 n_parts=2)

All it's doing is (1) reshaping and (2) aligning the data in the file. That's why llama.cpp takes several seconds to start. It wouldn't make sense to cache a bunch of memcpy() operations. The quickest thing we could do is introduce a third conversion step that creates a new file format, where the data is in the appropriate shape and alignment ahead of time. Then we could work our way backwards through the conversion tools, to reduce the number of pipeline chores from 3 to 1.

@jart
Copy link
Contributor

jart commented Mar 17, 2023

Here's another reason why this issue is so important. I just ran the 13B model with F16C on my workstation with 32GB of RAM. The model, once loaded, comes very close to hitting the physical memory limit, using maybe ~30GB peak RSS. Bringing memory up to the edge of swapping effectively compounds tragedy, since the kernel reacts by dropping its file caches. If we were using mmap() then the kernel would know that the loaded pages and the file pages are the same thing. But since we're copying the memory, the file cache goes away, and loading ends up taking a minute long each time.

@jart
Copy link
Contributor

jart commented Mar 17, 2023

@apaz-cli Have you attempted implementing yet the thing you proposed? It might work if you use MAP_FIXED when reloading it, since GGML appears to allocate objects with pointers too.

@apaz-cli
Copy link
Contributor

@jart I have no idea how to support that in a portable way. I haven't dug too deep into it. I'm halfway through implementing part 1.

The troubling thing is actually the default implementation for opening files with the C/C++ stdlib. There is no portable way in C++11 to check the size of a file or binary stream, not even with fseek()/seekg() and ftell()/tellg(). C++17 resolves this with std::filesystem, but other versions of the standard are out of luck. You have to guess, and resize/copy if you're wrong. Which seems not acceptable. The other way to do it is to read all the bytes of the file once just to get the size, and then do it again. This seems also not acceptable. Unless the compiler is somehow magically able to see through it. I haven't checked, but it doesn't seem that likely.

See this link to the C standard. The C++ standard says the same about it's own streams.

Although it's UB, the fseek()/ftell() dance is a classic, and is supported on almost all platforms. So we could just do it anyway.

@apage43
Copy link
Contributor

apage43 commented Mar 17, 2023

the mmap operation itself is going to have its own portability issues, supporting all platforms on a first pass with no #ifdefs is unlikely here - mmap also requires the fd being mapped to actually be a file on a filesystem, which is probably implied if its seekable, but fstat (or equivalent) is probably the better way to check for that and get the size at the same time

@jart
Copy link
Contributor

jart commented Mar 17, 2023

I've implemented a working prototype for UNIX systems. 5b8023d

Your LLaMA AGI models will now load instantly without any user visible latency.

llama-mmap2.mp4

The above video should be all the proof you need.

I did this by writing an append-only malloc() that lets us transactionally capture heap allocations, so they can be restored on subsequent runs. It's a ~200 LOC change that only took me a few hours and it worked the first time. Much easier than the alternative, which likely would have entailed serializing C++ STL objects. This change could be productionized as it stands. I'd need to add the WIN32 mmap() code. I'd also need to store the flags and possibly file timestamps for comparison in the serialized object too, since right now the state change can only happen when magic.dat is deleted. We'd also want to put this behind a flag.

However, I still firmly believe this change is not the right thing to do. The file format should be fixed so that it's aligned and doesn't need to be reshaped. Doing that will take 1k+ lines of refactorings, plus a data migration, plus changes to the GGML API. I don't think it's possible to use the ggml_init_params::mem_buffer field, because that memory region contains pointers. That makes it basically the same as my malloc() capturing code, except less generalized. If you wanted to mmap() that field in a portable way, you'd have to do what linkers do, and apply fixups to all the pointers. What I thought might make more sense, is doing a tensor->data = mmap() from a given file offset for each tensor (since I'm assuming there aren't that many of them?)

I'll also note that the gains here are mostly due to not copying memory anymore, and better cooperation with the kernel's page manager. We unfortunately aren't getting any additional gains from lazy page loading, since this is a dense model. To generate a single token, every single page in the model file needs to be loaded. What this means is that first runs that load from spinning disk are still going to be slow, even though the average case has greatly improved. I don't view that as a problem, since having the better cooperation with the page manager ensures that the kernel file caches are much less likely to be evicted.

@j-f1
Copy link
Collaborator

j-f1 commented Mar 17, 2023

I'll also note that the gains here are mostly due to not copying memory anymore, and better cooperation with the kernel's page manager. We unfortunately aren't getting any additional gains from lazy page loading, since this is a dense model. To generate a single token, every single page in the model file needs to be loaded. What this means is that first runs that load from spinning disk are still going to be slow, even though the average case has greatly improved. I don't view that as a problem, since having the better cooperation with the page manager ensures that the kernel file caches are much less likely to be evicted.

On devices with less RAM and no swap (iOS), will this allow the inference to proceed without hitting the memory limit by evicting weights from the page cache during inference?

For the pointers issue, you could make a custom smart pointer type that uses a global variable to do the fixups at runtime (not sure if this would have a perf impact though):

void *ctx_base;

template<typename T>
class ctx_ptr {
  off_t offset;
  inline ctx_ptr(T *value): offset(value - ctx_base) {
    assert(offset > 0);
  }
  inline T operator->() {
    return *(ctx_base + offset);
  }
};

// usage:
ctx_base = 0x1234;
ctx_ptr<ggml_whatever> ptr = ctx_ptr(the_raw_ptr);

ctx_base = 0x5678;
printf("%s\n", ptr->some_field);

@hmage
Copy link

hmage commented Mar 17, 2023

mmap(2) allows to use files larger than available RAM.

@ggerganov
Copy link
Owner

@jart

I don't think it's possible to use the ggml_init_params::mem_buffer field, because that memory region contains pointers.

Sorry, I missed that.

What I thought might make more sense, is doing a tensor->data = mmap() from a given file offset for each tensor (since I'm assuming there aren't that many of them?)

This is possible for non-sharded models. For example, if you take a look at the gpt-2 example, the model loading is a straight read from the file into each tensor:

https://github.com/ggerganov/ggml/blob/4c2f924553312c490e79e6e1739c6f4aa9bbd450/examples/gpt-2/main.cpp#L329-L331

However, the larger LLaMA models ( >7B ) are split into parts, and one has to merge them. Each tensor is split across all the parts. Either by rows or by columns. So the reading logic becomes complicated because of that:

llama.cpp/main.cpp

Lines 457 to 503 in 367946c

if (n_dims == 1 || n_parts == 1) {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
if (part_id == 0) {
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
} else {
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
total_size += ggml_nbytes(tensor);
} else {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
return false;
}
if (split_type == 0) {
const int np0 = ne[0];
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
assert(row_size == tensor->nb[1]);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = i1*row_size;
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
}
} else {
const int np1 = ne[1];
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = (i1 + part_id*np1)*row_size;
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
}
}
total_size += ggml_nbytes(tensor)/n_parts;
}

We could create a combined ggml model file as part of the setup process, but this way we go back to the issue of needing double disk space.

Anyway, impressive stuff!

@j-f1
Copy link
Collaborator

j-f1 commented Mar 17, 2023

We could create a combined ggml model file as part of the setup process, but this way we go back to the issue of needing double disk space.

Maybe the conversion script could combine all the .pth files into one GGML file?

@IkerAriz
Copy link

@ggerganov

However, the larger LLaMA models ( >7B ) are split into parts, and one has to merge them. Each tensor is split across all the parts. Either by rows or by columns. So the reading logic becomes complicated because of that:

Perhaps the merging can be addressed with multiple mmaps stitched together into a contiguous region as described here:

https://stackoverflow.com/a/34560306

@jart
Copy link
Contributor

jart commented Mar 28, 2023

I've just tested @slaren's change locally. It works great for 7B (which only has 1-dimensional tensors), loads instantly, and has zero performance regression on evaluation. My only worry is that tensor->data is no longer aligned to GGML_MEM_ALIGN because the file offsets aren't aligned, Does anyone know if that matters? It runs fine for me, unaligned, on AVX, SSE, and Apple M1 is fine too. I can't speak for other ISAs like POWER though.

@l29ah
Copy link
Contributor Author

l29ah commented Mar 28, 2023

Proof of concept here: slaren@fc68512

Tried it out on the single-part alpaca-13B-ggml/ggml-model-q4_0.bin i've grabbed from a torrent, and it works like a charm, thank you!

jart added a commit that referenced this issue Mar 28, 2023
- We have pretty high quality POSIX polyfills now
- We no longer need to override malloc()

Tracked by issue #91
Improves upon #341
@jart
Copy link
Contributor

jart commented Mar 28, 2023

Tests are all green on the mmap branch. It now works on Windows with MSVC. This is thanks to some pretty good POSIX polyfills that myself and @oKatanaaa put a lot of thought into creating. See https://github.com/ggerganov/llama.cpp/blob/mmap/mmap.c The mmap branch no longer relies on overriding malloc().

Unfortunately I no longer see a smooth path for merging the mmap branch into master, because #370 introduced refactorings that don't mesh with its design. @sloren's suggestion combined with the new mmap() win32 polyfill would probably be our best bet right now for getting a measurable improvement into the master branch.

@slaren
Copy link
Collaborator

slaren commented Mar 28, 2023

I can try to integrate the Windows patch on my branch and cleanup the code a bit so that at least it can be used with 7B for now.

@ggerganov what do you think would be the best way to modify ggml to allow it to create tensors without reserving memory? (like the ggml_nomem() hack in slaren@c13eaf3)

@ggerganov
Copy link
Owner

ggerganov commented Mar 28, 2023

@jart

I've just tested @slaren's change locally. It works great for 7B (which only has 1-dimensional tensors), loads instantly, and has zero performance regression on evaluation. My only worry is that tensor->data is no longer aligned to GGML_MEM_ALIGN because the file offsets aren't aligned, Does anyone know if that matters? It runs fine for me, unaligned, on AVX, SSE, and Apple M1 is fine too. I can't speak for other ISAs like POWER though.

Last time I researched this topic, I concluded that there is no real difference in performance between aligned and non-aligned memory memory. Probably this mattered in the past, but nowadays I think the differences are negligible.

I think ggml.c uses only non-aligned SIMD intrinsics, unless I missed some recent change on this. So if that is all true, I think it should not really matter if the memory buffer is aligned to GGML_MEM_ALIGN or not.

And even if I am wrong, given that all tensors in the model have sizes that are multiples of 32 I think it wouldn't be difficult to guarantee the alignment, right?

@slaren

I can try to integrate the Windows patch on my branch and cleanup the code a bit so that at least it can be used with 7B for now.

@ggerganov what do you think would be the best way to modify ggml to allow it to create tensors without reserving memory? (like the ggml_nomem() hack in slaren@c13eaf3)

It has to be a new variable in the struct ggml_init_params:

llama.cpp/ggml.h

Lines 314 to 320 in 2a98bc1

struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
void * mem_buffer; // if NULL, memory will be allocated internally
};

Name it bool no_alloc or whatever you prefer

@jart
Copy link
Contributor

jart commented Mar 28, 2023

And even if I am wrong, given that all tensors in the model have sizes that are multiples of 32 I think it wouldn't be difficult to guarantee the alignment, right?

I can't imagine it'd be difficult at all. I think the main thing that misaligns the format, is those nul-terminated token strings at the beginning of the file. If you do a roundup(offset, 32) after producing those, then you might luck out and everything else will magically align.

@ggerganov
Copy link
Owner

Totally - at this point we can easily make changes to the ggml model format and not worry too much.
Having a 32-byte rounded header sounds like a good idea anyway - regardless if we really need it or not

jart added a commit to jart/llama.cpp that referenced this issue Mar 29, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
jart added a commit to jart/llama.cpp that referenced this issue Mar 29, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
jart added a commit to jart/llama.cpp that referenced this issue Mar 29, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
jart added a commit to jart/llama.cpp that referenced this issue Mar 30, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
jart added a commit to jart/llama.cpp that referenced this issue Mar 30, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
jart added a commit to jart/llama.cpp that referenced this issue Mar 30, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
@jart jart closed this as completed in 78ca983 Mar 30, 2023
Nuked88 pushed a commit to Nuked88/llama.http that referenced this issue Mar 31, 2023
This is a breaking change that's going to give you three benefits:

1. Your inference commands should load 100x faster
2. You may be able to safely load models 2x larger
3. You can run many concurrent inference processes

This was accomplished by changing the file format so we can mmap()
weights directly into memory without having to read() or copy them
thereby ensuring the kernel can make its file cache pages directly
accessible to our inference processes; and secondly, that the file
cache pages are much less likely to get evicted (which would force
loads to hit disk) because they're no longer competing with memory
pages that were needlessly created by gigabytes of standard i/o.

The new file format supports single-file models like LLaMA 7b, and
it also supports multi-file models like LLaMA 13B. Our Python tool
now merges the foo.1, foo.2, etc. files back into a single file so
that the C++ code which maps it doesn't need to reshape data every
time. That's made llama.cpp so much simpler. Much of its load code
has now been deleted.

Furthermore, this change ensures that tensors are aligned properly
on a 32-byte boundary. That opens the door to seeing if we can get
additional performance gains on some microprocessors, by using ops
that require memory alignment.

Lastly note that both POSIX and the Windows platform are supported

Fixes ggerganov#91
@Ar57m
Copy link

Ar57m commented Aug 5, 2023

@ggerganov is possible to you guys make mmap work on gptj(pyg6b)?

@apaz-cli
Copy link
Contributor

apaz-cli commented Aug 5, 2023

@Ar57m Why would it not be working for a specific model? Do you have an error message to share? We need more to go on. Also, this is not the right place. You should look up the error message in the issues, and if you can't find it, create a new one.

Your other interaction on Github asks if it's possible to support loading big models on low memory devices. Yes and no. Memory mapping requires that you have enough memory to map into. So if you're running out of RAM, consider allocating swap space. This allows the OS to use the reserved disk space as though it were RAM. It isn't RAM, so expect it to run much slower. But it could make the difference between the program running and not running.

@l29ah
Copy link
Contributor Author

l29ah commented Aug 5, 2023

Memory mapping requires that you have enough memory to map into.

It doesn't. It's just like swap.

@apaz-cli
Copy link
Contributor

apaz-cli commented Aug 5, 2023

Enough virtual memory. Also there has to be enough physical memory backing it, or it will fault. So you are technically correct. Possibly the best kind of correct, but this is pedantics.

@Ar57m
Copy link

Ar57m commented Aug 5, 2023

@apaz-cli I'm trying to run pyg6b q4_0 on termux(Android), when the model is loading it crashes termux(because it's using too much ram), idk I think it's because it's not mmaping, I heard somewhere that gptj doesn't support mmap. I can run almost any size Llama based models(very slowly big models) with mmap on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests