-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve outlines.processors
, add integration tests to test_generate.py
#998
Conversation
@@ -14,10 +14,11 @@ def is_mlx_array(logits): | |||
return isinstance(logits, mx.array) | |||
|
|||
|
|||
class BaseLogitsProcessor(Protocol): | |||
class OutlinesLogitsProcessor(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not call it LogitsProcessor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, do we really need that class if all logit processors are ever going to be FSMLogitsProcessors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking here was if we want a new logits processor which isn't for structured generation, e.g. RepetitionPenaltyLogitsProcessor
. We can easily implement this processor once and use it in every outlines.models
without any changes specific to the model.
This provides us the flexiiblity to have the inference library handle the decoder pass and sampling, and outlines
can handle all logits augmentation.
Another example: mlxlm
doesn't support stop_strings
and stop_strings
is broken in the latest transformers
version (4.41.2
). We could implement a single StopStringsLogitsProcessor
and make it available to all inference libraries. Not arguing that implementing this specific logits processor should be a priority, but I am arguing that having a separate base class which doesn't demand an FSM provides us flexibility and opportunities.
) | ||
|
||
# ensure logits are torch Tensors | ||
torch_logits = self._to_torch(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the cost of this conversion? We need to profile this.
I would personally handle this by dispatching the core logic used to process the logits, process_logits
depending on the type of input_ids
or logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that we can cast numpy
and mlx.core
arrays to torch.Tensor
using shared memory, not a copy operation :)
Then we can implement a single OutlinesLogitsProcessor
subclass with a single process_logits()
method and it works out of the box for any library. While the tests aren't in main
yet, I've tested FSMLogitsProcessor
, and RegexLogitsProcessor
with nearly all outlines.models
options (haven't tested against exllamav2) and they work with no additional changes.
This makes the processors easy to maintain and implement.
self.fsm: Guide = fsm | ||
self._is_first_token = True | ||
self._seq_start_idx: Optional[int] = None | ||
|
||
def process_logits( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf above remark on dispatching the function depending on the type of the tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not convert back-and-forth to torch
arrays and instead dispatch the logic depending on the type of input_ids
and logits
so as to not add a performance penalty during inference.
6525fd4
to
7301a7f
Compare
0afefbf
to
6362217
Compare
Here are the benchmark results:
The performance penalty for casting (not copying)
|
19077a6
to
00052d5
Compare
00052d5
to
f00307a
Compare
f00307a
to
c07de55
Compare
c07de55
to
c3e8673
Compare
A lot of these fixes were intended for #966 however that's blocked until there's a new
transformers
release.These improvements are general to all models and will enable PRs resolving #806 and #965
Structure of
OutlinesLogitsProcessor
The goal is to create a base class which allows a logits processors to be implemented once and used for any
outlines.models
inference library.To accomplish this we must normalize the input array. It must have a consistent type (
torch.Tensor
) and consistent dimensionality (2). We can normalize both of these simply, and without any copy operations.mlx.core.array
,numpy.array
, andtorch.Tensor
all support pythons array standard__dlpack__
. This standard allows for casting between array types without copying.torch.Tensor
is the only input type which cannot always be cast to any other type because torch tensors may live in GPU memory. Therefore, we cast all arrays totorch.Tensor
, implement logits processors using torch methods, and convert back to the original array type inOutlinesLogitsProcessor
. See docstring ofOutlinesLogitsProcessor.__call__()
for more details.Detailed Changes
BaseLogitsProcessor
toOutlinesLogitsProcessor
OutlinesLogitsProcessor.process_logits()
is always passed a 2D batch request withtorch.Tensor
logits andList
input_ids. Also clean up code to be more readable inOutlinesLogitsProcessor__call__()
FSMLogitsProcessor
allows unstable sequence ordering (beam search in transformers and vLLM change the order of sequences)tests/generate/test_generate.py
to cover more permutations ofstream()
/generate()
benchmark_processors.py