Skip to content

Code execution & tool use#204

Closed
KiddoZhu wants to merge 141 commits intomainfrom
tool-calling
Closed

Code execution & tool use#204
KiddoZhu wants to merge 141 commits intomainfrom
tool-calling

Conversation

@KiddoZhu
Copy link
Contributor

@KiddoZhu KiddoZhu commented Apr 16, 2025

Add a basic version of code execution & tool use (#55, #56)

Implemented by a custom LogitProcessor in vllm backend. I implement tools as pre-defined variables in code environment so that there is no need to distinguish code and tool use. It also makes maintenance easier. The only concern is that code execution must be enabled whenever we want to enable tool use, though we can avoid telling the existence of arbitrary code execution beyond the tools in the prompt.
 

Supported features

  • On-the-fly multiple tool calls.
  • Compatible with batch decoding.
  • Stateful code executor. In the same generation, functions and variables will be passed to the next code snippet.

Examples (test_vllm_tools.py)

<code>x = 3; y = 4</code>
This is some regular text.
<code>x + y</code>
<result>7</result>
<code>retrieve('Jen-Hsun Huang')</code>

<result>
['Nvidia was established in 1993 by Jen-Hsun Huang, Curtis Priem, and Chris '
 'Malachowsky. In 2000 Nvidia took intellectual possession of 3dfx, one of the '
 'biggest GPU producers in 1990s.']
</result>

Tokenizer issue

Although I design CodeLogitProcessor to be tokenizer-agnostic as much as possible, there may be side cases where tokens don't split exactly at the end of </code>. For example, the tokenizer of GPT-4o will generate the following

... </##code##>x

If we tweak >x to be >, it will change the log prob in RL. My current solution is to not touch any generated token and directly append results afterwards

... </code>x<result> ...

chtruong814 and others added 30 commits March 18, 2025 17:08
Signed-off-by: Charlie Truong <chtruong@nvidia.com>
Signed-off-by: Charlie Truong <chtruong@nvidia.com>
Co-authored-by: oliver könig <okoenig@nvidia.com>
Co-authored-by: Sahil Jain <sahil.jain5125@gmail.com>
Co-authored-by: Parth Chadha <parth29@gmail.com>
Co-authored-by: Anna Shors <ashors@nvidia.com>
Co-authored-by: Gerald Shen <geshen@nvidia.com>
Co-authored-by: Yuki Huang <yukih@nvidia.com>
Co-authored-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Yi-Fu Wu <yifuw@nvidia.com>
Co-authored-by: ahmadki <ahmadki@users.noreply.github.com>
Co-authored-by: Nathan McKimpson <nmckimpson@nvidia.com>
Co-authored-by: Charlie Truong <chtruong@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Charlie Truong <chtruong@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
- flatten hyperparams for tb no longer errors for lists (was an issue for schedulers)
- the submission script now overlaps the head on the first worker (no longer needs extra node just for head)
- fixes the CI to handle weird permissions issues
- added sphinx build and doctest to CI
- added functional tests to CI
- nuked an old example
- added docs for functional tests
- --no-container-mount-home
- fix a unit tests that expected cuda to skip
- allow running unit tests on slurm head node with no gpu
- add a hermetic script to run functional tests

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
…cedence (#25)

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
…x math (#28)

Signed-off-by: Sahil Jain <sahilj@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
…de (#39)

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
#78)

Signed-off-by: Parth Chadha <pchadha@nvidia.com>
…rge-able) (#32)

Signed-off-by: Sahil Jain <sahilj@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
parthchadha and others added 12 commits April 28, 2025 15:16
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
Signed-off-by: Andrew Schilling <aschilling@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Parth Chadha <parth29@gmail.com>
Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
@KiddoZhu
Copy link
Contributor Author

Now we have an implementation based on stop strings, making it independent of vllm and hf backends.

The function generate_with_code_and_tools(policy, input_batch, tokenizer) can be used as a drop-in replacement of policy.generate(input_batch). It supports

  • Multiple code execution & tool use. It will continue generation until hitting an EOS token or user-specified stop strings.
  • Batch code execution based on ray.remote (@SahilJain314 Could you please check if we need to take special care of the ray workers here?)

@SahilJain314 The way we calculated generated_lengths in HFPolicyWorker is buggy when the generation is stopped by custom stop strings (e.g. </code> in code execution). It will count one more EOS token after the stop strings, since EOS token is also used for padding. I changed it to count the nonzero entries of generation_logprobs, assuming that any generated logprob is different from the padding value 0. Please let me know if you foresee any issue with this implementation.

I'll merge the main branch and clean up the logit processor implementation.

@KiddoZhu KiddoZhu requested a review from parthchadha April 29, 2025 21:26
@KiddoZhu
Copy link
Contributor Author

Done. Ready for review.

.gitignore Outdated
dist/
*.egg-info/
*.vscode/
uv.lock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@terrykong afaik this shouldn't be gitignored, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we should not ignore the lock

else:
gen_length = len(generated_part)

gen_length = (generated_logprob != 0).sum().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumption doesn't hold. Generated logprobs are actually surprisingly sometimes 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then how can we safely decide the length? The old implementation mistakes padded eos token as generated eos token. It's a serious problem for custom stop strings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace this with the existing run_multi_turn_generation function? and use environments to handle the stop tokens as we do right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. Will code execution always be invoked in a multi-turn chat format? How about changing run_multi_turn_rollout to normal prompt format, and then provide an additional interface on top of that to support chat messages?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default multi turn chat actually runs without chat templating (the template is applied in the string/tokens directly). To chat template it, the environment would have to do it. (Open to suggestions on this approach)

@SahilJain314
Copy link
Contributor

I'm a little hesistant to push this out without filesystem isolation

@KiddoZhu
Copy link
Contributor Author

KiddoZhu commented May 2, 2025

I'm a little hesistant to push this out without filesystem isolation

Good point! What's the solution in your mind? I can think of chroot or docker, but they are too heavy. If we believe the code is just untrusted, not malicious, a simple solution is to 1) run the code in a temporary directory 2) override open() to deny access beyond the temporary directory 3) override builtins.__import__ to deny modules like os, sys and subprocess.

@KiddoZhu
Copy link
Contributor Author

KiddoZhu commented May 6, 2025

Sorry I messed up the commit history. This branch will be no longer touched but kept for retrieving history. I will open a new branch and pull request.

@KiddoZhu KiddoZhu deleted the tool-calling branch May 6, 2025 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.