Skip to content

support online FP8 quantization for FA on NPU #2236#2640

Merged
gcanlin merged 35 commits into
vllm-project:mainfrom
lyj-jjj:main
May 13, 2026
Merged

support online FP8 quantization for FA on NPU #2236#2640
gcanlin merged 35 commits into
vllm-project:mainfrom
lyj-jjj:main

Conversation

@lyj-jjj
Copy link
Copy Markdown
Contributor

@lyj-jjj lyj-jjj commented Apr 9, 2026

1. Background and Goal

In generative models, FA accounts for more than 50% of the time when generating 480p videos and more than 70% when generating 720p videos. Therefore, online quantization of FA can significantly reduce the DIT time, which is essential for the support of view generation models.

Based on PR1413, the FA online FP8 quantization capability on the NPU is extended. The goal of this PR is to introduce online FP8-quantized FA on NPU, with optional step/layer-level fallback, so that DiT latency can be significantly reduced while maintaining generation quality as much as possible.


2. Scope

  • Platform scope: NPU
  • Feature scope:
    • Add end-to-end support for kv_cache_dtype=fp8
    • Support kv_cache_skip_steps and kv_cache_skip_layers selectors for fine-grained fallback
    • Keep layout/call compatibility for both cross-attention and self-attention
  • Non-goals:
    • No changes to CUDA/XPU quantized paths
    • No offline pre-quantized weight workflow (this PR focuses on online quantization)

3. Core Design

3.1 Config and Parameter Plumbing

Introduce and propagate the following parameters from config/entrypoints down to the attention execution layer:

  • kv_cache_dtype (e.g., fp8)
  • kv_cache_skip_steps (string index set)
  • kv_cache_skip_layers (string index set)
    These parameters are carried through the diffusion attention metadata system. At each layer forward, whether FP8 path is enabled is determined by current step/layer.

3.2 Backend Capability Declaration and Safe Fallback

FlashAttentionBackend uses supports_kv_cache_dtype() to declare supported kv-cache dtypes per platform.
On NPU, when kv_cache_dtype in metadata is valid, execution enters the quantized FA path; otherwise it falls back to native FA.
This design guarantees:

  • Safe degradation for invalid config / unsupported platform
  • Reusable attention abstraction for future platform capability extension

3.3 NPU Execution Path Routing

forward_npu() routes by attention type and quantization switch:

  1. cross-attention: use normal FA (layout-adapted)
  2. self-attention with effective kv_cache_dtype: use forward_fa_quant_npu()
  3. otherwise: use normal forward_fa_npu()
    In the quantized path, fp8_rotate_quant_fa() calls the NPU quantized fused operator. Necessary tensor transposes are applied according to layout requirements so that existing model tensor formats remain compatible.

3.4 Step/Layer Selective Fallback

In the attention layer, kv_cache_skip_steps and kv_cache_skip_layers are parsed into index sets.
At runtime, if either skip condition is matched, FP8 is disabled for the current layer/step and execution falls back to native dtype FA.
Value:

  • Provides controllable knobs for quality/performance trade-off
  • Enables gradual quantization rollout and model-level tuning

4. Compatibility and Risk Control

  • Default behavior unchanged: if kv_cache_dtype is not configured, runtime behavior is identical to previous versions.
  • Platform compatibility: FP8 quantization capability is enabled only for NPU; other platforms are unaffected.
  • Error fallback: unsupported dtypes are downgraded safely, avoiding hard runtime failures.
  • Quality risk: step/layer fallback can limit impact scope and support scenario-based tuning.

5. Performance and Effect (Current Test)

In Wan2.2 I2V scenario (1280x720, 61 frames, 4 steps):


6. Test Plan

  • Functional correctness:
    • Inference succeeds both before and after enabling kv_cache_dtype=fp8
    • kv_cache_skip_steps/layers fallback works as expected when selectors match
  • Stability:
    • Basic regression on single-card / multi-card NPU
    • Cover both cross-attention and self-attention paths
  • Performance:
    • Compare DiT total latency and FA time ratio vs BF16 baseline
  • Quality:
    • Compare generated video quality to detect visible artifacts/regressions

7. Usage Example

export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3

export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
export PYTHONPATH=/home/vllm-omni-0.18.0/vllm-omni:$PYTHONPATH

python image_to_video.py \
--model /home/weights/Wan2.2-I2V-A14B-LightX2V-Diffusers \
--image i2v_input.jpg \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside" \
--height 1280 \
--width 720 \
--num-frames 61 \
--guidance-scale 1.0 \
--guidance-scale-high 1.0 \
--num-inference-steps 4 \
--cfg-parallel-size 1 \
--ulysses-degree 4 \
--boundary-ratio 0.875 \
--flow-shift 12.0 \
--fps 16 \
--output i2v_output_origin_weight_1.mp4 \
--vae-patch-parallel-size 4 \
--vae-use-tiling \
--kv-cache-dtype fp8 \
--kv-cache-skip-steps "0,1"  \
--kv-cache-skip-layers "0-2" \

@lyj-jjj lyj-jjj requested a review from hsliuustc0106 as a code owner April 9, 2026 08:50
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@gcanlin gcanlin self-assigned this Apr 11, 2026
@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early Review -- WIP Online FP8 Quantization for FA on NPU

Thanks for the PR. This is a useful feature -- online FP8 quantization for flash attention on NPU with selective step/layer skipping. Below is an early review with the items I think should be addressed before this leaves WIP status.


Architecture & Design

  1. _update_attn_metadata mutates its input in-place but also returns it. The name "update" suggests mutation, but the function also creates a new AttentionMetadata when base is None. This dual behavior is error-prone. Consider either (a) always returning a new copy (safe, no aliasing surprises), or (b) always mutating in-place and requiring a non-None input. The current mixed mode means callers who pass a shared base_attn_metadata across self-attn and cross-attn in the same block will see cross-contamination of attn_kind / attn_mask because the same object is mutated twice.

    In wan2_2_transformer.py, WanTransformerBlock.forward() calls _update_attn_metadata(base_attn_metadata, ...) for both self-attn and cross-attn. If the first call mutates the shared base_attn_metadata (setting attn_kind="self"), the second call receives that already-mutated object and overwrites attn_kind to "cross". The first attention op has already consumed it, so it works today, but this is fragile. A shallow copy before mutation would be safer:

    metadata = copy.copy(base) if base is not None else AttentionMetadata()
  2. _update_attn_metadata is a module-level private function exported in the __all__-equivalent import in wan2_2_transformer.py. Consider making it a proper method on AttentionMetadata (e.g., AttentionMetadata.with_updates(...)) to improve discoverability and encapsulation.

  3. Lazy resolution pattern in Attention.forward() calls _resolve_kv_cache_dtype() and _resolve_kv_cache_skip_selectors_from_config() on every forward call. After the first resolution, the _resolved flags short-circuit, but the function-call overhead remains. Consider resolving once in a dedicated setup hook (e.g., first forward only, or a post_init method) rather than checking flags in the hot path.


Correctness Concerns

  1. forward_fa_quant_npu calls is_quantized_kv_cache(kv_cache_dtype) after forward_npu already checked kv_cache_dtype is not None. If kv_cache_dtype is a non-None string that is not in _FP8_KV_LABELS (e.g., "int8"), the function logs a warning and falls back to forward_fa_npu. However, _handle_kv_cache_dtype in the base class AttentionImpl already cleared unsupported dtypes to None with its own warning. This means the forward_fa_quant_npu warning path is dead code for any dtype that isn't in _supported_kv_cache_dtypes. Consider removing the redundant check or making the intent clearer.

  2. Global mutable state in kv_quant_npu.py -- _ROT_MATRIX and _IS_NOT_IMPORTED are module-level globals guarded by a boolean flag but not thread-safe. In async serving with multiple workers, two threads could race on _IS_NOT_IMPORTED and _ROT_MATRIX. Use threading.Lock or functools.lru_cache for the import guard, and be aware that _ROT_MATRIX.to(device) reassigns the global without synchronization.

  3. _ROT_MATRIX device migration is lossy. if _ROT_MATRIX.device != device: _ROT_MATRIX = _ROT_MATRIX.to(device) replaces the global with the new-device copy. In multi-device scenarios (e.g., tensor parallel across NPU cards), the second device overwrites the global and breaks the first. Consider using a dict[torch.device, torch.Tensor] cache.

  4. Magic numbers in npu_fused_infer_attention_score_v2 call. pre_tokens=2147483647, next_tokens=2147483647, query_quant_mode=7, etc. are opaque. Please add brief inline comments explaining what these values mean (e.g., "INT32_MAX = no causal masking", "7 = per-block FP8 quantization mode").


Plumbing / Config

  1. serve.py adds --kv-cache-skip-steps and --kv-cache-skip-layers but does NOT add --kv-cache-dtype for the serve entrypoint. The offline example adds all three, but the online serve path only adds the skip selectors. This means users who launch via vllm_omni serve cannot enable FP8 KV quantization without a YAML config. Is this intentional? If so, document it; if not, add the missing CLI arg.

  2. _resolve_stage_configs in async_omni_engine.py uses hasattr checks (not hasattr(cfg.engine_args, "kv_cache_dtype")). Since OmniDiffusionConfig is a dataclass with defaults, hasattr will almost always be True. The guard is likely intended to be cfg.engine_args.kv_cache_dtype is None, which is already the second condition. The hasattr check is misleading dead code -- consider removing it.

  3. The timestep_scalar variable computed in pipeline_wan2_2_i2v.py line ~508 is unused. It is assigned but never referenced. Remove it or use it.


Style / Cleanup

  1. Double blank line after forward_fa_quant_npu method and before forward_fa_npu in flash_attn.py -- one method has standard indentation, the other (forward_fa_npu) uses 8-space indentation for its parameter list instead of the 4-space style used everywhere else in this file. Please normalize.

  2. _parse_selector_indices is a @staticmethod on Attention but has no dependency on the class. It would be cleaner as a module-level utility or on the config dataclass where it is semantically relevant.

  3. PR description is empty -- Purpose, Test Plan, and Test Result sections are blank. Even for WIP, please add a brief description of the approach and any benchmark numbers you have so far (e.g., memory savings, latency impact on NPU).


Summary

The overall approach is sound: threading KV-cache quantization dtype through attention metadata and gating it per step/layer is a clean design. The main risks are the in-place mutation aliasing in _update_attn_metadata, the thread-unsafe globals in kv_quant_npu.py, and the missing --kv-cache-dtype CLI arg for the serve path. Looking forward to the non-WIP version.

@lyj-jjj
Copy link
Copy Markdown
Contributor Author

lyj-jjj commented Apr 17, 2026

@lishunyang12 Thank you for your careful review. This is very helpful to us. After analyzing the benefits, I will revise and submit the formal PR accordingly.

Comment thread vllm_omni/entrypoints/cli/serve.py
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
@lyj-jjj lyj-jjj changed the title [wip]support online FP8 quantization for FA on NPU #2236 [WIP]support online FP8 quantization for FA on NPU #2236 Apr 24, 2026
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
@lyj-jjj lyj-jjj changed the title [WIP]support online FP8 quantization for FA on NPU #2236 support online FP8 quantization for FA on NPU #2236 Apr 24, 2026
@lyj-jjj lyj-jjj changed the title support online FP8 quantization for FA on NPU #2236 [WIP] support online FP8 quantization for FA on NPU #2236 Apr 24, 2026
lyj-jjj added 4 commits April 24, 2026 18:44
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
@lyj-jjj
Copy link
Copy Markdown
Contributor Author

lyj-jjj commented Apr 25, 2026

@lishunyang12 @gcanlin @hsliuustc0106
This PR is ready. Please help me review and merge it.

@lyj-jjj lyj-jjj changed the title [WIP] support online FP8 quantization for FA on NPU #2236 support online FP8 quantization for FA on NPU #2236 Apr 25, 2026
lyj-jjj added 4 commits April 25, 2026 16:52
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a UT for it. Thanks

@Gaohan123 Gaohan123 removed this from the v0.20.0 milestone Apr 30, 2026
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
@lyj-jjj
Copy link
Copy Markdown
Contributor Author

lyj-jjj commented May 7, 2026

Please add a UT for it. Thanks

I have added UT, please review it again.

@lyj-jjj
Copy link
Copy Markdown
Contributor Author

lyj-jjj commented May 12, 2026

vllm 0.20.1.dev0+g88d34c640.d20260511.empty
vllm_ascend 0.19.1rc2.dev47+g592cc8f7d
vllm-omni 0.20.1.dev54+g41ab49dc8.npu

FA-BF16:
WARNING 05-12 09:52:34 [kv_transfer_manager.py:985] No connector available for receiving KV cache
 50%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                         | 2/4 [00:05<00:05,  2.89s/it][rank0]:[W512 09:52:44.394165360 compiler_depend.ts:250] Warning: CAUTION: The operator 'aten::_linalg_solve_ex.result' is not currently supported on the NPU backend and will fall back to run on the CPU. This may have performance implications. (function npu_cpu_fallback)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:11<00:00,  2.89s/it]
INFO 05-12 09:52:49 [pipeline_wan2_2_i2v.py:642] Pipeline stage timing summary: TextEncoding=51.44 ms, ImageEncoding=0.04 ms, LatentPreparation=1200.88 ms, Denoising=11602.64 ms (4 steps), Decoding=2169.06 ms, StagesSum=15024.06 ms, PipelineWall=15024.45 ms, Unaccounted=0.38 ms
INFO 05-12 09:52:49 [diffusion_model_runner.py:215] Peak GPU memory (this request): 67.45 GB reserved, 67.34 GB allocated, 0.11 GB pool overhead (0.2%)
INFO 05-12 09:52:50 [diffusion_engine.py:160] Generation completed successfully.
INFO 05-12 09:52:50 [diffusion_engine.py:207] Post-processing completed in 0.1305 seconds
INFO 05-12 09:52:50 [diffusion_engine.py:210] DiffusionEngine.step breakdown: preprocess=1.37 ms, add_req_and_wait=15708.93 ms, postprocess=130.49 ms, total=15841.35 ms
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.84s/it]
Total generation time: 15.8569 seconds (15856.94 ms)
Saved generated video to i2v_output_origin_weight_1.mp4
INFO 05-12 09:52:51 [async_omni_engine.py:1950] [AsyncOmniEngine] Shutting down Orchestrator
INFO 05-12 09:52:51 [orchestrator.py:212] [Orchestrator] Received shutdown signal

FA-FP8:

('Warning: torch.save with "_use_new_zipfile_serialization = False" is not recommended for npu tensor, which may bring unexpected errors and hopefully set "_use_new_zipfile_serialization = True"', 'if it is necessary to use this, please convert the npu tensor to cpu tensor for saving')
WARNING 05-12 09:48:58 [kv_transfer_manager.py:985] No connector available for receiving KV cache
 50%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                         | 2/4 [00:04<00:04,  2.48s/it][rank0]:[W512 09:49:07.866251390 compiler_depend.ts:250] Warning: CAUTION: The operator 'aten::_linalg_solve_ex.result' is not currently supported on the NPU backend and will fall back to run on the CPU. This may have performance implications. (function npu_cpu_fallback)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.48s/it]
INFO 05-12 09:49:11 [pipeline_wan2_2_i2v.py:642] Pipeline stage timing summary: TextEncoding=50.78 ms, ImageEncoding=0.04 ms, LatentPreparation=1204.40 ms, Denoising=9990.01 ms (4 steps), Decoding=2143.19 ms, StagesSum=13388.42 ms, PipelineWall=13388.85 ms, Unaccounted=0.43 ms
INFO 05-12 09:49:11 [diffusion_model_runner.py:215] Peak GPU memory (this request): 67.61 GB reserved, 67.43 GB allocated, 0.18 GB pool overhead (0.3%)
INFO 05-12 09:49:12 [diffusion_engine.py:160] Generation completed successfully.
INFO 05-12 09:49:12 [diffusion_engine.py:207] Post-processing completed in 0.1392 seconds
INFO 05-12 09:49:12 [diffusion_engine.py:210] DiffusionEngine.step breakdown: preprocess=1.46 ms, add_req_and_wait=14180.92 ms, postprocess=139.15 ms, total=14321.99 ms
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:14<00:00, 14.32s/it]
Total generation time: 14.3322 seconds (14332.22 ms)
Saved generated video to i2v_output_origin_weight_1.mp4

UT:

================================================================================================================================ test session starts =================================================================================================================================
platform linux -- Python 3.11.9, pytest-9.0.3, pluggy-1.6.0 -- /home/lyj/omni_skf/bin/python
cachedir: .pytest_cache
rootdir: /home/lyj/0511/vllm-omni
configfile: pyproject.toml
plugins: asyncio-1.3.0, anyio-4.13.0
asyncio: mode=Mode.AUTO, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 6 items                                                                                                                                                                                                                                                                    

test_kv_quant_npu.py::test_is_quantized_kv_cache PASSED                                                                                                                                                                                                                        [ 16%]
test_kv_quant_npu.py::TestKVQuantNPUUnit::test_get_rot_matrix_caches_by_device_dtype_and_head_dim PASSED                                                                                                                                                                       [ 33%]
test_kv_quant_npu.py::TestKVQuantNPUUnit::test_fp8_rotate_quant_fa_layouts_scale_and_crop[BNSD-input_shape0-out_shape0-None-0.35355339059327373] PASSED                                                                                                                        [ 50%]
test_kv_quant_npu.py::TestKVQuantNPUUnit::test_fp8_rotate_quant_fa_layouts_scale_and_crop[BSND-input_shape1-out_shape1-0.125-0.125] PASSED                                                                                                                                     [ 66%]
test_kv_quant_npu.py::TestKVQuantNPUUnit::test_fp8_rotate_quant_fa_invalid_layout_raises PASSED                                                                                                                                                                                [ 83%]
test_kv_quant_npu.py::TestKVQuantNPUSmoke::test_fp8_rotate_quant_fa_real_npu_shape_contract 

PASSED                                                                                                                                                                             [100%]

Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Copy link
Copy Markdown
Collaborator

@david6666666 david6666666 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some remaining issues.

Comment thread vllm_omni/diffusion/attention/backends/abstract.py Outdated
Comment thread vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py Outdated
lyj-jjj added 2 commits May 12, 2026 11:31
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
@lyj-jjj lyj-jjj changed the title [WIP] support online FP8 quantization for FA on NPU #2236 support online FP8 quantization for FA on NPU #2236 May 12, 2026
@david6666666
Copy link
Copy Markdown
Collaborator

LGTM now. @gcanlin @lishunyang12 ptal thx

@david6666666
Copy link
Copy Markdown
Collaborator

david6666666 commented May 12, 2026

Otherwise, please add doc such as https://docs.vllm.ai/en/latest/features/quantization/quantized_kvcache/ and
follow the doc template for vllm-omni quantization, under docs/user_guide/quantization and update .nav.yml

@david6666666 david6666666 added the ready label to trigger buildkite CI label May 12, 2026
lyj-jjj added 3 commits May 12, 2026 19:04
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Comment thread vllm_omni/diffusion/attention/backends/flash_attn.py Outdated
lyj-jjj and others added 8 commits May 12, 2026 21:03
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented May 12, 2026

@lyj-jjj Could you please test this PR again? I made some refactor and would be better to take another look.

Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @lyj-jjj has checked the new code and it can pass the test locally.

@gcanlin gcanlin merged commit 56ca7dd into vllm-project:main May 13, 2026
8 of 9 checks passed
tzhouam pushed a commit that referenced this pull request May 14, 2026
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Galleons2029 pushed a commit to Galleons2029/vllm-omni-ljl that referenced this pull request May 18, 2026
…-project#2640)

Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants