Skip to content

Commit 2cb2771

Browse files
committed
Retrieve history from db when generating prompt
1 parent c951e86 commit 2cb2771

File tree

13 files changed

+121
-82
lines changed

13 files changed

+121
-82
lines changed

.dockerignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
db/
2+
.parcel-cache
3+
node_modules
4+
.env
5+
.token_secret

Dockerfile

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FROM node:18-alpine
2+
3+
WORKDIR /app
4+
VOLUME [ "/app/db" ]
5+
6+
RUN npm install pnpm -g
7+
8+
ADD package.json pnpm-lock.yaml ./
9+
RUN pnpm i --frozen-lockfile
10+
11+
ADD public.ts requirements.txt tailwind.config.js tsconfig.json .babelrc .postcssrc .prettierrc ./
12+
ADD common/ ./common/
13+
ADD srv/ ./srv/
14+
ADD web/ ./web
15+
16+
17+
EXPOSE 3001
18+
EXPOSE 5001
19+
20+
CMD ["pnpm", "start"]

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"name": "agn-ai",
2+
"name": "agnaistic",
33
"private": true,
44
"version": "0.1.0",
55
"description": "Agnostic AI Chat",

srv/api/adapter/chai.ts

+2-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import needle from 'needle'
22
import { config } from '../../config'
33
import { logger } from '../../logger'
44
import { trimResponse } from '../chat/common'
5-
import { createPrompt } from './prompt'
65
import { ModelAdapter } from './type'
76

87
const base = {
@@ -13,14 +12,7 @@ const base = {
1312
top_p: 1,
1413
}
1514

16-
export const handleChai: ModelAdapter = async function* ({
17-
chat,
18-
char,
19-
history,
20-
message,
21-
sender,
22-
members,
23-
}) {
15+
export const handleChai: ModelAdapter = async function* ({ char, members, prompt }) {
2416
if (!config.chai.url) {
2517
yield { error: 'Chai URL not set' }
2618
return
@@ -33,7 +25,7 @@ export const handleChai: ModelAdapter = async function* ({
3325

3426
const body = {
3527
...base,
36-
text: createPrompt({ sender, chat, char, history, message, members }),
28+
text: prompt,
3729
}
3830

3931
const response = await needle('post', `${config.chai.url}/generate/gptj`, body, {

srv/api/adapter/generate.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ import { errors, StatusError } from '../wrap'
77
import { handleChai } from './chai'
88
import { handleKobold } from './kobold'
99
import { handleNovel } from './novel'
10+
import { createPrompt } from './prompt'
1011
import { ModelAdapter } from './type'
1112

1213
export type GenerateOptions = {
1314
senderId: string
1415
chatId: string
15-
history: AppSchema.ChatMessage[]
1616
message: string
1717
log: AppLog
18+
retry?: AppSchema.ChatMessage
1819
}
1920

2021
const handlers: { [key in ChatAdapter]: ModelAdapter } = {
@@ -41,7 +42,8 @@ export async function generateResponse(
4142
const adapter =
4243
(opts.chat.adapter === 'default' ? user.defaultAdapter : opts.chat.adapter) ||
4344
user.defaultAdapter
44-
const adapterOpts = { ...opts, members, user, sender }
45+
const prompt = await createPrompt({ ...opts, members, sender })
46+
const adapterOpts = { ...opts, members, user, sender, prompt }
4547

4648
const handler = handlers[adapter]
4749
return handler(adapterOpts)

srv/api/adapter/kobold.ts

+2-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import needle from 'needle'
22
import { config } from '../../config'
33
import { logger } from '../../logger'
44
import { joinParts, trimResponse } from '../chat/common'
5-
import { createPrompt } from './prompt'
65
import { ModelAdapter } from './type'
76

87
const MAX_NEW_TOKENS = 196
@@ -31,19 +30,8 @@ const base = {
3130
sampler_order: [6, 0, 1, 2, 3, 4, 5],
3231
}
3332

34-
export const handleKobold: ModelAdapter = async function* ({
35-
chat,
36-
char,
37-
history,
38-
message,
39-
sender,
40-
members,
41-
user,
42-
}) {
43-
const body = {
44-
...base,
45-
prompt: createPrompt({ chat, char, history, message, sender, members }),
46-
}
33+
export const handleKobold: ModelAdapter = async function* ({ char, members, user, prompt }) {
34+
const body = { ...base, prompt }
4735

4836
let attempts = 0
4937
let maxAttempts = body.max_length / MAX_NEW_TOKENS + 4

srv/api/adapter/novel.ts

+3-22
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import needle from 'needle'
22
import { logger } from '../../logger'
33
import { sanitise, trimResponse } from '../chat/common'
44
import { badWordIds } from './novel-bad-words'
5-
import { createPrompt } from './prompt'
65
import { ModelAdapter } from './type'
76

87
const novelUrl = `https://api.novelai.net/ai/generate`
@@ -37,38 +36,20 @@ const base = {
3736
},
3837
}
3938

40-
export const handleNovel: ModelAdapter = async function* ({
41-
chat,
42-
char,
43-
history,
44-
sender,
45-
message,
46-
members,
47-
user,
48-
}) {
39+
export const handleNovel: ModelAdapter = async function* ({ char, members, user, prompt }) {
4940
if (!user.novelApiKey) {
5041
yield { error: 'Novel API key not set' }
5142
return
5243
}
5344

54-
const body = {
55-
...base,
56-
input: createPrompt({
57-
chat,
58-
char,
59-
history,
60-
message,
61-
sender,
62-
members,
63-
}),
64-
}
45+
const body = { ...base, input: prompt }
6546

6647
const endTokens = ['***', 'Scenario:', '----']
6748

6849
const response = await needle('post', novelUrl, body, {
6950
json: true,
7051
timeout: 2000,
71-
response_timeout: 8000,
52+
response_timeout: 10000,
7253
headers: { Authorization: `Bearer ${user.novelApiKey}` },
7354
}).catch((err) => ({ err }))
7455

srv/api/adapter/prompt.ts

+58-15
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,88 @@
11
import { AppSchema } from '../../db/schema'
22
import gpt from 'gpt-3-encoder'
33
import { logger } from '../../logger'
4+
import { store } from '../../db'
45

56
type PromptOpts = {
67
sender: AppSchema.Profile
78
chat: AppSchema.Chat
89
char: AppSchema.Character
9-
history: AppSchema.ChatMessage[]
1010
message: string
1111
members: AppSchema.Profile[]
12+
retry?: AppSchema.ChatMessage
1213
}
1314

15+
const MAX_TOKENS = 2048
1416
const BOT_REPLACE = /\{\{char\}\}/g
1517
const SELF_REPLACE = /\{\{user\}\}/g
1618

17-
export function createPrompt({ sender, chat, char, history, message, members }: PromptOpts) {
19+
export async function createPrompt({ sender, chat, char, message, members }: PromptOpts) {
1820
const username = sender.handle || 'You'
1921

20-
const lines: string[] = [`${char.name}'s Persona: ${formatCharacter(char.name, chat.overrides)}`]
22+
const pre: string[] = [`${char.name}'s Persona: ${formatCharacter(char.name, chat.overrides)}`]
2123

2224
if (chat.scenario) {
23-
lines.push(`Scenario: ${chat.scenario}`)
25+
pre.push(`Scenario: ${chat.scenario}`)
2426
}
2527

26-
lines.push(
27-
`<START>`,
28-
...chat.sampleChat.split('\n'),
29-
...history.map((chat) => prefix(chat, char.name, members) + chat.msg),
30-
`${username}: ${message}`,
31-
`${char.name}:`
32-
)
28+
pre.push(`<START>`, ...chat.sampleChat.split('\n'))
29+
const post = [`${username}: ${message}`, `${char.name}:`]
3330

34-
const prompt = lines
31+
const prompt = await appendHistory(chat, members, char, pre, post)
32+
return prompt
33+
}
34+
35+
async function appendHistory(
36+
chat: AppSchema.Chat,
37+
members: AppSchema.Profile[],
38+
char: AppSchema.Character,
39+
pre: string[],
40+
post: string[],
41+
retry?: AppSchema.ChatMessage
42+
) {
43+
const owner = members.find((mem) => mem.userId === chat.userId)
44+
if (!owner) {
45+
throw new Error(`Cannot produce prompt: Owner profile not found`)
46+
}
47+
48+
// We need to do this early for accurate token counts
49+
const preamble = pre
50+
.filter(removeEmpty)
51+
.join('\n')
52+
.replace(BOT_REPLACE, char.name)
53+
.replace(SELF_REPLACE, owner.handle)
54+
const postamble = post
3555
.filter(removeEmpty)
3656
.join('\n')
3757
.replace(BOT_REPLACE, char.name)
38-
.replace(SELF_REPLACE, username)
58+
.replace(SELF_REPLACE, owner.handle)
3959

40-
const tokens = gpt.encode(prompt).length
41-
logger.debug({ tokens, prompt }, 'Tokens')
60+
let tokens = gpt.encode(preamble + '\n' + postamble).length
61+
const lines: string[] = []
62+
let before = retry?.updatedAt
63+
64+
do {
65+
const messages = await store.chats.getRecentMessages(chat._id, before)
66+
const history = messages.map((chat) => prefix(chat, char.name, members) + chat.msg)
67+
68+
for (const hist of history) {
69+
const nextTokens = gpt.encode(hist).length
70+
if (nextTokens + tokens > MAX_TOKENS) break
71+
tokens += nextTokens
72+
lines.unshift(hist)
73+
}
74+
75+
if (tokens >= MAX_TOKENS || messages.length < 50) break
76+
before = messages.slice(-1)[0].createdAt
77+
} while (true)
78+
79+
const middle = lines
80+
.join('\n')
81+
.replace(BOT_REPLACE, char.name)
82+
.replace(SELF_REPLACE, owner.handle)
4283

84+
const prompt = [preamble, middle, postamble].join('\n')
85+
logger.info({ tokens, prompt }, 'Tokens used')
4386
return prompt
4487
}
4588

srv/api/adapter/type.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export type ModelAdapter = (opts: {
55
char: AppSchema.Character
66
user: AppSchema.User
77
members: AppSchema.Profile[]
8-
history: AppSchema.ChatMessage[]
98
message: string
109
sender: AppSchema.Profile
10+
prompt: string
1111
}) => AsyncGenerator<string | { error: any }>

srv/api/chat/message.ts

+4-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { obtainLock, releaseLock, verifyLock } from './lock'
77

88
export const generateMessage = handle(async ({ userId, params, body, log }, res) => {
99
const id = params.id
10-
assertValid({ message: 'string', history: 'any', ephemeral: 'boolean?', retry: 'boolean?' }, body)
10+
assertValid({ message: 'string', ephemeral: 'boolean?', retry: 'boolean?' }, body)
1111

1212
const lockId = await obtainLock(id)
1313

@@ -41,7 +41,6 @@ export const generateMessage = handle(async ({ userId, params, body, log }, res)
4141
senderId: userId!,
4242
chatId: id,
4343
message: body.message,
44-
history: body.history,
4544
log,
4645
})
4746

@@ -73,19 +72,12 @@ export const generateMessage = handle(async ({ userId, params, body, log }, res)
7372
export const retryMessage = handle(async ({ body, params, userId, log }, res) => {
7473
const { id, messageId } = params
7574

76-
assertValid(
77-
{
78-
history: 'any',
79-
message: 'string',
80-
ephemeral: 'boolean?',
81-
},
82-
body
83-
)
75+
assertValid({ message: 'string', ephemeral: 'boolean?' }, body)
8476

8577
const lockId = await obtainLock(id)
8678

8779
const prev = await store.chats.getMessageAndChat(messageId)
88-
if (!prev || !prev.chat) throw errors.NotFound
80+
if (!prev || !prev.chat || !prev.msg) throw errors.NotFound
8981

9082
const members = prev.chat.memberIds.concat(prev.chat.userId)
9183
if (!members.includes(userId!)) throw errors.Forbidden
@@ -99,10 +91,10 @@ export const retryMessage = handle(async ({ body, params, userId, log }, res) =>
9991

10092
const { stream } = await createResponseStream({
10193
chatId: params.id,
102-
history: body.history,
10394
message: body.message,
10495
senderId: userId!,
10596
log,
97+
retry: prev.msg,
10698
})
10799

108100
let generated = ''

srv/api/ws/handle.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,12 @@ export function publishMany<T extends { type: string }>(userIds: string[], data:
9191
for (const userId of unique) {
9292
count += publishOne(userId, data)
9393
}
94-
95-
logger.debug({ count }, 'Messages sent')
9694
}
9795

9896
export function publishOne<T extends { type: string }>(userId: string, data: T) {
9997
let count = 0
10098
const sockets = userSockets.get(userId)
101-
logger.info({ count: sockets?.length, type: data.type }, 'Publishing')
99+
102100
if (!sockets) return count
103101

104102
for (const socket of sockets) {

srv/db/chats.ts

+20
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ export async function editMessage(id: string, content: string) {
128128
return doc
129129
}
130130

131+
export async function getMessage(messageId: string) {
132+
const msg = await msgs.findOne({ _id: messageId, kind: 'chat-message' })
133+
return msg
134+
}
135+
131136
export async function deleteMessages(messageIds: string[]) {
132137
await chats.deleteMany({ _id: { $in: messageIds } }, { multi: true })
133138
}
@@ -170,3 +175,18 @@ export async function getAllChats(userId: string) {
170175

171176
return list
172177
}
178+
179+
/**
180+
*
181+
* @param chatId
182+
* @param before Date ISO string
183+
*/
184+
export async function getRecentMessages(chatId: string, before?: string) {
185+
const query: any = { kind: 'chat-message', chatId }
186+
if (before) {
187+
query.createdAt = { $lt: before }
188+
}
189+
190+
const messages = await msgs.find(query).sort({ createdAt: -1 }).limit(50)
191+
return messages
192+
}

0 commit comments

Comments
 (0)