-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[api] implements text-generation search algorithm #2637
Conversation
4490def
to
24d3a8e
Compare
examples/src/main/java/ai/djl/examples/inference/GPTInference.java
Outdated
Show resolved
Hide resolved
95ddf0e
to
8d91ef7
Compare
private boolean suffixPadding; | ||
|
||
/** Constructs a new ContrastiveSearchConfig object with default values. */ | ||
public SearchConfig() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any plans to support different configurations since not all of the text generation models are the same? I'm personally more interested in T5 than GPT2. T5 in particular is a different beast with both a decoder and encoder in contrast to GPT2's decoder-only approach. T5 also supports over a hundred special tokens. There's 100 "extra" tokens that can be used for a variety of things including fill masks and potentially representing special words/instructions in the generated output.
https://huggingface.co/transformers/v3.0.2/model_doc/t5.html#t5tokenizer
There probably needs to be different configurations and generation classes for each of the family of models out there. If we hardcode everything to GPT2, then there's going to be breaking changes in the future. I'd suggest adding support for two different models starting out and coming up with a solution for adding support for others in the future..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the searchConfig, I'm thinking of just adding parameters into it. Not necessarily all of them are used in a single model. This should solve the issue about different search configurations you mentioned, right?
I'm just leaving my 2 cents since I'm interested in your work. It's great to see it get added to DJL. My only real fear is that there might need to be a lot of refactoring to get support for other models. Flan-T5 is one of the most powerful open source models (that supports commercial use) we have available and it has a variety of sizes available. I'd be most interested in seeing it be supported. I'm not super familiar with GPT2 aside from it being a decoder-only model. There's a chance that T5 and GPT2 share some similarities in the decoder aspect, but T5 has an initial encoder pass on the initial inputs. The hidden state of the encoded inputs are then used for each pass of the decoder alongside the ids that have been currently selected for generation. |
That's all I've got to add, good work. I just want to see this turn into a bigger feature beyond what you're working on. |
We are planning to add T5 model. This is just a starting point to add textgeneration support. |
@jawaff Thanks for pointing out the encoder-decoder model T5 to us and reminding us of the possible refactoring. I think to implement encoder-decoder model, the major edition will be in the search algorithms, where we will need an if (encoderDecoder is true) block, which computes the encoding). The rest part of the code will basically be shared. This structure is seen in huggingface transformer. |
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #2637 +/- ##
============================================
- Coverage 72.08% 72.06% -0.03%
- Complexity 5126 7020 +1894
============================================
Files 473 698 +225
Lines 21970 31252 +9282
Branches 2351 3224 +873
============================================
+ Hits 15838 22521 +6683
- Misses 4925 7200 +2275
- Partials 1207 1531 +324
☔ View full report in Codecov by Sentry. |
This PR succeeds PR #2547 #2509 #2557, #2572 which contains the benchmark outputs of the searching results.
This PR contains only the features of LMSearch.
djl/examples/src/main/java/ai/djl/examples/inference/GPTInference.java
contains the front_end design.The model conversion to torchscript and onnx
See the Model tracing section in #2547 #2509 's PR description.
Demonstration
The PR #2723 provides several examples to demonstrate the usage of the language model text generation.