-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Text generation inference integration #12
Open
andrewramsay
wants to merge
60
commits into
main
Choose a base branch
from
text_generation_inference
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Remove existing requirements and replace with huggingface_hub - Use oat_common as a base since we don't need CUDA support now
I've replaced the existing code to load the model with creating an `InferenceClient` object using the endpoint URL defined in the docker-compose.yml file. Creating the object doesn't trigger any connection, so it currently submits a simple query to check if the TGI endpoint is actually available. This might not happen immediately (e.g. if it's still downloading or loading a model), it currently has a basic retry setup but might need some more thought put into it. The `call_model` and `batch_call_model` methods are updated to call the `.text_generation` method on the `InferenceClient`. For the batch method, it submits things in parallel using a ThreadPoolExecutor, TGI doesn't offer a batch-specific endpoint but it should automatically batch the requests internally based on the docs.
- Add some env vars with default values to the tgi service definition to allow easy control of some options that might need changed depending on model/hardware - Add a wrapper script as the entrypoint for the tgi container to allow passing in extra CLI parameters using a TGI_PARAMS env var
These 2 env vars can be used to adjust the number of retries llm_functionalities will make when attempting to connect to the TGI endpoint, and the delay between successive retries.
There seems to be a bug in the `InferenceClient.summarization` method in the recent officially released versions, it's fixed in the current development version.
Use the new env vars from `docker-compose.yml` to control the connection attempts to the TGI endpoint
This just adds TGI equivalents for the `generate_summary` and `generate_summaries` methods that pass requests through to the `InferenceClient.summarization` method (https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.summarization)
This is currently used by one of the LLM tests to check if a particular step in the policy was activated to produce the current response.
Adding `delete_taskmap` and `delete_session` methods. These aren't much use for the online system, they're used in some new tests.
This is to make some new tests simpler to implement, not relevant for the online system. - Define new `delete_session` and `delete_taskmap` RPCs - Add `delete_session` and `delete_taskmap` methods to `DynamoDB` - Add a `delete` method to `ProtoDB` which those 2 methods call
This commit removes most of the timeout-handling and exception-checking for the LLM code in `functionalities`. Previously the system couldn't assume `llm_functionalities` would be available, so there had to be error handling for failed RPC calls within the various `llm_*` components in `functionalities`. There also had to be timeout handling in some cases to prevent lengthy LLM calls from delaying system responses. In the TGI-enabled version of OAT, `llm_functionalities` has no GPU requirements any more so we can probably assume it's always going to be available like the other services. That also means we can just do the timeout/error handling in `llm_functionalities` since all calls to TGI will be routed through there. This ultimately means the code around the RPCs to `llm_functionalities` can be simplfiied to remove the existing timeout and exception handling.
More removal of timeout and error handling around RPCs to `llm_functionalities` (now handled when `llm_functionalities` makes calls to TGI).
Various changes: - Remove special handling of summary requests - Define a default timeout for TGI calls - Update `_check_connectivity` to return True/False rather than throw an exception, to allow for returning empty responses to the client instead - Add a `_call_with_timeout` wrapper method for the two `call_model` methods to allow them to submit TGI requests with a timeout applied - Timeouts are set from the `ModelRequest` objects or if not set there the default value is used
Python strings don't support `foo[1] = "."` assignments, so this would throw an exception if it was executed
- Remove old TGI_CONNECTION_* env vars - Add a volume to the TGI container to load local models from - Update comments - Adjust default SHM size - Set the default MODEL_ID to be the local Alpaca model
Instead of reloading the session from the database we can just check the JSON object returned by the orchestrator, since it should include some of the changed text.
The previous version of this class entered a loop on startup to periodically check if the TGI endpoint address was connectable. If it was able to make a connection it assumed that it would be remain valid indefinitely. This version should be a bit more flexible in that it now only attempts to connect when LLM requests are triggered. It should also handle the endpoint becoming unavailable and then available again because the first request sent after it comes back up should create a new client object for future requests. The other important change here is the use of dnspython to test if the TGI endpoint hostname is resolvable. I found that if this isn't true (e.g. if you launch OAT with a remote TGI endpoint that is still starting up), the DNS resolution process takes 10+ seconds to timeout. This seems to be difficult to handle using the Python stdlib, but the dnspython package makes it very simple to set a timeout on the resolution process. This allows the per-request connection checking to work without excessive delays.
While I was trying to add some LLM tests, I found that this class didn't seem to be set up to parse the model responses correctly. The original prompt asks it to produce output in the format: > {"step_text": "<model output>"} and then ends the prompt with: > {"step_text:": " The Alpaca model we've been using does successfully complete this with something like: > one two three"} but the problem is the `extract_response` method assumes that the generated text will be a complete parseable JSON string, not a partial one. This leads the method to return an empty dict as if the response had failed to generate any text, when it will normally have generated something valid. I've changed the prompt so it will output the text without any extra formatting, removed `extract_response` because it's now not required, and removed `process_response_text` because it seemed to be already unused.
If the input string was empty, these methods would throw an IndexError when calling `re.search` (because of trying to do `""[-1]`)
The prompt for this type of request expects a JSON-compatible string response with 2 fields. However the prompt already includes the first part of the expected string (`"{\"name\"`) and the parsing fails because it's run only on the response, which is an incomplete JSON string. This just adjusts the parsing method to add the section of the response format included in the prompt to the start of the actual response, allowing the parsing to function as intended.
By keeping this in LLM functionalities it allows the `Downloader` class to continue working as normal, even though it's technically going to be downloading files for use by TGI
This helps make clear which files are local and which have been downloaded from huggingface.co
The connection will be attempted for the first time when the first LLM request is received instead
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is about replacing the original custom LLM implementation in OAT with HuggingFace's Text Generation Inference framework.
The core change here is refactoring OAT's architecture around LLM requests.
In the current
main
version of OAT, online services making RPCs tollm_functionalities
which then makes calls to the configured alpaca_llm model loaded into its container. This obviously meansllm_functionalities
requires a GPU, and so the other OAT services have to be prepared for it to be unavailable if the system is running somewhere without sufficient GPU resources.With TGI integrated,
llm_functionalities
now becomes a relatively thin wrapper around calls to a TGI endpoint using the InferenceClient API. This allows for most of the error/timeout handling code around LLM calls in the current codebase to be removed and concentrated inllm_functionalities
, since all LLM requests will get routed through there before reaching the TGI endpoint.Edited to include features since the PR was created:
shared/file_system/downloads/llm_functionalities/<model_name>
and then settingMODEL_ID
to/models/<model_name>
tgi
to distinguish between local models and downloaded modelsllm_functionalitiies
still downloads thealpaca_llm
model but stores it in the location corresponding to thetgi
Docker volume for local modelsllm_functionalities
has a different way of connecting to the TGI endpoint. Originally it would enter a loop attempting to connect and give up after some number of retries. Now it will only attempt to connect when the first LLM request arrives from OAT (returning empty responses when the endpoint is unavailable), and it should also be able to handle the case where the endpoint goes down and comes back up while OAT is running. It also includes a fix for the case where the endpoint's hostname is unresolvable, this was causing a long timeoutshared/tests/integration_tests/test_llm.py
. They don't cover all the LLM components but it's a start at least (some of them are awkward to trigger)Testing
To run the new LLM tests:
docker compose up
(this should download thealpaca_llm
folder toshared/file_system/downloads/tgi/local/
)docker compose run tester --runslow -vk test_llm
should run the 3 LLM tests