From 2c068153b245ec75b146d15fac9bc746d290d106 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 15 Apr 2025 03:45:18 +0000 Subject: [PATCH 1/2] Fix broken GritLM model and tests Pass pooling_metadata to pooler head in gritlm. This was broken by PR https://github.com/vllm-project/vllm/pull/16331 broke gritlm. PR https://github.com/vllm-project/vllm/pull/14516 broke gritlm tests due to changing xformers to flash_atnn Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 21 ++++++++++--------- vllm/model_executor/models/gritlm.py | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index d6bf7d270639..87a1dde9381f 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): def server_embedding(): # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") def server_generate(): args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture -async def client_embedding(monkeypatch: pytest.MonkeyPatch, - server_embedding: RemoteOpenAIServer): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - async with server_embedding.get_async_client() as async_client: - yield async_client +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client @pytest_asyncio.fixture diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 2984f2241286..ad60d2b600b3 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -170,7 +170,7 @@ def forward( mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = self.head(mean_embeddings) + pooled_data = self.head(mean_embeddings, pooling_metadata=pooling_metadata) pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data From bec5700ab08d13e73a864500bfbc7cec5be2047b Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 15 Apr 2025 04:18:55 +0000 Subject: [PATCH 2/2] Fix pre-commit Signed-off-by: Pooya Davoodi --- vllm/model_executor/models/gritlm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ad60d2b600b3..e4692c458088 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -170,7 +170,8 @@ def forward( mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = self.head(mean_embeddings, pooling_metadata=pooling_metadata) + pooled_data = self.head(mean_embeddings, + pooling_metadata=pooling_metadata) pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data