Skip to content

[Spec Decode][CUDA Graphs] Enables Eagle drafter support for FULL CUDA Graph mode#34880

Open
yiz-liu wants to merge 1 commit intovllm-project:mainfrom
yiz-liu:full-spec
Open

[Spec Decode][CUDA Graphs] Enables Eagle drafter support for FULL CUDA Graph mode#34880
yiz-liu wants to merge 1 commit intovllm-project:mainfrom
yiz-liu:full-spec

Conversation

@yiz-liu
Copy link
Contributor

@yiz-liu yiz-liu commented Feb 19, 2026

Purpose

As mentioned in vllm-project/vllm-ascend#5459 and #33341 , this PR enables Full CUDA Graph mode for the Eagle drafter model to improve performance.

The main changes include:

  1. CUDA Graph Integration: Wraps the drafter model with CUDAGraphWrapper during load_model and initializes the necessary keys for the dispatcher to manage graph-based execution.
  2. Graph Capture Support: Builds dummy attention metadata during the dummy_run phase, which is required for successful graph capture.
  3. Dispatch: For the first step, the draft model shares the same uniform_decode with target model and basically has the same batch_desc and cudagraph_mode, while for the following steps, the uniform_decode_query_len will be set to 1 and uniform_decode to True, making it possible to have separate cudagraph_keys.
  4. Metadata Correction: Corrects the memory address handling for query_start_loc and slot_mapping within the attention metadata.
  5. Bug Fix: Adjusts CUDA graph capture sizes to resolve a runtime error that occurred when num_speculative_tokens was set to 2. Also fix prepare_inputs_padded and prepare_next_token_ids_padded for padding issues.

Collectively, these changes allow the Eagle drafter to leverage the performance benefits of Full CUDA Graph mode, enhancing throughput for speculative decoding.

Test Plan

The feature was tested by running the model with the following configuration:

  • num_speculative_tokens=2 (and also validated with 3/4/5)
  • cudagraph_mode="FULL" (and also validated with FULL_DECODE_ONLY and FULL_AND_PIECEWISE)

Test Result

The model's acceptance rate in Full CUDA Graph mode is consistent with the results from eager mode.

For FULL_AND_PIECEWISE:

...
[cuda_graph.py:123] | Unpadded Tokens | Padded Tokens | Num Paddings | Runtime Mode | Count |
[cuda_graph.py:123] |-----------------|---------------|--------------|--------------|-------|
[cuda_graph.py:123] | 40              | 40            | 0            | FULL         | 28    |
[cuda_graph.py:123] | 42              | 50            | 8            | PIECEWISE    | 1     |
[cuda_graph.py:123] | 35              | 35            | 0            | FULL         | 1     |
...
--------------------------------------------------
total_num_output_tokens: 768
num_drafts: 402
num_draft_tokens: 1608
num_accepted_tokens: 370
mean acceptance length: 1.92
--------------------------------------------------
acceptance at token 0: 0.46
acceptance at token 1: 0.22
acceptance at token 2: 0.14
acceptance at token 3: 0.10

For FULL:

...
[cuda_graph.py:123] | Unpadded Tokens | Padded Tokens | Num Paddings | Runtime Mode | Count |
[cuda_graph.py:123] |-----------------|---------------|--------------|--------------|-------|
[cuda_graph.py:123] | 40              | 40            | 0            | FULL         | 29    |
[cuda_graph.py:123] | 5               | 5             | 0            | FULL         | 2     |
[cuda_graph.py:123] | 44              | 50            | 6            | FULL         | 1     |
...
--------------------------------------------------
total_num_output_tokens: 768
num_drafts: 402
num_draft_tokens: 1608
num_accepted_tokens: 365
mean acceptance length: 1.91
--------------------------------------------------
acceptance at token 0: 0.48
acceptance at token 1: 0.23
acceptance at token 2: 0.12
acceptance at token 3: 0.08

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables full CUDA graph support for the Eagle drafter model, which is a significant performance enhancement. The changes are well-structured, including necessary modifications for CUDA graph compatibility like in-place tensor updates and proper dummy run setup for graph capturing. I've identified one critical issue: a logging statement with incorrect formatting that will cause a TypeError at runtime. I've provided a suggestion to fix it. Overall, this is a great contribution.

Copy link
Contributor Author

@yiz-liu yiz-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are some questions I am not entirely sure about. Any comments?

Comment on lines +373 to +381
if not self.speculative_config.enforce_eager:
# This is a temprary mapping open to discussions
# FULL_AND_PIECEWISE -> PIECEWISE, FULL_DECODE_ONLY -> FULL
# PIECEWISE -> PIECEWISE, FULL -> FULL
eagle_cudagraph_mode = (
CUDAGraphMode.PIECEWISE
if cudagraph_mode.has_piecewise_cudagraphs()
else cudagraph_mode.decode_mode()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any other thoughts on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is updated by c79e3ac

Comment on lines +571 to +574
common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[
: batch_size + 1
]
common_attn_metadata.query_start_loc_cpu[: batch_size + 1] = torch.from_numpy(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we set query_start_loc or slot_mapping to a different buffer in the first place, but I assume it's always safe to use the original buffer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was not considering consistent addressing since we didn't have full graphs yet.

Comment on lines +909 to +910
# NOTE: For CUDA Graph, we need the `num_reqs_padded` here
batch_size = common_attn_metadata.num_reqs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another core change, as the input_batch.num_reqs != common_attn_metadata.num_reqs after padding, and I wonder if there is a better way to deal with this?

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Feb 19, 2026

@tomasruizt @LucasWilkinson Could you please take a look at this? Thanks!

Comment on lines 5588 to 5597
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
# we need to adjust the cudagraph sizes to be a multiple of the uniform
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have separate cudagraph capture
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
and self.uniform_decode_query_len > 1
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I check the comments here, since we already have separate capture sizes now, we can remove this condition right? This is to solve the num_speculative_tokens=2 issue.

@tomasruizt
Copy link
Contributor

@benchislett

@tomasruizt
Copy link
Contributor

tomasruizt commented Feb 19, 2026

@yiz-liu Thanks a lot for this PR!

Edit: Perhaps the observation below is just a matter of enabling full CG also for method=draft_model. Let me know! :)

I profiled your branch using the PyTorch profiler and found:

  • That the target model is running with full cuda graphs (CG) and dispatching the forward using cudaGraphLaunch.
  • However, the draft model is not using full cg: instead its launching many separate pytorch ops during the forward (see below, left circle is the target forward, and right circle is the draft forward).
sd-full-cg-profile

I attach the command I used to generate the trace as well as the PyTorch trace, which you can open in https://ui.perfetto.dev/.

Script: profile-4b-sd-0.6b.sh

Trace: rank0.1771496331250686317.pt.trace.json.gz

I assume that you are seeing no changes in performance whatsoever compared to main (TPOT, ITL). If correct, it means that some wiring up is still missing to enable full CG for the drafter. Let me know if I'm wrong or missing something.

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Feb 19, 2026

@tomasruizt Weird, I'll look into this later, in the meanwhile scripts and profiling are attached below:
image
FULL.zip
data_parallel.py

python3 data_parallel.py \
--model="/home/weight/Qwen3-30B-A3B-FP8" \
-dp=1 \
-tp=2

@tomasruizt
Copy link
Contributor

For EAGLE3, I'm observing the same phenomenon. I used gpt-oss-20b + eagle3.

  • target model forward dispatches to cudaGraphLaunch, while
  • draft model forward dispatches to a bunch of small ops
sd-eagle3-full-cg-profile

Profiling script: profile-gpt-oss-20b-eagle3.sh
PyTorch trace: dp0_pp0_tp0_dcp0_ep0_rank0.1771505307158794037.pt.trace.json.gz

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Feb 19, 2026

For EAGLE3, I'm observing the same phenomenon. I used gpt-oss-20b + eagle3.

  • target model forward dispatches to cudaGraphLaunch, while
  • draft model forward dispatches to a bunch of small ops

@tomasruizt Oh yeah I checked this scripts and profiling, I believe the behavior you're observing is due to the default CUDA graph mode, which resolves to FULL_AND_PIECEWISE. As I mentioned in this comment, this configuration correctly results in a PIECEWISE graph for the speculative decoding step. The resulting piecewise graph contains very few ops, which can make it appear as though no graph is active, but this is the expected outcome for that mode:

image

Could you please try explicitly setting the CUDA graph mode to FULL and re-running the profile?

This brings up a design question, I'll elaborate on my comment before: do you think we should change the default strategy to be more aggressive (i.e., prefer FULL over PIECEWISE)? My take is that with async scheduling now available, the host launch overhead for drafter model is likely masked by the target model's computation time, making PIECEWISE still a safe and robust default, which is exactly the reason I keep it. What are your thoughts?

@tomasruizt
Copy link
Contributor

tomasruizt commented Feb 19, 2026

If the target model runs in full cg, then the draft model should run in full cg if possible, right? The higher performance setting should be the default. What is the problem with setting full cg as a default?

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Feb 19, 2026

If the target model runs in full cg, then the draft model should run in full cg if possible, right? The higher performance setting should be the default. What is the problem with setting full cg as a default?

Yeah that's a good point. No problem at all, my initial design was just trying to honor the existing default behavior for consistency. However, I agree that prioritizing performance is the right way. I'll go ahead and update the PR. Of course, for others who might have concerns, this is still open for discussion.

Thanks for the valuable feedback!

@benchislett
Copy link
Collaborator

+1, behaviour should match the base model for consistency whenever possible. If base model uses full graphs for a certain shape, so should the drafter.

@Neo9061
Copy link

Neo9061 commented Feb 19, 2026

Thanks for the great work! Will the full CUDA can be applied to parallel-EAGLE as well? CC @benchislett

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Feb 20, 2026

@tomasruizt @benchislett Hi, please see the latest commit for the unified CUDA Graph mode, the target model and drafter model should share the same behavior now.

@mergify
Copy link

mergify bot commented Feb 21, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Feb 21, 2026

Hi @yiz-liu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 21, 2026

Hi @yiz-liu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@tomasruizt
Copy link
Contributor

@yiz-liu Are you able to generate the PyTorch profile? You can attach it once done as proof that the drafter uses cudaGraphLaunch. If for some reason you cannot, let me know, I probably can do it on Monday.

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Mar 2, 2026

@DingYibin I noticed in #34102 it's been tweaked again, anyway, I'll try to remove the dispatch and see if it works.

  File "/home/ubuntu/repos/vllm/vllm/v1/worker/gpu_model_runner.py", line 3752, in propose_draft_token_ids
    self._draft_token_ids = self.propose_draft_token_ids(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/repos/vllm/vllm/v1/worker/gpu_model_runner.py", line 4162, in propose_draft_token_ids
    draft_token_ids = self.drafter.propose(
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/repos/vllm/vllm/v1/spec_decode/eagle.py", line 694, in propose
    ret_hidden_states = self.model(**model_kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/repos/vllm/vllm/compilation/cuda_graph.py", line 246, in __call__
    validate_cudagraph_capturing_enabled()
  File "/home/ubuntu/repos/vllm/vllm/compilation/monitor.py", line 55, in validate_cudagraph_capturing_enabled
    raise RuntimeError(
RuntimeError: CUDA graph capturing detected at an inappropriate time. This operation is currently disabled.

Hmm, ran into this error again, I'll take a closer look at it.

@gorski-m
Copy link

gorski-m commented Mar 2, 2026

Hey @yiz-liu, I saw the changes on your PR, looks really cool! I wanted to check TPOT improvement on my setup, and saw one issue connected to acceptance rate. Specifically, when testing with batch_size > 1, the acceptance length drops from from ~4 to ~3 on my specific dataset and model, with num_speculative_tokens=5. For bs=1 the acceptance lenght is the same as without fix.
Did you manage to test the change with bs>1 and higher num_speculative_tokens?

Thanks for the feedback! I did test under those conditions, but my baseline acceptance length is quite low (~1.4), which might be why I did not observe any regression. Could you provide a minimal reproduction script or more details about your model/dataset setup?

Thank you. Perhaps that's an issue on my end then. I'll look into the setup closer, and if I'm able to reproduce it in a simple way, I'll share the details here

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Mar 3, 2026

How does this PR handle the two different shapes of EAGLE batch? We have:

  • First step with (1+K) tokens per request, in a uniform batch (when padded batch mode is on)
  • Remaining decoding steps with 1 token per request.

Do we need to record graphs for all the shapes? How do we handle the padding in the respective cases?

Oh sorry, I missed this comment before. Good catch. I've been struggling with the dispatching and cudagraph_keys since the recent rebases (including #34043 and #34102). I tried relaxing the batch_desc constraints, but it led to two issues:

  1. In FULL_DECODE_ONLY/PIECEWISE modes, the acceptance rate is abnormal after the first token (@gorski-m maybe this is the same issue with yours). See the stats below:
    acceptance at token 0: 0.42
    acceptance at token 1: 0.07
    acceptance at token 2: 0.01
    acceptance at token 3: 0.01
    
    while it should be:
    acceptance at token 0: 0.48
    acceptance at token 1: 0.23
    acceptance at token 2: 0.12
    acceptance at token 3: 0.08
    
  2. Switching to FULL mode triggers an illegal memory access error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered.

I suspect the regression is due to some mismatch between the first (1+K) step and the subsequent 1-token decoding steps. Maybe they are getting incorrect stale values and the illegal memory access may suggests that the CUDA Graph captured for the larger shape is accessing out-of-bounds memory when executed on the smaller 1-token batch? Do you have any insights on how to capture and dispatch graphs @benchislett , thanks a lot.

My next step is to handle these two scenarios independently by ensuring they are captured as distinct CUDA graph instances.

@mergify
Copy link

mergify bot commented Mar 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 4, 2026
@benchislett
Copy link
Collaborator

@LucasWilkinson might have some ideas. I'm not entirely sure how to avoid the issue without having an additional set of graphs, which seems like it would be a pain to maintain.

@mergify
Copy link

mergify bot commented Mar 17, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Mar 17, 2026
@yiz-liu
Copy link
Contributor Author

yiz-liu commented Mar 17, 2026

Sorry for the late reply, didn't feel well last week. I’ve been overhauling things the past couple of days. The main changes:

  1. Rebased onto the latest code.
  2. Fixed an issue in prepare_inputs_padded related to padding.
  3. Redesigned cudagraph_keys so step=0 and step>0 no longer share the same batch_desc as the key, the main differences are uniform_decode_query_len and uniform_decode.
  4. dummy_run now correctly constructs different data for step=0 and step>0 during capture.
  5. In propose, step=0 now directly inherits the target model’s batch_desc.uniform, so no extra checks are needed.

FULL_AND_PIECEWISE:
When target model is dispatched to PIECEWISE, the draft model will have PIECEWISE for step 0 and FULL for the rest of the steps. If target model is dispatched to FULL then draft model should have FULL for all steps.
FULL_AND_PIECEWISE

FULL:
For FULL mode, all steps should always be FULL as long as the num_tokens is small enough.
FULL

FULL_DECODE_ONLY:
For FULL_DECODE_ONLY, things are basically the same with FULL_AND_PIECEWISE except step 0 may be dispatched to NONE when uniform_decode=False.

Also, the acceptance rate should be OK now as you can see in the Test Result in PR description, @gorski-m .

Could you please review this PR again? Thanks. @benchislett @LucasWilkinson

@mergify
Copy link

mergify bot commented Mar 17, 2026

Hi @yiz-liu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

…ificantly improving inference performance by reducing CPU overhead during the draft speculative steps.

1. CudagraphDispatcher
* Added a for_draft_model flag to allow specialized graph capture logic for speculative decoding.
* Updated initialize_cudagraph_keys to capture graphs up to max_num_tokens specifically for steps > 0.
* Set uniform_decode_query_len as a independent parameter as when steps > 0, it should be 1.

2. EAGLE Proposer Updates
* Model Wrapping: The draft model is now wrapped in CUDAGraphWrapper when FULL mode is enabled and padding is not disabled.
* Metadata Padding: Fixed a potential crash by padding spec_decode_metadata.cu_num_draft_tokens to match the padded batch size.
* Refined Dispatching: Updated _determine_batch_execution_and_padding to return and pass BatchDescriptor objects, ensuring the runtime uses the correct graph key.
* Capture Logic: Enhanced dummy_run to simulate the actual speculative decoding steps during the graph capture phase.

3. GPUModelRunner
* Introduced supports_sd_full_graph to identify proposers (like EAGLE) that are compatible with FULL graph mode.
* Modified ExecuteModelState to track batch_desc, ensuring consistency between the target model and draft model.
* Ensured CommonAttentionMetadata is correctly passed to the drafter's warmup/dummy runs to facilitate accurate metadata building during capture.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
mgehre-amd added a commit to mgehre-amd/vllm that referenced this pull request Mar 20, 2026
…oposer

Cherry-pick 409a12e to enable FULL CUDAGraph mode for the EAGLE
proposer during draft speculative steps, reducing CPU overhead.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
mgehre-amd added a commit to mgehre-amd/vllm that referenced this pull request Mar 20, 2026
…oposer

Cherry-pick 409a12e to enable FULL CUDAGraph mode for the EAGLE
proposer during draft speculative steps, reducing CPU overhead.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
mgehre-amd added a commit to mgehre-amd/vllm that referenced this pull request Mar 20, 2026
- Pass spec_decode_common_attn_metadata to drafter.dummy_run() so the
  drafter can dispatch uniform_decode=True and match FULL batch keys
- Allow any non-NONE cudagraph mode during capture (not just PIECEWISE)
  so the drafter's FULL CUDAGraphWrapper actually triggers capture
- Add hasattr fallback for get_eagle3_aux_hidden_state_layers to support
  models like Qwen3 that only have the default method
- Add _dump_all_full_graphs() call after capture for hipGraph debugging
- Re-apply PR vllm-project#34880 changes lost during merge with awq_gemv_ifdef_sweep

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd
Copy link
Contributor

I'm happy about seeing eagle drafters in CUDA Graph!
Is there a fundamental reason why we cannot capture the eagle drafter together with the base model in a single CUDA graph? Instead of capturing each of them into their own CUDA graph?

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Mar 22, 2026

I'm happy about seeing eagle drafters in CUDA Graph! Is there a fundamental reason why we cannot capture the eagle drafter together with the base model in a single CUDA graph? Instead of capturing each of them into their own CUDA graph?

I think the main blocker is the D2H/H2D transfer between the target model and the draft model, for example in prepare_next_token_ids_padded and prepare_inputs_padded?

I am not sure yet whether we can get rid of those transfers in the future, and I will check Model Runner v2 design later.

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Mar 22, 2026

@LucasWilkinson Hi, I noticed that similar features have been added in MRV2 (#35959 ) over the past couple of weeks. I’ll take a look as well, but would you mind reviewing this PR when you get a chance? I think V1 still needs this support, please let me know what you think, any help would be greatly appreciated. Thanks!

@mergify
Copy link

mergify bot commented Mar 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants