Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/openai/logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def _get_next_logits(body):
state = process_parameters(body) if use_samplers else {}
state['stream'] = True

return get_next_logits(body['prompt'], state, use_samplers, "", return_dict=True)
return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True)
5 changes: 3 additions & 2 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from typing import List
from typing import Dict, List

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -156,6 +156,7 @@ class TokenCountResponse(BaseModel):
class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
top_logits: int | None = 50
frequency_penalty: float | None = 0
max_tokens: int | None = 16
presence_penalty: float | None = 0
Expand All @@ -168,7 +169,7 @@ class LogitsRequest(GenerationOptions, LogitsRequestParams):


class LogitsResponse(BaseModel):
logits: dict
logits: Dict[str, float]


class ModelInfoResponse(BaseModel):
Expand Down
7 changes: 4 additions & 3 deletions modules/logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
global_scores = None


def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
def get_next_logits(prompt, state, use_samplers, previous, top_logits=50, return_dict=False):
if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous
Expand Down Expand Up @@ -50,8 +50,7 @@ def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
scores = output['logits'][-1][-1]

probs = torch.softmax(scores, dim=-1, dtype=torch.float)
topk_values, topk_indices = torch.topk(probs, k=50, largest=True, sorted=True)
topk_values = [f"{float(i):.5f}" for i in topk_values]
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
if is_non_hf_exllamav1 or is_non_hf_llamacpp:
topk_indices = [i.expand((1, 1)) for i in topk_indices]

Expand All @@ -61,12 +60,14 @@ def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
tokens = [shared.tokenizer.decode(i) for i in topk_indices]

if return_dict:
topk_values = [float(i) for i in topk_values]
output = {}
for row in list(zip(topk_values, tokens)):
output[row[1]] = row[0]

return output
else:
topk_values = [f"{float(i):.5f}" for i in topk_values]
output = ''
for row in list(zip(topk_values, tokens)):
output += f"{row[0]} - {repr(row[1])}\n"
Expand Down