diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 002a129272d8..7bb9bc23442b 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -27,6 +27,9 @@ def __init__(self, rid, input_text, input_ids): self.input_ids = input_ids self.output_ids = [] + # for accumulated prompt tokens from jump forward + self.orig_prompt_tokens = len(input_ids) + # For vision input self.pixel_values = None self.image_size = None diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index dfcf34378e67..b57f6ce52c4d 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -534,10 +534,16 @@ def handle_finished_requests(self, batch: Batch): output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) + + # For the length of input_ids, which will be accumulated during jump-forward. + # Use the original length of input_ids to calculate the token usage info. meta_info = { - "prompt_tokens": len(req.input_ids), - "completion_tokens": len(req.output_ids), + "prompt_tokens": req.orig_prompt_tokens, + "completion_tokens": len(req.input_ids) + + len(req.output_ids) + - req.orig_prompt_tokens, } + if req.return_logprob: meta_info["prompt_logprob"] = req.logprob meta_info["token_logprob"] = req.token_logprob