Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
179 commits
Select commit Hold shift + click to select a range
7884b0e
sampling : add support for backend sampling
danbev Nov 17, 2025
9fe9a00
llama-cli : add backend sampler configuration
danbev Nov 17, 2025
f1f3e68
server : add backend sampling options/configuration
danbev Nov 17, 2025
a3eb847
webui : add backend sampling options
danbev Nov 17, 2025
67d3b8e
ggml : add initial cumsum implementation for CUDA
danbev Nov 17, 2025
71574f9
sampling : enable all backend sampler tests
danbev Nov 18, 2025
4b52e59
graph : do not include llama-model.h
ggerganov Nov 18, 2025
82957a9
sampling : always expose sampled_ids
danbev Nov 18, 2025
311c1a3
sampling : ensure at most one output token per seq
danbev Nov 18, 2025
26be108
CUDA: Optimize argsort for gpu-based token sampling
ORippler Nov 18, 2025
0da7e7d
sampling : remove version from sampler chain
danbev Nov 19, 2025
51fee29
sampling : always populate logits for sampled probs
danbev Nov 19, 2025
7e98ebc
sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
d74eb61
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
38f408c
common : fix regression caused by extra memory allocations during sam…
ggerganov Nov 19, 2025
18ed4d8
squash! sampling : simplify backend sampling logic decode
danbev Nov 19, 2025
0c660e7
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 20, 2025
ed4345b
squash! common : fix regression caused by extra memory allocations du…
danbev Nov 20, 2025
0d28b16
sampling : introduce sampling_info struct
danbev Nov 20, 2025
c162562
sampling : return early if backend sampling is disabled
danbev Nov 21, 2025
61ffe41
sampling : use pinned memory for backend sampling buffers
danbev Nov 21, 2025
9b24393
common, tools : refactor model loading to support backend samplers
danbev Nov 21, 2025
79b8cf2
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 21, 2025
65500d0
sampling : add stride variable for clarity
danbev Nov 23, 2025
ae23d2d
sampling: clarify candidate ids usage in comments
danbev Nov 23, 2025
9e273f7
sampling : fix copying both sampled tokens and logits/probs from backend
danbev Nov 23, 2025
50d21aa
tests : cleanup test-backend-sampler.cpp
danbev Nov 24, 2025
7816f0b
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 24, 2025
d88ba18
common : remove build-info.cpp from commit [no ci]
danbev Nov 24, 2025
4a90583
sampling : cleanup and clarify output_reserve
danbev Nov 24, 2025
8eb9b47
sampling : remove redundant checks for stride and size [no ci]
danbev Nov 24, 2025
25f3380
sampling : add debug log when backend sampler selects token
danbev Nov 24, 2025
d0bea21
examples : update batched to use backend sampling
danbev Nov 24, 2025
e2d4f08
llama-cli : fix dangling reference to sampler config
ggerganov Nov 24, 2025
b26c706
common : initialize backend samplers
ggerganov Nov 24, 2025
883a870
samplers : add missing cont
ggerganov Nov 24, 2025
a02adf4
sampling : add assertions for contiguous tensors in async copy functions
danbev Nov 24, 2025
2b4c792
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
0f17ccd
examples : add info about hybrid sampling in batched [no ci]
danbev Nov 25, 2025
53dca56
Merge remote-tracking branch 'upstream/master' into gpu-sampling
danbev Nov 25, 2025
9e5e09d
sampling : remove backend-dist option (wip)
danbev Nov 25, 2025
ec047e1
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 25, 2025
f23b306
CUDA: Add top-k implementation
ORippler Nov 21, 2025
b45d504
sampling : add min-p backend sampler
danbev Nov 26, 2025
4fea191
Use `FetchContent` over CPM as it's bundled with CMake
ORippler Nov 26, 2025
0f7805f
common : add get_active_samplers function to check enabled samplers
danbev Nov 26, 2025
90a3aff
cuda : fix editorconfig-checker warning
danbev Nov 26, 2025
7c2bfb3
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Nov 26, 2025
d9d7361
sampling : use argmax for min-p sampling
danbev Nov 27, 2025
51107a0
sampling : fix temperature check to allow zero temperature
danbev Nov 27, 2025
5ea3be2
cuda : fix top-k compilation when CUB is unavailable
danbev Nov 27, 2025
172208a
sampling : add comments about backend sampler [no ci]
danbev Nov 27, 2025
e9d0709
sampling : remove backend sampling chain from common_sampler
danbev Nov 27, 2025
f9889cf
Fix top-k comp & behavior for non-CUB path
ORippler Nov 27, 2025
74be332
sampling : support intermixed backend/cpu samplers
danbev Nov 27, 2025
9ad6522
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
459b7ae
squash! sampling : support intermixed backend/cpu samplers
danbev Nov 28, 2025
117e207
refactor : simplify and improve memory management
ggerganov Nov 28, 2025
333da80
Add initial version for top-p sampling
ORippler Nov 28, 2025
8cac9de
sampling : use logits directly for min-p filtering
danbev Nov 28, 2025
2464d1b
sampling : simplify
ggerganov Nov 28, 2025
fbc8f49
llama : simplify
ggerganov Nov 29, 2025
9028ebf
llama : cleanup + naming
ggerganov Nov 29, 2025
d8d98bb
Merge branch 'master' into HEAD
ggerganov Nov 29, 2025
ff7b0bf
llama : call backend_init once
ggerganov Nov 29, 2025
467746e
Merge branch 'master' into HEAD
ggerganov Nov 29, 2025
1760bd6
llama : reserve graphs with samplers
ggerganov Nov 29, 2025
c187003
llama : naming
ggerganov Nov 29, 2025
80742cb
cont : naming
ggerganov Nov 29, 2025
cf0e147
sampling : lower log level for output buffer reallocations [no ci]
danbev Dec 1, 2025
8bee483
Fix backend_top_p_sampler
ORippler Dec 1, 2025
16451d6
Merge branch 'master' into HEAD
ggerganov Dec 1, 2025
ae0bb6a
Factor out `ggml_sort` into its own function
ORippler Dec 1, 2025
217469f
Make backend's top_p sampler inclusive
ORippler Dec 1, 2025
4032ce2
common : simplify sampler chain initialization
ggerganov Dec 1, 2025
04f2822
sampling : do not create empty samplers
ggerganov Dec 1, 2025
88cca45
sampling : fix top_p empty condition
ggerganov Dec 1, 2025
988261b
examples : remove outdated backend sampling section
danbev Dec 1, 2025
739b597
sampling : fix backend temp sampler for zero temperature
danbev Dec 2, 2025
3e9a258
Merge remote-tracking branch 'upstream/master' into gpu-sampling
danbev Dec 2, 2025
559d058
CUDA: Move cccl fetch to after cuda has been enabled in CMakeLists.txt
ORippler Dec 1, 2025
244880a
CUDA: Use standard-compliant preprocessor for MSVC builds
ORippler Dec 2, 2025
516af33
CUDA: Update CCCL's rc candidate
ORippler Dec 2, 2025
db8972e
squash! sampling : fix backend temp sampler for zero temperature
danbev Dec 2, 2025
2595818
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 2, 2025
aad5a6a
sampling : implement temp_ext_backend sampling
danbev Dec 2, 2025
cce3b2a
sampling : minor cleanup
ggerganov Dec 3, 2025
87b2719
sampling : stop short if backend sampler sampled a token
danbev Dec 4, 2025
c0b182f
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 4, 2025
10bd640
Revert "sampling : stop short if backend sampler sampled a token"
danbev Dec 4, 2025
ac9e164
sampling : fix backend temp sampling to use logits masking
danbev Dec 4, 2025
fce571e
sampling : simplify temp sampling
ggerganov Dec 4, 2025
1bde707
sampling : remove redundant calls to ggml_build_forward_expand
ggerganov Dec 4, 2025
6958d41
sampling : check backend support during init
ggerganov Dec 4, 2025
abc1963
cont : keep backend sampling disabled for now
ggerganov Dec 4, 2025
7864074
sampling : fix outputs and device checks
ggerganov Dec 4, 2025
cf74b1a
sampling : fix candidates logic
ggerganov Dec 5, 2025
dd11f6e
Add perf-tests for CUMSUM
ORippler Dec 5, 2025
7668999
Merge branch 'master' into gpu-sampling
ORippler Dec 5, 2025
e652566
Readd `cub::DeviceScan::InclusiveSum`-based CumSum
ORippler Dec 5, 2025
30742a6
sampling : expand support (wip)
ggerganov Dec 5, 2025
fdac968
Merge branch 'master' into HEAD
ggerganov Dec 6, 2025
5225818
tests : fix memory leaks
ggerganov Dec 6, 2025
8ef5f90
cont : fixes
ggerganov Dec 7, 2025
42125f0
tests : check temp back to 0.0
ggerganov Dec 7, 2025
72e3681
sampling : fix top-p
ggerganov Dec 7, 2025
6d38db5
Merge branch 'master' into HEAD
ggerganov Dec 8, 2025
f3beb22
sampling : handle n_probs case
ggerganov Dec 8, 2025
560ac16
server : handle unsupported cases
ggerganov Dec 9, 2025
d62b580
metal : print node names for debugging
ggerganov Dec 9, 2025
62d1b00
ggml : remove redundant src in ggml_cast
ggerganov Dec 9, 2025
9f6681c
ggml-alloc : fix reuse-parent logic for misaligned sizes
ggerganov Dec 9, 2025
7ab6f51
Revert "ggml : remove redundant src in ggml_cast"
ggerganov Dec 9, 2025
a84dfd3
CUDA: Add Cooperative-Groups-based parallelization of ncols in softmax
ORippler Dec 8, 2025
886c366
Add TODOs to and adjust heuristics of row-wise soft_max in CUDA
ORippler Dec 9, 2025
07003f1
Fix compiler warnings by casting `const` away
ORippler Dec 9, 2025
92ff767
llama : require backend samplers to be of type llama_sampler_chain
ggerganov Dec 9, 2025
34b407b
sampling : use host buffer type for inputs
ggerganov Dec 9, 2025
3f0594a
Try fixing HIP build errors by adding corresponding #defines
ORippler Dec 9, 2025
a25fda5
Fix launch logic when supports_cooperative_launch=false
ORippler Dec 9, 2025
6dc6614
Disable cooperative groups for musa
ORippler Dec 9, 2025
81cb578
Merge branch 'master' into HEAD
ggerganov Dec 10, 2025
0ecee8b
server : reconnect the backend_sampling setting in the WebUI
ggerganov Dec 10, 2025
c02654e
graph : make the compute graph constant with respect to active samplers
ggerganov Dec 10, 2025
3888224
Merge branch 'master' into HEAD
ggerganov Dec 10, 2025
44d5c4b
batch : fix sequence id ownage
ggerganov Dec 10, 2025
804e7e3
graph : respect sampler order for graph reuse
ggerganov Dec 10, 2025
42cf5c0
HIP/MUSA: fix build for backend sampling
JohannesGaessler Dec 10, 2025
56720f8
Merge pull request #1 from JohannesGaessler/gpu-sampling-hip
danbev Dec 11, 2025
54e9054
sampling : optimize logit_bias sampler
ggerganov Dec 11, 2025
d5d1665
cont : fix build
ggerganov Dec 11, 2025
8544aba
sampling : generic ggml op support detection
ggerganov Dec 11, 2025
74b112e
sampling : fix greedy
ggerganov Dec 11, 2025
ab65b47
tests : run backend sampler tests always on the CPU
ggerganov Dec 11, 2025
4d10b78
Merge branch 'master' into HEAD
ggerganov Dec 11, 2025
07b809b
Apply suggestions from code review
ORippler Dec 12, 2025
22c7f85
Merge branch 'master' into HEAD
ggerganov Dec 14, 2025
0086c24
Merge branch 'master' into HEAD
ggerganov Dec 14, 2025
2652e74
webui : fix lint
ggerganov Dec 14, 2025
3732b85
Fix data-race in `soft_max_f32_parallelize_cols_single_row`
ORippler Dec 15, 2025
e5737f6
Apply automated code-formating to softmax.cu
ORippler Dec 15, 2025
ad1b60a
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 16, 2025
68a1c4d
llama : clarify backend_accept/backend_set_input comments [no ci]
danbev Dec 17, 2025
c5d44b8
llama : fix typo in comment [no ci]
danbev Dec 17, 2025
9a9ea2f
tests : use smart pointers for backend samplers
danbev Dec 17, 2025
9845996
tests : use smart pointers for model and context
danbev Dec 17, 2025
76a1b7f
tests : remove vocab member from test_model_context
danbev Dec 17, 2025
cc31e6a
tests : extract batch info update to separate method
danbev Dec 17, 2025
a519aea
tests : fix batch token position tracking in test_backend_sampler.cpp
danbev Dec 17, 2025
981475f
tests : add --device option support to backend sampler tests
danbev Dec 17, 2025
eefdb0d
Merge branch 'master' into HEAD
ggerganov Dec 18, 2025
3b3f5fe
common : disable backend sampling when grammar is involved
ggerganov Dec 18, 2025
bc5195c
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 19, 2025
1750917
Fix different RNG-states between backend-sampling and llama-sampling
ORippler Dec 19, 2025
0a17687
Make backend dist sampler use same rnd's as dist sampler
ORippler Dec 19, 2025
b5ec0fd
Update CCCL version to v3.2.0-rc2
ORippler Dec 19, 2025
1da013c
Build with CCCL 3.2 for CUDA backends
ORippler Dec 19, 2025
f1310ab
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 22, 2025
0ce0359
Merge branch 'master' into HEAD
ggerganov Dec 24, 2025
c0a351c
tests : revert server test changes (no longer needed)
ggerganov Dec 24, 2025
82c2600
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 28, 2025
060c0a5
ggml : include cub/cub.cuh instead of block_scan.cuh
danbev Dec 28, 2025
ebfe545
Merge remote-tracking branch 'upstream/master' into backend-sampling
danbev Dec 30, 2025
23e8bb4
arg : add shorthand for --backend-sampling
ggerganov Dec 30, 2025
5d2156e
ci : add server workflow with backend sampling
ggerganov Dec 30, 2025
610e50a
sampling : fix reshapes
ggerganov Dec 30, 2025
588299c
server : remove printfs
ggerganov Dec 30, 2025
c5de759
Merge branch 'master' into HEAD
ggerganov Dec 30, 2025
791ecb9
sampling : zero-initialize input buffers
ggerganov Dec 30, 2025
4c3d542
minor : add comments + some cleanup
ggerganov Dec 31, 2025
435c967
llama : assert at most one output token per sequence
ggerganov Dec 31, 2025
0d85c5c
tests : add more top_k tests
ggerganov Jan 1, 2026
8071a57
Merge branch 'master' into HEAD
ggerganov Jan 1, 2026
b3cf4eb
CUDA: Fix non-determinism of CUB-based Top-K
ORippler Jan 4, 2026
6975bda
CUDA: Optimize index of top_k_cub
ORippler Jan 4, 2026
194401a
Apply code-formatting to top-k.cu
ORippler Jan 4, 2026
9f6c1f3
Merge remote-tracking branch 'origin/master' into gpu-sampling
ORippler Jan 4, 2026
03454de
CUDA: Remove obsolete temp_keys from CUB
ORippler Jan 4, 2026
2e54b1d
minor : cleanup, TODOs, etc.
ggerganov Jan 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ jobs:
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}

- name: Build with CMake
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: |
cmake -S . -B build -G Ninja \
-DLLAMA_CURL=OFF \
Expand All @@ -1107,7 +1108,8 @@ jobs:
-DCMAKE_CUDA_ARCHITECTURES=89-real \
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
-DGGML_NATIVE=OFF \
-DGGML_CUDA=ON
-DGGML_CUDA=ON \
-DGGML_CUDA_CUB_3DOT2=ON
cmake --build build

windows-2022-cmake-cuda:
Expand Down Expand Up @@ -1143,6 +1145,7 @@ jobs:
- name: Build
id: cmake_build
shell: cmd
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^
Expand All @@ -1153,7 +1156,8 @@ jobs:
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=ON ^
-DGGML_CUDA=ON ^
-DGGML_RPC=ON
-DGGML_RPC=ON ^
-DGGML_CUDA_CUB_3DOT2=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,16 @@ jobs:
- name: Build
id: cmake_build
shell: cmd
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^
-DGGML_BACKEND_DL=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_CPU=OFF ^
-DGGML_CUDA=ON ^
-DLLAMA_CURL=OFF
-DLLAMA_CURL=OFF ^
-DGGML_CUDA_CUB_3DOT2=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda

Expand Down
18 changes: 18 additions & 0 deletions .github/workflows/server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
include:
- build_type: Release
sanitizer: ""
extra_args: ""
- build_type: Release
sanitizer: ""
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken

steps:
Expand All @@ -65,6 +69,12 @@ jobs:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}

- name: Build
id: cmake_build
run: |
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server

- name: Python setup
id: setup_python
uses: actions/setup-python@v5
Expand All @@ -76,6 +86,14 @@ jobs:
run: |
pip install -r tools/server/tests/requirements.txt

- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"

server-windows:
runs-on: windows-2022

Expand Down
3 changes: 2 additions & 1 deletion ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ if [ ! -z ${GG_BUILD_METAL} ]; then
fi

if [ ! -z ${GG_BUILD_CUDA} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON"
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DGGML_CUDA_CUB_3DOT2=ON"

if command -v nvidia-smi >/dev/null 2>&1; then
CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.')
Expand Down
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
}
).set_sparam());
add_opt(common_arg(
{"-bs", "--backend-sampling"},
"enable backend sampling (experimental) (default: disabled)",
[](common_params & params) {
params.sampling.backend_sampling = true;
}
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
Expand Down
19 changes: 19 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ struct common_init_result::impl {
std::vector<llama_adapter_lora_ptr> lora;

std::vector<common_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
};

common_init_result::common_init_result(common_params & params) :
Expand Down Expand Up @@ -1162,10 +1163,19 @@ common_init_result::common_init_result(common_params & params) :
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}

// init the backend samplers as part of the context creation
pimpl->samplers.resize(cparams.n_seq_max);
pimpl->samplers_seq_config.resize(cparams.n_seq_max);

for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
}

llama_context * lctx = llama_init_from_model(model, cparams);
Expand All @@ -1189,6 +1199,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}

void common_init_result::reset_samplers() {
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
}
}

std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
Expand Down Expand Up @@ -1304,6 +1320,9 @@ common_init_result_ptr common_init_from_params(common_params & params) {
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);

// reset samplers to reset RNG state after warmup to the seeded state
res->reset_samplers();
}

return res;
Expand Down
4 changes: 4 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens

bool backend_sampling = false;

bool has_logit_bias() const {
return !logit_bias.empty();
}
Expand Down Expand Up @@ -689,7 +691,9 @@ struct common_init_result {

llama_model * model();
llama_context * context();

common_sampler * sampler(llama_seq_id seq_id);
void reset_samplers();

std::vector<llama_adapter_lora_ptr> & lora();

Expand Down
16 changes: 10 additions & 6 deletions common/llguidance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
}

static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
/* .backend_init = */ NULL,
/* .backend_accept = */ NULL,
/* .backend_apply = */ NULL,
/* .backend_set_input = */ NULL,
};

static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
Expand Down
54 changes: 48 additions & 6 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,34 @@ struct common_sampler {
}

void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

const int n_vocab = llama_vocab_n_tokens(vocab);

cur.resize(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count);
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);
GGML_ASSERT(logits != nullptr);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}

cur_p = { cur.data(), cur.size(), -1, false };
Expand Down Expand Up @@ -159,7 +176,7 @@ std::string common_params_sampling::print() const {
return std::string(result);
}

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);

llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
Expand Down Expand Up @@ -298,6 +315,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(chain, smpl);
}

if (grmr && params.backend_sampling) {
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);

params.backend_sampling = false;
}

auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
Expand Down Expand Up @@ -407,6 +430,25 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits

// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
id = llama_get_sampled_token_ith(ctx, idx);

if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);

GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");

// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };

return id;
}
}

gsmpl->set_logits(ctx, idx);

if (grammar_first) {
Expand Down
4 changes: 3 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct common_sampler;

// llama_sampler API overloads

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
// note: can mutate params in some cases
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);

void common_sampler_free(struct common_sampler * gsmpl);

Expand All @@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);

// get the underlying llama_sampler_chain
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);

// extended sampling implementation:
Expand Down
18 changes: 12 additions & 6 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;

std::vector<llama_sampler *> samplers;
std::vector<llama_sampler_seq_config> sampler_configs;

for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
Expand All @@ -78,7 +78,13 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));

samplers.push_back(smpl);
sampler_configs.push_back({ i, smpl });
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
}

llama_context * ctx = llama_init_from_model(model, ctx_params);
Expand Down Expand Up @@ -180,7 +186,7 @@ int main(int argc, char ** argv) {
continue;
}

const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);

// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
Expand Down Expand Up @@ -236,15 +242,15 @@ int main(int argc, char ** argv) {
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));

LOG("\n");
llama_perf_sampler_print(samplers[0]);
llama_perf_sampler_print(sampler_configs[0].sampler);
llama_perf_context_print(ctx);

fprintf(stderr, "\n");

llama_batch_free(batch);

for (auto & sampler_config : samplers) {
llama_sampler_free(sampler_config);
for (auto & sampler_config : sampler_configs) {
llama_sampler_free(sampler_config.sampler);
}

llama_free(ctx);
Expand Down
Loading
Loading