Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve outlines.processors, add integration tests to test_generate.py #998

Merged
merged 1 commit into from
Jun 30, 2024

Conversation

lapp0
Copy link
Contributor

@lapp0 lapp0 commented Jun 21, 2024

A lot of these fixes were intended for #966 however that's blocked until there's a new transformers release.

These improvements are general to all models and will enable PRs resolving #806 and #965

Structure of OutlinesLogitsProcessor

The goal is to create a base class which allows a logits processors to be implemented once and used for any outlines.models inference library.

To accomplish this we must normalize the input array. It must have a consistent type (torch.Tensor) and consistent dimensionality (2). We can normalize both of these simply, and without any copy operations.

mlx.core.array, numpy.array, and torch.Tensor all support pythons array standard __dlpack__. This standard allows for casting between array types without copying.

torch.Tensor is the only input type which cannot always be cast to any other type because torch tensors may live in GPU memory. Therefore, we cast all arrays to torch.Tensor, implement logits processors using torch methods, and convert back to the original array type in OutlinesLogitsProcessor. See docstring of OutlinesLogitsProcessor.__call__() for more details.

Detailed Changes

  • Rename BaseLogitsProcessor to OutlinesLogitsProcessor
  • Ensure OutlinesLogitsProcessor.process_logits() is always passed a 2D batch request with torch.Tensor logits and List input_ids. Also clean up code to be more readable in OutlinesLogitsProcessor__call__()
  • Ensure FSMLogitsProcessor allows unstable sequence ordering (beam search in transformers and vLLM change the order of sequences)
  • Update tests/generate/test_generate.py to cover more permutations of
    • regex / text
    • batch / single
    • greedy / multinomial / beam search
    • stream() / generate()
  • Ensure performance stability with difference array libraries through benchmark_processors.py

@lapp0 lapp0 marked this pull request as ready for review June 21, 2024 17:50
@lapp0 lapp0 requested a review from rlouf June 21, 2024 17:55
@rlouf rlouf added enhancement structured generation Linked to structured generation labels Jun 22, 2024
@@ -14,10 +14,11 @@ def is_mlx_array(logits):
return isinstance(logits, mx.array)


class BaseLogitsProcessor(Protocol):
class OutlinesLogitsProcessor(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call it LogitsProcessor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, do we really need that class if all logit processors are ever going to be FSMLogitsProcessors

Copy link
Contributor Author

@lapp0 lapp0 Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking here was if we want a new logits processor which isn't for structured generation, e.g. RepetitionPenaltyLogitsProcessor. We can easily implement this processor once and use it in every outlines.models without any changes specific to the model.

This provides us the flexiiblity to have the inference library handle the decoder pass and sampling, and outlines can handle all logits augmentation.

Another example: mlxlm doesn't support stop_strings and stop_strings is broken in the latest transformers version (4.41.2). We could implement a single StopStringsLogitsProcessor and make it available to all inference libraries. Not arguing that implementing this specific logits processor should be a priority, but I am arguing that having a separate base class which doesn't demand an FSM provides us flexibility and opportunities.

)

# ensure logits are torch Tensors
torch_logits = self._to_torch(logits)
Copy link
Member

@rlouf rlouf Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the cost of this conversion? We need to profile this.

I would personally handle this by dispatching the core logic used to process the logits, process_logits depending on the type of input_ids or logits

Copy link
Contributor Author

@lapp0 lapp0 Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that we can cast numpy and mlx.core arrays to torch.Tensor using shared memory, not a copy operation :)

Then we can implement a single OutlinesLogitsProcessor subclass with a single process_logits() method and it works out of the box for any library. While the tests aren't in main yet, I've tested FSMLogitsProcessor, and RegexLogitsProcessor with nearly all outlines.models options (haven't tested against exllamav2) and they work with no additional changes.

This makes the processors easy to maintain and implement.

self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf above remark on dispatching the function depending on the type of the tensors.

Copy link
Member

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not convert back-and-forth to torch arrays and instead dispatch the logic depending on the type of input_ids and logits so as to not add a performance penalty during inference.

@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch from 6525fd4 to 7301a7f Compare June 22, 2024 17:11
@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch 10 times, most recently from 0afefbf to 6362217 Compare June 22, 2024 18:13
@lapp0
Copy link
Contributor Author

lapp0 commented Jun 22, 2024

I would not convert back-and-forth to torch arrays and instead dispatch the logic depending on the type of input_ids and logits so as to not add a performance penalty during inference.

Here are the benchmark results:

[50.00%] ··· Running (bench_processors.LogitsProcessorBenchmark.time_logits_processor--).
[100.00%] ··· bench_processors.LogitsProcessorBenchmark.time_logits_processor                                                                                                                                   ok
[100.00%] ··· ======== ==========
               param1            
              -------- ----------
               torch    150±20μs 
               numpy    161±7μs  
                mlx     195±20μs 
              ======== ==========

The performance penalty for casting (not copying) mlx.core.array -> torch.Tensor -> mlx.core.array is 45 microseconds. The benchmark has float logits with shape (4, 30000).

  • This is around 1-3 orders of magnitude faster than applying a mask in FSMLogitsProcessor (varies a lot based on the size of the mask, huge masks can take ~20ms).
  • this is 3-4 orders of magnitude faster than a decoder pass on a tiny model (mlx-community/Qwen1.5-1.8B-Chat-4bit)
  • I expect the 45μs to be closer to numpys 11μs once mlx supports conversion directly from torch without a numpy intermediate.

@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch 4 times, most recently from 19077a6 to 00052d5 Compare June 22, 2024 18:48
@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch from 00052d5 to f00307a Compare June 23, 2024 18:36
@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch from f00307a to c07de55 Compare June 29, 2024 12:47
@lapp0 lapp0 force-pushed the logits-processor-integrations-fix branch from c07de55 to c3e8673 Compare June 29, 2024 13:14
@rlouf rlouf merged commit a643cb0 into dottxt-ai:main Jun 30, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants