-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtest_step.py
100 lines (84 loc) · 2.63 KB
/
test_step.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time
from llama_cpp import Llama
import pytest
from gigax.parse import CharacterAction
from gigax.scene import Character, Item, Location, ProtagonistCharacter
from transformers import AutoTokenizer, AutoModelForCausalLM
from gigax.step import NPCStepper
from outlines import models
from dotenv import load_dotenv
load_dotenv()
def test_stepper_local_llamacpp(
context: str,
locations: list[Location],
NPCs: list[Character],
protagonist: ProtagonistCharacter,
items: list[Item],
events: list[CharacterAction],
):
llm = Llama.from_pretrained(
repo_id="Gigax/NPC-LLM-3_8B-GGUF",
filename="npc-llm-3_8B.gguf"
# n_gpu_layers=-1, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
)
model = models.LlamaCpp(llm)
stepper = NPCStepper(model=model)
start = time.time()
action = stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
protagonist=protagonist,
items=items,
events=events,
)
print(f"Query time: {time.time() - start}")
assert str(action) == "Aldren: Attack John the Brave"
def test_stepper_local_transformers(
context: str,
locations: list[Location],
NPCs: list[Character],
protagonist: ProtagonistCharacter,
items: list[Item],
events: list[CharacterAction],
):
llm = AutoModelForCausalLM.from_pretrained(
"gigax/NPC-LLM-3_8B", output_attentions=True, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("gigax/NPC-LLM-3_8B")
model = models.Transformers(llm, tokenizer) # type: ignore
# Get the NPC's input
stepper = NPCStepper(model=model)
action = stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
protagonist=protagonist,
items=items,
events=events,
)
assert str(action) == "Aldren: Attack John the Brave"
def test_stepper_api(
context: str,
locations: list[Location],
NPCs: list[Character],
protagonist: ProtagonistCharacter,
items: list[Item],
events: list[CharacterAction],
):
# Get the NPC's input
with pytest.raises(ValueError):
NPCStepper(model="mistral-7b-regex")
stepper = NPCStepper(model="mistral_7b_regex", api_key=os.getenv("API_KEY"))
action = await stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
protagonist=protagonist,
items=items,
events=events,
)
assert str(action) == "Aldren: Attack John the Brave"