Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
a902e46
log answer key to wandb
BjarniHaukur Mar 28, 2025
1fa6c69
all Table
BjarniHaukur Mar 28, 2025
46bf3f8
HTML logging
BjarniHaukur Mar 29, 2025
338730e
table
BjarniHaukur Mar 29, 2025
c9777d3
bump patch
BjarniHaukur Mar 29, 2025
d81295a
hmm
BjarniHaukur Mar 29, 2025
bd92df8
formatting
BjarniHaukur Mar 29, 2025
05610d6
html esacape
BjarniHaukur Mar 29, 2025
908d739
reward isnt string
BjarniHaukur Mar 29, 2025
99fd8de
sync fork
BjarniHaukur Apr 2, 2025
3943e72
preliminary openai compatible endpoint
BjarniHaukur Apr 4, 2025
1723c56
early concept, needs refining
BjarniHaukur Apr 4, 2025
1491cb1
dedupe
BjarniHaukur Apr 4, 2025
9a9c416
debug print
BjarniHaukur Apr 5, 2025
57557da
some slop to work on
BjarniHaukur Apr 6, 2025
044a490
unslop, missing hist
BjarniHaukur Apr 6, 2025
2a3c178
almost valid pseudocode
BjarniHaukur Apr 6, 2025
faf116c
middle-ware monkey patch in mp.Pool()...
BjarniHaukur Apr 6, 2025
7f2c730
remove unused
BjarniHaukur Apr 7, 2025
1348c23
More accurate .md
BjarniHaukur Apr 7, 2025
fb79fb6
need gpu
BjarniHaukur Apr 9, 2025
b16f072
renting lambda again
BjarniHaukur Apr 9, 2025
14be11b
much nicer
BjarniHaukur Apr 9, 2025
7af1273
small
BjarniHaukur Apr 9, 2025
4506bad
aider-chat and datasets conflict
BjarniHaukur Apr 9, 2025
8388e83
risky reqs change
BjarniHaukur Apr 9, 2025
63088e0
should work, but hacky
BjarniHaukur Apr 10, 2025
8b0ed76
some insights, but monkeypatching probably wont suffice
BjarniHaukur Apr 10, 2025
cea5eec
refactor: Rewrite test script to use SWE-bench dataset with MultiProc…
BjarniHaukur Apr 10, 2025
50ea732
refactor: Remove logging statements from test.py
BjarniHaukur Apr 10, 2025
55d24a6
one step closer
BjarniHaukur Apr 10, 2025
1b6e42b
finally, the correct abstraction
BjarniHaukur Apr 11, 2025
c607e8f
doc
BjarniHaukur Apr 11, 2025
cbf90d9
todo
BjarniHaukur Apr 11, 2025
7dd119f
unslop
BjarniHaukur Apr 12, 2025
ce53aa3
unslop
BjarniHaukur Apr 12, 2025
0d51a14
undo accidental black
BjarniHaukur Apr 12, 2025
9aabf0a
cleaner abstraction
BjarniHaukur Apr 14, 2025
18f38a7
new abstraction
BjarniHaukur Apr 15, 2025
38ef847
💎 Gemma 3 VLM SFT example script for single-image and multi-image (#3…
sergiopaniego Mar 26, 2025
caa3be3
merge
BjarniHaukur Apr 16, 2025
d11f0f5
upstream
BjarniHaukur May 13, 2025
8c67add
misssed one
BjarniHaukur May 13, 2025
3df4d61
working
BjarniHaukur May 14, 2025
d722c5b
baby steps
BjarniHaukur May 15, 2025
81e99ca
getting closer
BjarniHaukur May 15, 2025
64a6d6e
temporary fix
BjarniHaukur May 16, 2025
2618826
close to being ready
BjarniHaukur May 16, 2025
50624c7
Merge remote-tracking branch 'upstream/main'
BjarniHaukur May 16, 2025
0e762a6
almost pull ready
BjarniHaukur May 16, 2025
7c48129
undo rename
BjarniHaukur May 16, 2025
e2c4c4a
accidental
BjarniHaukur May 16, 2025
42dca01
fixes
BjarniHaukur May 16, 2025
415d91e
comment
BjarniHaukur May 16, 2025
2c81b08
sync with upstream
BjarniHaukur May 16, 2025
833d849
bug bump
BjarniHaukur May 17, 2025
5e9ee94
missing
BjarniHaukur May 17, 2025
207ab2f
we should probably warn users to not use 8bit
BjarniHaukur May 17, 2025
adf6508
new dev branch
BjarniHaukur May 18, 2025
80617ec
refactor
BjarniHaukur May 18, 2025
36da30c
correct logging
BjarniHaukur May 19, 2025
0afcbb8
removed usecase specific
BjarniHaukur May 19, 2025
c01eebe
pr cleanup
BjarniHaukur May 19, 2025
48fec41
pr cleanup
BjarniHaukur May 19, 2025
5e6b02f
pr cleanup
BjarniHaukur May 19, 2025
f4e4f11
pr cleanup
BjarniHaukur May 19, 2025
cd87130
pr cleanup
BjarniHaukur May 19, 2025
254510c
pr cleanup
BjarniHaukur May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,209 @@ def tearDownClass(cls):
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()


@pytest.mark.slow
@require_torch_multi_gpu
class TestVLLMClientServerAsync(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve-async", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)

#Initialize the client
cls.client = VLLMClient(connection_timeout=240)
cls.client.init_communicator()

def test_generate(self):
prompt = "Hello, AI! Tell me a joke."
response = self.client.session.post(
url="http://localhost:8000/v1/completions",
json={
"model": self.model_id,
"prompt": prompt,
"max_tokens": 50
}
)
response.raise_for_status()
response_json = response.json()

# Check basic response structure
self.assertIn("choices", response_json)
self.assertGreater(len(response_json["choices"]), 0)

# Check that we got a non-empty text response
first_choice = response_json["choices"][0]
self.assertIn("text", first_choice)
self.assertGreater(len(first_choice["text"]), 0)

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()


@pytest.mark.slow
@require_3_gpus
class TestVLLMClientAsyncServerTP(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve-async", "--model", cls.model_id, "--tensor_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)

# Initialize the client
cls.client = VLLMClient(connection_timeout=240)
cls.client.init_communicator()

def test_generate(self):
prompt = "Hello, AI! Tell me a joke."
response = self.client.session.post(
url="http://localhost:8000/v1/completions",
json={
"model": self.model_id,
"prompt": prompt,
"max_tokens": 50
}
)
response.raise_for_status()
response_json = response.json()

# Check basic response structure
self.assertIn("choices", response_json)
self.assertGreater(len(response_json["choices"]), 0)

# Check that we got a non-empty text response
first_choice = response_json["choices"][0]
self.assertIn("text", first_choice)
self.assertGreater(len(first_choice["text"]), 0)

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()


@pytest.mark.slow
@require_3_gpus
class TestVLLMClientAsyncServerDP(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve-async", "--model", cls.model_id, "--data_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)

# Initialize the client
cls.client = VLLMClient(connection_timeout=240)

def test_generate(self):
prompt = "Hello, AI! Tell me a joke."
response = self.client.session.post(
url="http://localhost:8000/v1/completions",
json={
"model": self.model_id,
"prompt": prompt,
"max_tokens": 50
}
)
response.raise_for_status()
response_json = response.json()

# Check basic response structure
self.assertIn("choices", response_json)
self.assertGreater(len(response_json["choices"]), 0)

# Check that we got a non-empty text response
first_choice = response_json["choices"][0]
self.assertIn("text", first_choice)
self.assertGreater(len(first_choice["text"]), 0)

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()
9 changes: 9 additions & 0 deletions trl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from .scripts.utils import TrlParser
from .scripts.vllm_serve import main as vllm_serve_main
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
from .scripts.vllm_serve_async import main as vllm_serve_async_main
from .scripts.vllm_serve_async import make_parser as make_vllm_serve_async_parser


def main():
Expand All @@ -45,6 +47,7 @@ def main():
make_kto_parser(subparsers)
make_sft_parser(subparsers)
make_vllm_serve_parser(subparsers)
make_vllm_serve_async_parser(subparsers)

# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
# Duplicates may occur if the same argument is provided in both the config file and CLI.
Expand Down Expand Up @@ -138,7 +141,13 @@ def main():
)

vllm_serve_main(script_args)

elif args.command == "vllm-serve-async":
# Here we defer to vllm's argument parser, so that we don't have to reimplement all of its logic
sys.argv = ["trl/scripts/vllm_serve_async.py"] + launch_args
vllm_serve_async_main()


if __name__ == "__main__":
main()

5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_requests_available = _is_package_available("requests")
_unsloth_available = _is_package_available("unsloth")
_uvicorn_available = _is_package_available("uvicorn")
_uvloop_available = _is_package_available("uvloop")
_vllm_available = _is_package_available("vllm")
_vllm_ascend_available = _is_package_available("vllm_ascend")
_joblib_available = _is_package_available("joblib")
Expand Down Expand Up @@ -80,6 +81,10 @@ def is_uvicorn_available() -> bool:
return _uvicorn_available


def is_uvloop_available() -> bool:
return _uvloop_available


def is_vllm_available() -> bool:
return _vllm_available

Expand Down
Loading