Skip to content

Add the padding-aware bucketing strategy#762

Merged
adobrzyn merged 5 commits into
vllm-project:mainfrom
yangulei:linear_limit_main
Apr 9, 2026
Merged

Add the padding-aware bucketing strategy#762
adobrzyn merged 5 commits into
vllm-project:mainfrom
yangulei:linear_limit_main

Conversation

@yangulei
Copy link
Copy Markdown
Collaborator

@yangulei yangulei commented Dec 26, 2025

Introduce the new PaddingAwareBucketingStrategy which could be enabled by setting VLLM_BUCKETING_STRATEGY="pad" and further tuning by setting the VLLM_{phase}_{dim}_BUCKET_PAD_MAX and VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT for the max absolute and relative padding limits respectively.

Motivation

The exponential bucketing is introduced to significantly reduce the number of buckets. For the example with max_num_batched_tokens=8192, max_model_len=32768, max_num_seqs=256 and # hpu blocks: 4127. The exponential bucketing generates 120 prompt buckets and 81 decode buckets, and the linear bucketing generates 14368 prompt buckets and 4042 decode buckets. The exponential buckets are filtered combinations with the following ranges:

Prompt query range: [128, 256, 384, 512, 640, 1792, 2816, 3968, 4992, 6144, 7168, 8192]
Prompt context range: [0, 1, 3, 8, 22, 56, 90, 124, 158, 192]
Decode BS range: [1, 2, 4, 8, 14, 24, 42, 78, 140, 256]
Decode context range: [1, 256, 512, 768, 1024, 1280, 1536, 1792, 2304, 2816, 3584, 4352]

The max absolute padding (max(bucket[i]-bucket[i-1]-1)) is proportional to the bucket max without limitations, and the max relative padding ((bucket[i]-bucket[i-1]-1)/bucket[i]) towards 50% for large bucket max. The large padding cause large overhead especially for the cases with long sequences.

We need a bucketing algorithm to balance the bucket number (warmup time) and the runtime performance (padding overhead).

Changes

  • Enhance the warmup_range in linear bucketing to warmup_range_with_limits to generate a range that ensure the absolute and relative padding not exceeds the specified limits.
  • Introduce new ENVs named VLLM_{phase}_{dim}_BUCKET_PAD_MAX and VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT to set the absolute and relative padding limits respectively.
  • Introduce the new ENV named VLLM_BUCKETING_STRATEGY to selects the bucketing strategies from exp, lin and pad for the default exponential, the linear and the padding-aware bucketing strategies respectively.

For the above example with default settings:

Prompt query range: [128, 256, 384, 512, 640, 768, 1024, 1280, 1664, 2176, 2816, 3712, 4864, 6400, 8192]
Prompt context range: [0, 1, 2, 4, 6, 8, 12, 16, 22, 30, 40, 54, 64, 86, 116, 128, 172, 192, 255]
Decode BS range: [1, 2, 4, 6, 8, 12, 16, 22, 30, 32, 44, 60, 64, 86, 96, 128, 160, 192, 224, 256]
Decode context range: [128, 256, 384, 512, 640, 768, 1024, 1280, 1664, 2176, 2816, 3712, 4127]

Which results in 284 prompt buckets and 222 decode buckets with much less padding.

Benefits

  • Could simulate the exponential bucketing by setting large VLLM_{phase}_{dim}_BUCKET_PAD_MAX and setting VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT=50.
  • Could fallback to the original linear bucketing by setting VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT=0.
  • Users could further tuning the absolute and relative padding limits to balance the warmup time and runtime performance.
  • Setting VLLM_{phase}_{dim}_BUCKET_PAD_MAX to multiple of PT_HPU_SDPA_BR_FACTOR and PT_HPU_SDPA_BC_FACTOR could generate buckets that align with the slicing chunk size and give better performance.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces configurable absolute and relative padding limits to the linear bucketing algorithm to better balance warmup time and runtime performance. The change replaces exponential bucketing as the default strategy.

  • Adds new environment variables for controlling padding limits (PAD_MAX and PAD_PERCENT) across all bucket dimensions
  • Implements a new warmup_range_with_limits function that generates buckets respecting these padding constraints
  • Changes the default bucketing strategy from exponential to linear with limits

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
vllm_gaudi/extension/features.py Adds new environment variables for padding limits and switches default bucketing strategy to linear
vllm_gaudi/extension/bucketing/linear.py Implements padding-aware bucket generation with new warmup_range_with_limits function and updated configuration handling
vllm_gaudi/extension/bucketing/common.py Simplifies bucketing strategy selection and adds debug logging for bucket ranges
tests/unit_tests/test_bucketing.py Updates tests to accommodate new padding parameters in bucket configuration
docs/configuration/env_variables.md Documents new padding-related environment variables and updated defaults
Comments suppressed due to low confidence (1)

vllm_gaudi/extension/bucketing/linear.py:1

  • The BUCKET_PAD_PERCENT environment variables are defined as int type, but they represent percentages. This could lead to confusion as the documentation shows 25 meaning 25%, but users might expect values like 0.25. Consider using a float type or clearly documenting that the value should be specified as an integer percentage (0-100).
import os

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread docs/configuration/env_variables.md Outdated
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@yangulei yangulei force-pushed the linear_limit_main branch 3 times, most recently from 0e742bc to 575e7d1 Compare December 30, 2025 01:51
@yangulei yangulei requested a review from Copilot December 30, 2025 02:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Comments suppressed due to low confidence (1)

vllm_gaudi/extension/bucketing/linear.py:1

  • The PAD_PERCENT parameter is stored as an integer but represents a percentage value (0-50). Consider using a float type or renaming to indicate it's in integer percentage points to avoid confusion.
import os

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
@yangulei yangulei requested a review from Copilot December 30, 2025 03:08
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread vllm_gaudi/extension/bucketing/linear.py
@yangulei
Copy link
Copy Markdown
Collaborator Author

yangulei commented Jan 7, 2026

Submitted #780 to solve the OOM issue in CI.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 7, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 8, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
cddbc2b4b2547c681d1bdb876fdd6a7b8e0ec58d

@yangulei yangulei force-pushed the linear_limit_main branch 2 times, most recently from e4eabf6 to fefd207 Compare January 15, 2026 01:35
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
66652e8082b69ba7d1e6aca7c234433de55f1b9b

@yangulei yangulei force-pushed the linear_limit_main branch 2 times, most recently from 4b01722 to 8c7c399 Compare January 15, 2026 17:44
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
4c1c501a7ee1d5efbad945ea62a702ce5cefb799

@yangulei yangulei force-pushed the linear_limit_main branch from 0de2420 to 4a8ca80 Compare March 24, 2026 03:42
Copy link
Copy Markdown
Collaborator

@afierka-intel afierka-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yangulei ,

Thank you for your patience and for your PR: #762. We now have nearly all the benchmark results and have made a decision on the path forward.

We've decided not to change the default bucketing strategy to the one proposed in your PR. You're right that it improves performance in long-context scenarios, and it doesn't degrade performance in other cases either. However, the warmup time is unacceptable for virtually all scenarios — with your PR we see 2.5×–3.5× longer warmup compared to main. This is visible not only on MoE models, but also on simple, small models like granite-3.3-2b.

That said, we do like your approach and would like to incorporate it — but as a separate bucketing strategy. Currently we have linear and exponential bucketing strategies. We'd be happy to add yours as a third option, targeted at long-context scenarios. It would not be the default, but could be enabled via an environment variable or a configuration parameter.

Thank you for understanding,
Artur

@yangulei yangulei changed the title Introduce absolute and relative padding limits to the linear bucketing Add the padding-aware bucketing strategy Mar 31, 2026
@yangulei yangulei force-pushed the linear_limit_main branch from 4a8ca80 to 1debb37 Compare March 31, 2026 07:52
@yangulei
Copy link
Copy Markdown
Collaborator Author

@afierka-intel Moved the impl to PaddingAwareBucketingStrategy and enabled by setting VLLM_BUCKETING_STRATEGY="pad" explicitly. Please help to review again, thanks.

@yangulei yangulei force-pushed the linear_limit_main branch from 0400124 to 1debb37 Compare March 31, 2026 08:23
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@yangulei yangulei force-pushed the linear_limit_main branch from 1debb37 to 84f94f7 Compare March 31, 2026 08:24
@yangulei yangulei requested a review from Copilot March 31, 2026 08:25
@afierka-intel
Copy link
Copy Markdown
Collaborator

Hi @yangulei 👋

Thank you for reworking this PR to make padding-aware bucketing an opt-in strategy — the updated design with VLLM_BUCKETING_STRATEGY=pad is exactly what we discussed, and the implementation looks solid overall. The documentation and test coverage are really nice too.

After reviewing the full diff against the current codebase, I found a couple of things that need attention before we can merge. I want to be upfront: one of them is a pattern that is easy to miss because it is buried in each strategy independently rather than centralized.


🔴 Needs fix before merge

1. Missing merged_prefill handling in PaddingAwareBucketingStrategy

Both LinearBucketingStrategy (linear.py:15-54) and ExponentialBucketingStrategy (exponential.py:43-82) override bucket configs in get_prompt_cfgs() when get_config().merged_prefill is True — they set batch size to (1, 1, 1, ...) and multiply step sizes by 4. Since this logic lives in each strategy (not centralized in the manager), PaddingAwareBucketingStrategy needs the same handling.

Without it, using VLLM_BUCKETING_STRATEGY=pad together with merged prefill would produce incorrect bucket configurations.

Suggested fix: Add a similar if get_config().merged_prefill: block in PaddingAwareBucketingStrategy.get_prompt_cfgs(), adapted for the 5-element config format [min, step, max, pad_max, pad_percent].

2. VLLM_EXPONENTIAL_BUCKETING removed without deprecation path

The current codebase actively uses VLLM_EXPONENTIAL_BUCKETING in features.py (line 16) and common.py (lines 103-105). This PR removes it entirely and replaces it with VLLM_BUCKETING_STRATEGY. While the replacement is correct, users who have VLLM_EXPONENTIAL_BUCKETING=false in their scripts or CI pipelines will have that setting silently ignored after upgrading.

Suggested fix: Add a deprecation warning in common.py or features.py that detects if VLLM_EXPONENTIAL_BUCKETING is set and logs a message like:

"VLLM_EXPONENTIAL_BUCKETING is deprecated and will be removed in a future release. 
 Use VLLM_BUCKETING_STRATEGY='exp'|'lin'|'pad' instead."

Optionally, you could auto-map VLLM_EXPONENTIAL_BUCKETING=false to bucketing_strategy='lin' for backward compatibility during the transition period.


🟡 Minor suggestions (non-blocking)

3. decode_query_bucket_cfg = [1, 1, 1, 1, 1] — add a comment

This is correct (decode query is always 1 token in autoregressive mode), but a reader unfamiliar with the code might wonder why all five values are 1. A short comment like # Decode query is always 1 token (autoregressive), no bucketing needed would help.

4. Default strategy test

The test test_get_bucketing_strategy_selected_by_env covers explicit exp/lin/pad selection, which is great. Consider adding one more case that verifies the default (no env var set) returns ExponentialBucketingStrategy — this would serve as a regression guard for the "don't change the default" requirement.


✅ Things that look good

  • exp is correctly the default via Value('bucketing_strategy', 'exp', ...)
  • The warmup_range_with_limits algorithm and its documented examples are clear and correct
  • The VLLM_PROMPT_SEQ_BUCKET_* to VLLM_PROMPT_QUERY_BUCKET_* fallback with deprecation warning is well handled
  • Documentation in bucketing_mechanism.md is excellent — the examples really help explain the strategy
  • The % pad_max alignment rule in warmup_range_with_limits is justified for FusedSDPA chunk alignment (though a brief inline comment explaining why would help future readers)

Thank you again for the quality work here. Looking forward to the next iteration! 🙏

@yangulei yangulei force-pushed the linear_limit_main branch from 84f94f7 to 7866ce5 Compare April 1, 2026 01:59
@yangulei
Copy link
Copy Markdown
Collaborator Author

yangulei commented Apr 1, 2026

Hi @afierka-intel
All the requested changes are addressed in the last two commits.

Copy link
Copy Markdown
Collaborator

@afierka-intel afierka-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing all the feedback! The changes look great:

  • merged_prefill handling is now properly implemented in PaddingAwareBucketingStrategy, consistent with the linear and exponential strategies
  • VLLM_EXPONENTIAL_BUCKETING backward compatibility with deprecation warning — clean implementation
  • ✅ Helpful comment on decode_query_bucket_cfg and default strategy test added

Nice work on the comprehensive test coverage too. 👍

Could you please rebase on top of the latest main so we can trigger a fresh CI run?

yangulei added 5 commits April 7, 2026 13:15
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
@yangulei yangulei force-pushed the linear_limit_main branch from 7866ce5 to e5d3474 Compare April 7, 2026 13:23
@yangulei
Copy link
Copy Markdown
Collaborator Author

yangulei commented Apr 7, 2026

Could you please rebase on top of the latest main so we can trigger a fresh CI run?

Done.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 7, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
d28d86e8a34bf2617be294c235d6e6ef3321917b

@adobrzyn adobrzyn merged commit c2641a6 into vllm-project:main Apr 9, 2026
71 checks passed
@yangulei yangulei deleted the linear_limit_main branch April 9, 2026 08:14
@yangulei yangulei restored the linear_limit_main branch April 9, 2026 08:15
Copilot AI added a commit that referenced this pull request Apr 14, 2026
Signed-off-by: copilot <copilot@github.com>

Tests cover the four PRs addressing long-context bucketing:
- PR #762:  Padding-aware bucketing strategy (warmup ranges, configs, generation)
- PR #1122: Exponential decode block formula, limit cap, filter, linear fix
- PR #1155: FusedSDPA slicing contract (pad_max bounds, strategy selection)
- PR #1346: HPU graph capture skip (cudagraph size, warmup clamp scenarios)
- Cross-PR integration: end-to-end 256K scenario, fallback, regressions

49 test functions organized in 6 test classes.

Co-authored-by: michalkuligowski <23379006+michalkuligowski@users.noreply.github.com>
osavchenkox pushed a commit to osavchenkox/vllm-gaudi that referenced this pull request May 5, 2026
Introduce the new `PaddingAwareBucketingStrategy` which could be enabled
by setting `VLLM_BUCKETING_STRATEGY="pad"` and further tuning by setting
the `VLLM_{phase}_{dim}_BUCKET_PAD_MAX` and
`VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT` for the max absolute and
relative padding limits respectively.

## Motivation
The exponential bucketing is introduced to significantly reduce the
number of buckets. For the example with `max_num_batched_tokens=8192`,
`max_model_len=32768`, `max_num_seqs=256` and `# hpu blocks: 4127`. The
exponential bucketing generates **120** prompt buckets and **81** decode
buckets, and the linear bucketing generates **14368** prompt buckets and
**4042** decode buckets. The exponential buckets are filtered
combinations with the following ranges:
```
Prompt query range: [128, 256, 384, 512, 640, 1792, 2816, 3968, 4992, 6144, 7168, 8192]
Prompt context range: [0, 1, 3, 8, 22, 56, 90, 124, 158, 192]
Decode BS range: [1, 2, 4, 8, 14, 24, 42, 78, 140, 256]
Decode context range: [1, 256, 512, 768, 1024, 1280, 1536, 1792, 2304, 2816, 3584, 4352]
```
The max absolute padding (`max(bucket[i]-bucket[i-1]-1)`) is
proportional to the bucket max without limitations, and the max relative
padding (`(bucket[i]-bucket[i-1]-1)/bucket[i]`) towards **50%** for
large bucket max. The large padding cause large overhead especially for
the cases with long sequences.

We need a bucketing algorithm to **balance** the bucket number (warmup
time) and the runtime performance (padding overhead).

## Changes
- Enhance the `warmup_range` in linear bucketing to
`warmup_range_with_limits` to generate a range that ensure the absolute
and relative padding not exceeds the specified limits.
- Introduce new ENVs named `VLLM_{phase}_{dim}_BUCKET_PAD_MAX` and
`VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT` to set the absolute and relative
padding limits respectively.
- Introduce the new ENV named `VLLM_BUCKETING_STRATEGY` to selects the
bucketing strategies from `exp`, `lin` and `pad` for the default
exponential, the linear and the padding-aware bucketing strategies
respectively.

For the above example with default settings:
```
Prompt query range: [128, 256, 384, 512, 640, 768, 1024, 1280, 1664, 2176, 2816, 3712, 4864, 6400, 8192]
Prompt context range: [0, 1, 2, 4, 6, 8, 12, 16, 22, 30, 40, 54, 64, 86, 116, 128, 172, 192, 255]
Decode BS range: [1, 2, 4, 6, 8, 12, 16, 22, 30, 32, 44, 60, 64, 86, 96, 128, 160, 192, 224, 256]
Decode context range: [128, 256, 384, 512, 640, 768, 1024, 1280, 1664, 2176, 2816, 3712, 4127]
```
Which results in **284** prompt buckets and **222** decode buckets with
much less padding.

## Benefits
- Could simulate the exponential bucketing by setting large
`VLLM_{phase}_{dim}_BUCKET_PAD_MAX` and setting
`VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT=50`.
- Could fallback to the original linear bucketing by setting
`VLLM_{phase}_{dim}_BUCKET_PAD_PERCENT=0`.
- Users could further tuning the absolute and relative padding limits to
balance the warmup time and runtime performance.
- Setting `VLLM_{phase}_{dim}_BUCKET_PAD_MAX` to multiple of
`PT_HPU_SDPA_BR_FACTOR` and `PT_HPU_SDPA_BC_FACTOR` could generate
buckets that align with the slicing chunk size and give better
performance.

---------

Signed-off-by: Youlei Yang <youlei.yang@intel.com>
kamil-kaczor pushed a commit that referenced this pull request May 11, 2026
FusedSDPA can be split into smaller chunks to improve performance while
using the padding-aware bucketing strategy which guarantees the max
absolute padding in the sequence and context dimensions.

## Usage
| Parameter name | Description | Default value |
| ---------------------------------------- |
--------------------------------------------------------------------------------------------
| ------------------------------------------- |
| `VLLM_HPU_FSDPA_SLICE_ENABLED` | Enable the slicing. | `True` for
padding-aware bucketing strategy |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which
slicing is applied. | `min(max_num_batched_tokens, 8192)` |
| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and
`kv_len` in each chunk. Rounded up to the next multiple of 1024. |
`VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` |
| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a
separate graph to reduce compilation time. | `true` for lazy mode and
`false` otherwise |

> [!IMPORTANT]
> These parameters are effective only with the padding-aware bucketing
strategy set by `VLLM_BUCKETING_STRATEGY="pad"`.


## Implementation
Take the prefix-prefill with `[bs, query, context] = [1, 9037, 8832]` as
an example. The prefill shape will first be padded to `[1, 10880,
11008]` by the bucketing. The attention mask will be looks like:

![fsdpa_apc_b1_s9037_c8832_pb1_ps10880_pc11008_attention_mask](https://github.com/user-attachments/assets/f11540c1-0c5d-4e5d-92d6-4b3b42b473f0)

> Not that there are padding in query and context dimensions.

The original implementation pass the full attention mask to the
FusedSDPA kernel.

This PR introduced an implementation to calculate the FSDPA in chunks by
slicing the `Q`, `K` and `V` as below:

![fsdpa_apc_b1_s9037_c8832_pb1_ps10880_pc11008_bf16_SliceQKV](https://github.com/user-attachments/assets/a830612a-3e46-48a0-9b1a-3c9102598a2e)

Where the color of the rectangles differentiate the `is_causal` and
`attn_mask` parameters passed to the FusedSDPA kernel:
- `rgb(255,0,0)`: `is_causal=False` and `attn_mask is not None`
- `rgb(255,255,0)`: `is_causal=True` and `attn_mask=None`
- `rgb(255,0,255)`: `is_causal=False` and `attn_mask=None`

In this way, most of the chunks call the FusedSDPA without attention
mask to get better performance, and the graph for the chunks might be
reused across different buckets to reduce the warmup duration.

## Dependencies
- #762 as the number of
chunks with padding is determined by the `PAD_MAX` for the query and
context.

---
### Thanks @Wei-Lin-Intel for the original idea and the detailed
behavior of the FusedSDPA kernel.

---------

Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants