Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda fallback bf16 for compute_cap < 8.0 (~x6 speed) #57

Merged
merged 1 commit into from
Jan 8, 2025

Conversation

haricot
Copy link

@haricot haricot commented Jan 7, 2025

tested and works with:

nvidia-smi   --query-gpu="compute_cap"  --format=csv
compute_cap
6.1 

and mistral.rs:

cargo run -F "cuda cudnn " -r -- --no-paged-attn -i plain -m meta-llama/Llama-3.2-1B-Instruct --dtype bf16

@EricLBuehler
Copy link
Owner

EricLBuehler commented Jan 7, 2025

@haricot this seems super interesting. Do you see a 6x speedup in T/s? That is amazing.

Copy link
Owner

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! It looks great.

I'll add checks for when the cuda arch is >= 800, in which case these should be automatically disabled

@EricLBuehler EricLBuehler merged commit eab5cbb into EricLBuehler:main Jan 8, 2025
@haricot
Copy link
Author

haricot commented Jan 8, 2025

Yes, that's what I observed:

    Finished `release` profile [optimized] target(s) in 0.26s
       Running `target/release/mistralrs-server --no-paged-attn --throughput -i plain -m meta-llama/Llama-3.2-1B-Instruct --dtype bf16`
2025-01-08T06:08:09.525518Z  INFO mistralrs_server: avx: true, neon: false, simd128: false, f16c: true
2025-01-08T06:08:09.525539Z  INFO mistralrs_server: Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial
2025-01-08T06:08:09.525559Z  INFO mistralrs_server: Model kind is: normal (no adapters)
2025-01-08T06:08:09.525698Z  INFO mistralrs_core::pipeline::normal: Loading `tokenizer.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:09.525740Z  INFO mistralrs_core::pipeline::normal: Loading `config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:10.827451Z  INFO mistralrs_core::pipeline::paths: Found model weight filenames ["model.safetensors"]
2025-01-08T06:08:11.134528Z  INFO mistralrs_core::pipeline::normal: Loading `generation_config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:11.670317Z  INFO mistralrs_core::pipeline::normal: Loading `tokenizer_config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:11.953842Z  INFO mistralrs_core::pipeline::normal: Loading model `meta-llama/Llama-3.2-1B-Instruct` on cuda[0].
2025-01-08T06:08:11.953916Z  INFO mistralrs_core::utils::log: Automatic loader type determined to be `llama`
2025-01-08T06:08:11.953982Z  INFO mistralrs_core::utils::normal: DType selected is BF16.
2025-01-08T06:08:11.953994Z  INFO mistralrs_core::pipeline::normal: Model config: Config { hidden_size: 2048, intermediate_size: 8192, vocab_size: 128256, num_hidden_layers: 16, num_attention_heads: 32, num_key_value_heads: 8, use_flash_attn: false, rms_norm_eps: 1e-5, rope_theta: 500000.0, max_position_embeddings: 131072, rope_scaling: Some(Llama3RopeConfig { factor: 32.0, low_freq_factor: 1.0, high_freq_factor: 4.0, original_max_position_embeddings: 8192, rope_type: Llama3 }), quantization_config: None, tie_word_embeddings: true }
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 575.66it/s]
2025-01-08T06:08:13.044656Z  INFO mistralrs_core::pipeline::chat_template: bos_toks = "<|begin_of_text|>", eos_toks = "<|eot_id|>", "<|end_of_text|>", "<|eom_id|>", unk_tok = `None`
2025-01-08T06:08:13.076538Z  INFO mistralrs_server: Model loaded.
2025-01-08T06:08:13.079284Z  INFO mistralrs_core: Enabling GEMM reduced precision in BF16.
2025-01-08T06:08:13.081520Z  INFO mistralrs_core: Enabling GEMM reduced precision in F16.
2025-01-08T06:08:13.082448Z  INFO mistralrs_core::cublaslt: Initialized cuBLASlt handle
2025-01-08T06:08:13.082528Z  INFO mistralrs_core: Beginning dummy run.
2025-01-08T06:08:13.200032Z  INFO mistralrs_core: Dummy run completed in 0.117482336s.
2025-01-08T06:08:13.200072Z  INFO mistralrs_server::interactive_mode: Starting interactive loop with sampling params: SamplingParams { temperature: Some(0.1), top_k: Some(32), top_p: Some(0.1), min_p: Some(0.05), top_n_logprobs: 0, frequency_penalty: Some(0.1), presence_penalty: Some(0.1), stop_toks: None, max_len: Some(4096), logits_bias: None, n_choices: 1, dry_params: Some(DrySamplingParams { sequence_breakers: ["\n", ":", "\"", "*"], multiplier: 0.0, base: 1.75, allowed_length: 2 }) }
====================
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.

Commands:
- `\help`: Display this message.
- `\exit`: Quit interactive mode.
- `\system <system message here>`:
    Add a system message to the chat without running the model.
    Ex: `\system Always respond as a pirate.`
====================
>can you tell me something joke?
Here's one:

What do you call a fake noodle?

An impasta.

I hope that made you laugh! Do you want to hear another one?
2025-01-08T06:08:18.617231Z  INFO mistralrs_server::interactive_mode: Average T/s: 46.27449340135961
    Finished `release` profile [optimized] target(s) in 0.25s
     Running `target/release/mistralrs-server --no-paged-attn --throughput -i plain -m meta-llama/Llama-3.2-1B-Instruct --dtype f16`
2025-01-08T06:08:41.992549Z  INFO mistralrs_server: avx: true, neon: false, simd128: false, f16c: true
2025-01-08T06:08:41.992569Z  INFO mistralrs_server: Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial
2025-01-08T06:08:41.992601Z  INFO mistralrs_server: Model kind is: normal (no adapters)
2025-01-08T06:08:41.992732Z  INFO mistralrs_core::pipeline::normal: Loading `tokenizer.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:41.992775Z  INFO mistralrs_core::pipeline::normal: Loading `config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:43.172236Z  INFO mistralrs_core::pipeline::paths: Found model weight filenames ["model.safetensors"]
2025-01-08T06:08:43.495398Z  INFO mistralrs_core::pipeline::normal: Loading `generation_config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:44.211818Z  INFO mistralrs_core::pipeline::normal: Loading `tokenizer_config.json` at `meta-llama/Llama-3.2-1B-Instruct`
2025-01-08T06:08:44.519284Z  INFO mistralrs_core::pipeline::normal: Loading model `meta-llama/Llama-3.2-1B-Instruct` on cuda[0].
2025-01-08T06:08:44.519379Z  INFO mistralrs_core::utils::log: Automatic loader type determined to be `llama`
2025-01-08T06:08:44.519414Z  INFO mistralrs_core::utils::normal: DType selected is F16.
2025-01-08T06:08:44.519430Z  INFO mistralrs_core::pipeline::normal: Model config: Config { hidden_size: 2048, intermediate_size: 8192, vocab_size: 128256, num_hidden_layers: 16, num_attention_heads: 32, num_key_value_heads: 8, use_flash_attn: false, rms_norm_eps: 1e-5, rope_theta: 500000.0, max_position_embeddings: 131072, rope_scaling: Some(Llama3RopeConfig { factor: 32.0, low_freq_factor: 1.0, high_freq_factor: 4.0, original_max_position_embeddings: 8192, rope_type: Llama3 }), quantization_config: None, tie_word_embeddings: true }
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 309.63it/s]
2025-01-08T06:08:45.671820Z  INFO mistralrs_core::pipeline::chat_template: bos_toks = "<|begin_of_text|>", eos_toks = "<|eot_id|>", "<|end_of_text|>", "<|eom_id|>", unk_tok = `None`
2025-01-08T06:08:45.710028Z  INFO mistralrs_server: Model loaded.
2025-01-08T06:08:45.712818Z  INFO mistralrs_core: Enabling GEMM reduced precision in BF16.
2025-01-08T06:08:45.715095Z  INFO mistralrs_core: Enabling GEMM reduced precision in F16.
2025-01-08T06:08:45.715965Z  INFO mistralrs_core::cublaslt: Initialized cuBLASlt handle
2025-01-08T06:08:45.716045Z  INFO mistralrs_core: Beginning dummy run.
2025-01-08T06:08:45.841990Z  INFO mistralrs_core: Dummy run completed in 0.125921814s.
2025-01-08T06:08:45.842031Z  INFO mistralrs_server::interactive_mode: Starting interactive loop with sampling params: SamplingParams { temperature: Some(0.1), top_k: Some(32), top_p: Some(0.1), min_p: Some(0.05), top_n_logprobs: 0, frequency_penalty: Some(0.1), presence_penalty: Some(0.1), stop_toks: None, max_len: Some(4096), logits_bias: None, n_choices: 1, dry_params: Some(DrySamplingParams { sequence_breakers: ["\n", ":", "\"", "*"], multiplier: 0.0, base: 1.75, allowed_length: 2 }) }
====================
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.

Commands:
- `\help`: Display this message.
- `\exit`: Quit interactive mode.
- `\system <system message here>`:
    Add a system message to the chat without running the model.
    Ex: `\system Always respond as a pirate.`
====================
>can you tell me something joke?     
Here's one:

What do you call a fake noodle?

An impasta.

I hope that made you laugh! Do you want to hear another one?
2025-01-08T06:09:31.082653Z  INFO mistralrs_server::interactive_mode: Average T/s: 7.026543238073144

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants