Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support glm3 and glm4. #8031

Merged
merged 39 commits into from
Jul 7, 2024
Merged

Support glm3 and glm4. #8031

merged 39 commits into from
Jul 7, 2024

Conversation

youth123
Copy link
Contributor

I have fixed the issues mentioned in #6999. This code can totally supports glm3 and glm4 model architecture and can be emdded in ollama server. This PR is based on https://github.com/mnlife/llama.cpp/tree/glm4 and https://github.com/mnlife/llama.cpp/tree/chatglm3, by @mnlife and @xingxingqiao.

@github-actions github-actions bot added testing Everything test related python python script changes labels Jun 20, 2024
@xunkai55
Copy link

Thanks for the great work!

@youth123 youth123 marked this pull request as draft June 20, 2024 10:07
@youth123 youth123 marked this pull request as ready for review June 20, 2024 10:15
llama.cpp Outdated Show resolved Hide resolved
@arch-btw
Copy link
Contributor

This is so great! Thank you 👍 !

There are only a couple of things that I ran into:

During compile, a small note:

llama.cpp: In function ‘int32_t llama_tokenize(const llama_model*, const char*, int32_t, llama_token*, int32_t, bool, bool)’:
llama.cpp:18603:28: warning: moving a temporary object prevents copy elision [-Wpessimizing-move]
18603 |     auto prompt = std::move(std::string(text, text_len));
      |                   ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
llama.cpp:18603:28: note: remove ‘std::move’ call

And convert-hf-to-gguf.py doesn't work with bf16.

But it does work with f32.

bf16 log:

INFO:hf-to-gguf:output.weight,             torch.bfloat16 --> BF16, shape = {4096, 151552}
Writing:   0%|                                                                                                                                                                     | 0.00/18.8G [00:00<?, ?byte/s]Traceback (most recent call last):
  File "/home/glm4/convert-hf-to-gguf.py", line 3072, in <module>
    main()
  File "/home/glm4/convert-hf-to-gguf.py", line 3066, in main
    model_instance.write()
  File "/home/glm4/convert-hf-to-gguf.py", line 331, in write
    self.gguf_writer.write_tensors_to_file(progress=True)
  File "/home/glm4/gguf-py/gguf/gguf_writer.py", line 312, in write_tensors_to_file
    ti.tensor.tofile(self.fout)
  File "/home/glm4/gguf-py/gguf/lazy.py", line 233, in tofile
    eager = LazyNumpyTensor.to_eager(self)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 193, in to_eager
    return cls._recurse_apply(t, simple_to_eager)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 109, in _recurse_apply
    return fn(o)
           ^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 185, in simple_to_eager
    lt._data = lt._func(lt._args)
               ^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 158, in <lambda>
    return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
                                                                                        ^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 52, in __quantize_bf16_array
    return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 47, in __apply_over_grouped_rows
    np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
                    ^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 30, in __compute_fp32_to_bf16
    n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
                                                 ~~^~~~~~~~~~~~
OverflowError: Python integer 4294901760 out of bounds for int32

Other than that it works great.

Prompt:

./llama-cli -m glm4-Q6_K.gguf --color -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\n"

Output:

You are a helpful assistant
Hello

Hello! How can I assist you today? [end of text]

@youth123
Copy link
Contributor Author

youth123 commented Jun 21, 2024

This is so great! Thank you 👍 !

There are only a couple of things that I ran into:

During compile, a small note:

llama.cpp: In function ‘int32_t llama_tokenize(const llama_model*, const char*, int32_t, llama_token*, int32_t, bool, bool)’:
llama.cpp:18603:28: warning: moving a temporary object prevents copy elision [-Wpessimizing-move]
18603 |     auto prompt = std::move(std::string(text, text_len));
      |                   ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
llama.cpp:18603:28: note: remove ‘std::move’ call

And convert-hf-to-gguf.py doesn't work with bf16.

But it does work with f32.

bf16 log:

INFO:hf-to-gguf:output.weight,             torch.bfloat16 --> BF16, shape = {4096, 151552}
Writing:   0%|                                                                                                                                                                     | 0.00/18.8G [00:00<?, ?byte/s]Traceback (most recent call last):
  File "/home/glm4/convert-hf-to-gguf.py", line 3072, in <module>
    main()
  File "/home/glm4/convert-hf-to-gguf.py", line 3066, in main
    model_instance.write()
  File "/home/glm4/convert-hf-to-gguf.py", line 331, in write
    self.gguf_writer.write_tensors_to_file(progress=True)
  File "/home/glm4/gguf-py/gguf/gguf_writer.py", line 312, in write_tensors_to_file
    ti.tensor.tofile(self.fout)
  File "/home/glm4/gguf-py/gguf/lazy.py", line 233, in tofile
    eager = LazyNumpyTensor.to_eager(self)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 193, in to_eager
    return cls._recurse_apply(t, simple_to_eager)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 109, in _recurse_apply
    return fn(o)
           ^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 185, in simple_to_eager
    lt._data = lt._func(lt._args)
               ^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 158, in <lambda>
    return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
                                                                                        ^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 52, in __quantize_bf16_array
    return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 47, in __apply_over_grouped_rows
    np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
                    ^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 30, in __compute_fp32_to_bf16
    n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
                                                 ~~^~~~~~~~~~~~
OverflowError: Python integer 4294901760 out of bounds for int32

Other than that it works great.

Prompt:

./llama-cli -m glm4-Q6_K.gguf --color -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\n"

Output:

You are a helpful assistant
Hello

Hello! How can I assist you today? [end of text]

I reran the conversion for GLM3 and GLM4, but did not encounter the issue you mentioned.
Here are my run commands and model weight links.

python convert-hf-to-gguf.py   /root/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/75792d7ee58a335df6943c5d719cc559b64f8e2a/ --outtype bf16 --outfile test.gguf

https://huggingface.co/THUDM/glm-4-9b-chat

@arch-btw
Copy link
Contributor

Thanks, I think it's related to my pip environment and not a problem with the code.

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Jun 21, 2024
llama.cpp Outdated
@@ -18324,6 +18550,19 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to
}

bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
auto arch_name = llama_model_arch_name(model->arch);
auto vocab_type = model->vocab.type;
if (strcmp(arch_name, "chatglm") == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama_token_is_eog is called quite often, doing string compare here may have impact on performance

Copy link
Collaborator

@ngxson ngxson Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at tokenizer_config.json, I think that it's safe to stop at EOS (<|endoftext|>), so no need to hard-code token IDs here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: looking at chat template, seems like the model does not have the notion end-of-turn token (strange!). Maybe we need to introduce EOT token as a list instead of single value. This will require adding metadata to gguf (CC @ggerganov )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I will add an eot list to the metadata of gguf. Then, during the initialization of vocab, I will put all the eot entries into this variable. At that time, the judgment will only require traversing this list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 8, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
@CsBoBoNice
Copy link

您好,我使用b3333版本已经合并该功能的代码进行编译使用,

cmake -B build -DGGML_CUDA=ON && cmake --build build --config Release -j
python3 convert_hf_to_gguf.py /root/autodl-tmp/glm-4-9b-chat-1m
./build/bin/llama-cli -m /root/autodl-tmp/glm-4-9b-chat-1m/ggml-model-f16.gguf -ngl 20 --color -i

执行后出现以下错误:

# ./build/bin/llama-cli -m /root/autodl-tmp/glm-4-9b-chat-1m/ggml-model-f16.gguf -ngl 20 --color -i
Log start
main: build = 3333 (905942ab)
main: built with cc (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0 for x86_64-linux-gnu
main: seed  = 1720430939
llama_model_loader: loaded meta data with 24 key-value pairs and 283 tensors from /root/autodl-tmp/glm-4-9b-chat-1m/ggml-model-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = chatglm
llama_model_loader: - kv   1:                               general.name str              = glm-4-9b-chat-1m
llama_model_loader: - kv   2:                     chatglm.context_length u32              = 1048576
llama_model_loader: - kv   3:                   chatglm.embedding_length u32              = 4096
llama_model_loader: - kv   4:                chatglm.feed_forward_length u32              = 13696
llama_model_loader: - kv   5:                        chatglm.block_count u32              = 40
llama_model_loader: - kv   6:               chatglm.attention.head_count u32              = 32
llama_model_loader: - kv   7:            chatglm.attention.head_count_kv u32              = 4
llama_model_loader: - kv   8:   chatglm.attention.layer_norm_rms_epsilon f32              = 0.000000
llama_model_loader: - kv   9:                          general.file_type u32              = 1
llama_model_loader: - kv  10:               chatglm.rope.dimension_count u32              = 64
llama_model_loader: - kv  11:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  12:                     chatglm.rope.freq_base f32              = 100000000.000000
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                         tokenizer.ggml.pre str              = chatglm-bpe
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,151552]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,151552]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  17:                      tokenizer.ggml.merges arr[str,151073]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  18:            tokenizer.ggml.padding_token_id u32              = 151329
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 151329
llama_model_loader: - kv  20:                tokenizer.ggml.eot_token_id u32              = 151336
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 151329
llama_model_loader: - kv  22:                    tokenizer.chat_template str              = [gMASK]<sop>{% for item in messages %...
llama_model_loader: - kv  23:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type  f16:  162 tensors
llm_load_vocab: special tokens cache size = 223
llm_load_vocab: token to piece cache size = 0.9732 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = chatglm
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 151552
llm_load_print_meta: n_merges         = 151073
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 1048576
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 4
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 512
llm_load_print_meta: n_embd_v_gqa     = 512
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.6e-07
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 13696
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 100000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 1048576
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 9B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 9.48 B
llm_load_print_meta: model size       = 17.67 GiB (16.00 BPW)
llm_load_print_meta: general.name     = glm-4-9b-chat-1m
llm_load_print_meta: EOS token        = 151329 '<|endoftext|>'
llm_load_print_meta: UNK token        = 151329 '<|endoftext|>'
llm_load_print_meta: PAD token        = 151329 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 151336 '<|user|>'
llm_load_print_meta: max token length = 1024
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2080 Ti, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size =    0.28 MiB
llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_qkv.weight' has wrong shape; expected  4096,  4608, got  4096,  5120,     1,     1
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/root/autodl-tmp/glm-4-9b-chat-1m/ggml-model-f16.gguf'
main: error: unable to load model

请问是我哪里操作不对导致无法运行吗?

@chuangfengwang
Copy link

My tests show:

  • glm-4-9b-chat is supported
  • glm-4-9b-chat-1m is not supported

It looks like: some model layers of glm-4-9b-chat-1m model are different from glm-4-9b-chat.

Let's adapt the model glm-4-9b-chat-1m

@youth123
Copy link
Contributor Author

youth123 commented Jul 8, 2024

My tests show:

  • glm-4-9b-chat is supported
  • glm-4-9b-chat-1m is not supported

It looks like: some model layers of glm-4-9b-chat-1m model are different from glm-4-9b-chat.

Let's adapt the model glm-4-9b-chat-1m

I will take a look at the model.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 8, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
@yesmycar
Copy link

yesmycar commented Jul 8, 2024 via email

@CsBoBoNice
Copy link

您好,很高兴能在该项目上用上glm-4-9b-chat模型

我使用b3333版本,使用glm-4-9b-chat模型,
发现模型一直返回不会正常停止

使用以下命令编译

cmake -B build -DGGML_CUDA=ON && cmake --build build --config Release -j

使用以下命令运行

./build/bin/llama-cli -m /root/autodl-tmp/glm-4-9b-chat/ggml-model-f16.gguf -ngl 41 --color -i -n 1024 -c 2048 --file ./prompts/chat-with-qwen.txt -r "User:\n" --in-suffix "Assistant:\n" --in-prefix "\n" --interactive-first

运行后的部分日志

.................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 5000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =    80.00 MiB
llama_new_context_with_model: KV self size  =   80.00 MiB, K (f16):   40.00 MiB, V (f16):   40.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.58 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   304.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    12.01 MiB
llama_new_context_with_model: graph nodes  = 1606
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 64 / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 0 |
main: interactive mode on.
Reverse prompt: 'User:
'
Input prefix: '
'
Input suffix: 'Assistant:
'
sampling:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 2048, n_batch = 2048, n_predict = 1024, n_keep = 0


== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.

You are a helpful assistant.
Output the first 10 even natural numbers in a row
Assistant:
2, 4, 6, 8, 10, 12, 14, 16, 18, 20. Here are the first 10 even natural numbers in a row. They are multiples of 2. Do you need any further assistance? Feel free to ask! 🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻

请问应该使用什么命令才能与模型正常对话?

期待您的回复

@taozhiyuai
Copy link

ollama/ollama#5553

@chuangfengwang
Copy link

您好,很高兴能在该项目上用上glm-4-9b-chat模型

我使用b3333版本,使用glm-4-9b-chat模型, 发现模型一直返回不会正常停止

使用以下命令编译

cmake -B build -DGGML_CUDA=ON && cmake --build build --config Release -j

使用以下命令运行

./build/bin/llama-cli -m /root/autodl-tmp/glm-4-9b-chat/ggml-model-f16.gguf -ngl 41 --color -i -n 1024 -c 2048 --file ./prompts/chat-with-qwen.txt -r "User:\n" --in-suffix "Assistant:\n" --in-prefix "\n" --interactive-first

运行后的部分日志

.................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 5000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =    80.00 MiB
llama_new_context_with_model: KV self size  =   80.00 MiB, K (f16):   40.00 MiB, V (f16):   40.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.58 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   304.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    12.01 MiB
llama_new_context_with_model: graph nodes  = 1606
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 64 / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 0 |
main: interactive mode on.
Reverse prompt: 'User:
'
Input prefix: '
'
Input suffix: 'Assistant:
'
sampling:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 2048, n_batch = 2048, n_predict = 1024, n_keep = 0


== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.

You are a helpful assistant.
Output the first 10 even natural numbers in a row
Assistant:
2, 4, 6, 8, 10, 12, 14, 16, 18, 20. Here are the first 10 even natural numbers in a row. They are multiples of 2. Do you need any further assistance? Feel free to ask! 🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻🤖👨‍💻🌟✨👩‍💻

请问应该使用什么命令才能与模型正常对话?

期待您的回复

You need special format prompt.

If you have a beginning user prompt, your command may be like this

./llama-cli -m path/to/glm-4-9b-chat-f16.gguf --color \
 -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\n"

image

else your command may be like this

./llama-cli -m path/to/glm-4-9b-chat-f16.gguf --color -i \
 -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>"

and you can interact as follow
image

@icetech233
Copy link

速度

@lzs0603
Copy link

lzs0603 commented Jul 10, 2024

I've encountered an issue while setting up multiple inferencing pools with GLM4. The model continuously outputs "GGGGGGGGGGGGGGGG..." without stopping.

The script I'm using is as follows:

#!/bin/bash
./llama-server -m /models/THUDM_glm-4-9b-chat/ggml-model-Q4_K_M.gguf -c 8192 --port 10094 --n-gpu-layers 41 -np 2 --threads 4 --host 172.17.0.1 -cb

When I modify -np to 1, the problem is resolved. Could you please help identify this issue?

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 11, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 11, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
@hp027
Copy link

hp027 commented Jul 12, 2024

same like this GGGGGG no end;
run with cmd: ollama run glm4

{
    "title_english": "Repository Installation",
    "keyword_english": ["repository installation", "Red Hat Enterprise Linux", "CentOS", "RHEL", "Oracle Linux", "Debian", "Ubuntu", "package manager", "Zabbix", "software package"],
    "summary_english": "This document provides instructions for installing repository configuration packages on various Linux distributions to set up Zabbix server, agent, and proxy installations. It covers the process specifically for Red Hat Enterprise Linux/CentOS, Debian, and Ubuntu with different versions of supported software package managers and configurations.",
    "faq_english": [
        "What are the steps to install repository configuration packages for Zabbix on RHEL 7?",
        "How do I configure apt-get preferences for Debian installation of Zabbix repository?",
        "What versions of Linux distributions are supported for ZGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
CPU times: user 69.9 ms, sys: 3.7 ms, total: 73.6 ms
Wall time: 7.76 s

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <[email protected]>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <[email protected]>

* fix lint error

Signed-off-by: XingXing Qiao <[email protected]>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <[email protected]>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <[email protected]>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d579.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <[email protected]>
Co-authored-by: XingXing Qiao <[email protected]>
Co-authored-by: Umpire2018 <[email protected]>
@JG-Adams
Copy link

JG-Adams commented Jul 16, 2024

I'm getting this too.
It declared, "Good Game" and gave up!
I would prefer this model over llama3 if it could be fully stable! :)

@ggerganov
Copy link
Owner

ggerganov commented Jul 16, 2024

Try similar solution as in #8412

@JG-Adams
Copy link

Try similar solution as in #8412

I'm using Ollama. I don't have a clue on how to update. Does it have to wait?

@tobias-varden
Copy link

For me the model seems fixed: Ollama 0.2.7

@JG-Adams
Copy link

JG-Adams commented Jul 20, 2024

For me the model seems fixed: Ollama 0.2.7

I have updated Ollama. It seem to be a little better I guess. But, it still have the problem.
I think it's the model itself that needed to be updated on Ollama

@Sakura4036
Copy link

Sakura4036 commented Jul 24, 2024

For me the model seems fixed: Ollama 0.2.7

I have updated Ollama. It seem to be a little better I guess. But, it still have the problem. I think it's the model itself that needed to be updated on Ollama

For me , Ollama 0.2.8 not fix this problem

@piDack
Copy link
Contributor

piDack commented Aug 22, 2024

For me the model seems fixed: Ollama 0.2.7

I have updated Ollama. It seem to be a little better I guess. But, it still have the problem. I think it's the model itself that needed to be updated on Ollama

For me , Ollama 0.2.8 not fix this problem

I tried the solution from this #8412 and it seems to resolve the issue for me.
I’ve submitted the code. You can try the modified code on my branch to see if it can resolve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level server testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.