-
Notifications
You must be signed in to change notification settings - Fork 11
/
main.go
79 lines (64 loc) · 1.52 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package main
import (
"bufio"
"flag"
"fmt"
"io"
"os"
"strings"
gpt2 "github.com/go-skynet/go-ggml-transformers.cpp"
)
var (
threads = 4
tokens = 128
)
func main() {
var model string
flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
flags.StringVar(&model, "m", "./models/7B/ggml-model-q4_0.bin", "path to q4_0.bin model file to load")
flags.IntVar(&threads, "t", 7, "number of threads to use during computation")
flags.IntVar(&tokens, "n", 128, "number of tokens to predict")
err := flags.Parse(os.Args[1:])
if err != nil {
fmt.Printf("Parsing program arguments failed: %s", err)
os.Exit(1)
}
l, err := gpt2.New(model)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
fmt.Printf("Model loaded successfully.\n")
reader := bufio.NewReader(os.Stdin)
for {
text := readMultiLineInput(reader)
res, err := l.Predict(text, gpt2.SetTokens(tokens), gpt2.SetThreads(threads))
if err != nil {
panic(err)
}
fmt.Printf("\ngolang: %s\n", res)
fmt.Printf("\n\n")
}
}
// readMultiLineInput reads input until an empty line is entered.
func readMultiLineInput(reader *bufio.Reader) string {
var lines []string
fmt.Print(">>> ")
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
os.Exit(0)
}
fmt.Printf("Reading the prompt failed: %s", err)
os.Exit(1)
}
if len(strings.TrimSpace(line)) == 0 {
break
}
lines = append(lines, line)
}
text := strings.Join(lines, "")
fmt.Println("Sending", text)
return text
}