Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lmformatenforcer/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union
try:
from transformers import AutoModelForCausalLM
from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor
from transformers.generation.logits_process import LogitsProcessor, PrefixConstrainedLogitsProcessor
from transformers.tokenization_utils import PreTrainedTokenizerBase
except ImportError:
raise ImportError('transformers is not installed. Please install it with "pip install transformers[torch]"')
Expand All @@ -16,7 +16,7 @@
from ..tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from ..analyzer import FormatEnforcerAnalyzer

class LogitsSaverWarper(LogitsWarper):
class LogitsSaverWarper(LogitsProcessor):
def __init__(self, analyzer: FormatEnforcerAnalyzer) -> None:
self.analyzer = analyzer

Expand Down
Loading