Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: spport qwen-vl and tool call for qwen #3114

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions src/config/modelProviders/qwen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ const Qwen: ModelProviderCard = {
description: '通义千问超大规模语言模型,支持中文、英文等不同语言输入',
displayName: 'Qwen Turbo',
enabled: true,
functionCall: true,
id: 'qwen-turbo',
tokens: 8000,
},
{
description: '通义千问超大规模语言模型增强版,支持中文、英文等不同语言输入',
displayName: 'Qwen Plus',
enabled: true,
functionCall: true,
id: 'qwen-plus',
tokens: 32_000,
},
Expand All @@ -22,13 +24,15 @@ const Qwen: ModelProviderCard = {
'通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入,当前通义千问2.5产品版本背后的API模型',
displayName: 'Qwen Max',
enabled: true,
functionCall: true,
id: 'qwen-max',
tokens: 8000,
},
{
description:
'通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入,扩展了上下文窗口',
displayName: 'Qwen Max LongContext',
functionCall: true,
id: 'qwen-max-longcontext',
tokens: 30_000,
},
Expand All @@ -50,6 +54,24 @@ const Qwen: ModelProviderCard = {
id: 'qwen2-72b-instruct',
tokens: 131_072,
},
{
description:
'通义千问大规模视觉语言模型增强版。大幅提升细节识别能力和文字识别能力,支持超百万像素分辨率和任意长宽比规格的图像。',
displayName: 'Qwen VL Plus',
enabled: true,
id: 'qwen-vl-plus',
tokens: 6144,
vision: true,
},
{
description:
'通义千问超大规模视觉语言模型。相比增强版,再次提升视觉推理能力和指令遵循能力,提供更高的视觉感知和认知水平。',
displayName: 'Qwen VL Max',
enabled: true,
id: 'qwen-vl-max',
tokens: 6144,
vision: true,
},
],
checkModel: 'qwen-turbo',
disableBrowserRequest: true,
Expand Down
118 changes: 114 additions & 4 deletions src/libs/agent-runtime/qwen/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import Qwen from '@/config/modelProviders/qwen';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeErrorType } from '@/libs/agent-runtime';
Expand All @@ -17,7 +18,7 @@ const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey;
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeOpenAICompatibleRuntime;
let instance: LobeQwenAI;

beforeEach(() => {
instance = new LobeQwenAI({ apiKey: 'test' });
Expand All @@ -41,7 +42,116 @@ describe('LobeQwenAI', () => {
});
});

describe('models', () => {
it('should correctly list available models', async () => {
arvinxx marked this conversation as resolved.
Show resolved Hide resolved
const instance = new LobeQwenAI({ apiKey: 'test_api_key' });
vi.spyOn(instance, 'models').mockResolvedValue(Qwen.chatModels);

const models = await instance.models();
expect(models).toEqual(Qwen.chatModels);
});
});

describe('chat', () => {
describe('Params', () => {
it('should call llms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: true,
top_p: 0.7,
result_format: 'message',
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should call vlms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
stream: true,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should transform non-streaming response to stream correctly', async () => {
const mockResponse: OpenAI.ChatCompletion = {
id: 'chatcmpl-fc539f49-51a8-94be-8061',
object: 'chat.completion',
created: 1719901794,
model: 'qwen-turbo',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Hello' },
finish_reason: 'stop',
logprobs: null,
},
],
};
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockResponse as any,
);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: false,
});

const decoder = new TextDecoder();

const reader = result.body!.getReader();
expect(decoder.decode((await reader.read()).value)).toContain(
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
);
expect(decoder.decode((await reader.read()).value)).toContain('event: text\n');
expect(decoder.decode((await reader.read()).value)).toContain('data: "Hello"\n\n');

expect(decoder.decode((await reader.read()).value)).toContain(
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
);
expect(decoder.decode((await reader.read()).value)).toContain('event: stop\n');
expect(decoder.decode((await reader.read()).value)).toContain('');

expect((await reader.read()).done).toBe(true);
});
});

describe('Error', () => {
it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
Expand Down Expand Up @@ -129,8 +239,7 @@ describe('LobeQwenAI', () => {

instance = new LobeQwenAI({
apiKey: 'test',

baseURL: 'https://api.abc.com/v1',
baseURL: defaultBaseURL,
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
Expand All @@ -144,7 +253,8 @@ describe('LobeQwenAI', () => {
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
/* Desensitizing is unnecessary for a public-accessible gateway endpoint. */
endpoint: defaultBaseURL,
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
Expand Down
156 changes: 128 additions & 28 deletions src/libs/agent-runtime/qwen/index.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,128 @@
import OpenAI from 'openai';

import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';

export const LobeQwenAI = LobeOpenAICompatibleFactory({
baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
chatCompletion: {
handlePayload: (payload) => {
const top_p = payload.top_p;
return {
...payload,
stream: payload.stream ?? true,
top_p: top_p && top_p >= 1 ? 0.9999 : top_p,
} as OpenAI.ChatCompletionCreateParamsStreaming;
},
},
constructorOptions: {
defaultHeaders: {
'Content-Type': 'application/json',
},
},
debug: {
chatCompletion: () => process.env.DEBUG_QWEN_CHAT_COMPLETION === '1',
},

provider: ModelProvider.Qwen,
});
import { omit } from 'lodash-es';
import OpenAI, { ClientOptions } from 'openai';

import Qwen from '@/config/modelProviders/qwen';

import { LobeOpenAICompatibleRuntime, LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { handleOpenAIError } from '../utils/handleOpenAIError';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { QwenAIStream } from '../utils/streams';

const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1';

/**
* Use DashScope OpenAI compatible mode for now.
* DashScope OpenAI [compatible mode](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api) currently supports base64 image input for vision models e.g. qwen-vl-plus.
* You can use images input either:
* 1. Use qwen-vl-* out of box with base64 image_url input;
* or
* 2. Set S3-* enviroment variables properly to store all uploaded files.
*/
export class LobeQwenAI extends LobeOpenAICompatibleRuntime implements LobeRuntimeAI {
client: OpenAI;
baseURL: string;

constructor({
apiKey,
baseURL = DEFAULT_BASE_URL,
...res
}: ClientOptions & Record<string, any> = {}) {
super();
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = new OpenAI({ apiKey, baseURL, ...res });
this.baseURL = this.client.baseURL;
}

async models() {
return Qwen.chatModels;
}

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
try {
const params = this.buildCompletionParamsByModel(payload);

const response = await this.client.chat.completions.create(
params as OpenAI.ChatCompletionCreateParamsStreaming & { result_format: string },
{
headers: { Accept: '*/*' },
signal: options?.signal,
},
);

if (params.stream) {
const [prod, debug] = response.tee();

if (process.env.DEBUG_QWEN_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}

return StreamingResponse(QwenAIStream(prod, options?.callback), {
headers: options?.headers,
});
}

const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion);

return StreamingResponse(QwenAIStream(stream, options?.callback), {
headers: options?.headers,
});
} catch (error) {
if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: error as any,
errorType: AgentRuntimeErrorType.InvalidProviderAPIKey,
provider: ModelProvider.Qwen,
});
}

default: {
break;
}
}
}
const { errorResult, RuntimeError } = handleOpenAIError(error);
const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError;

throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: errorResult,
errorType,
provider: ModelProvider.Qwen,
});
}
}

private buildCompletionParamsByModel(payload: ChatStreamPayload) {
const { model, top_p, stream, messages, tools } = payload;
const isVisionModel = model.startsWith('qwen-vl');

const params = {
...payload,
messages,
result_format: 'message',
stream: !!tools?.length ? false : stream ?? true,
top_p: top_p && top_p >= 1 ? 0.999 : top_p,
};

/* Qwen-vl models temporarily do not support parameters below. */
/* Notice: `top_p` imposes significant impact on the result,the default 1 or 0.999 is not a proper choice. */
return isVisionModel
? omit(
params,
'presence_penalty',
'frequency_penalty',
'temperature',
'result_format',
'top_p',
)
: params;
}
}
Loading