Skip to content

Commit c8cae92

Browse files
authored
Support using prompt param with GuardrailAgent (#23)
1 parent 8f30ead commit c8cae92

File tree

2 files changed

+194
-42
lines changed

2 files changed

+194
-42
lines changed

src/__tests__/unit/agents.test.ts

Lines changed: 163 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,27 @@
33
*/
44

55
import { describe, it, expect, vi, beforeEach } from 'vitest';
6+
import type { InputGuardrail, OutputGuardrail } from '@openai/agents-core';
67
import { GuardrailAgent } from '../../agents';
78
import { TextInput } from '../../types';
89
import { z } from 'zod';
910

1011
// Define the expected agent interface for testing
1112
interface MockAgent {
1213
name: string;
13-
instructions: string;
14-
inputGuardrails: Array<{ execute: (input: TextInput) => Promise<{ outputInfo: Record<string, unknown>; tripwireTriggered: boolean }> }>;
15-
outputGuardrails: Array<{ execute: (input: TextInput) => Promise<{ outputInfo: Record<string, unknown>; tripwireTriggered: boolean }> }>;
14+
instructions?: string | ((context: unknown, agent: unknown) => string | Promise<string>);
15+
inputGuardrails: Array<{
16+
name?: string;
17+
execute: (
18+
input: TextInput
19+
) => Promise<{ outputInfo: Record<string, unknown>; tripwireTriggered: boolean }>;
20+
}>;
21+
outputGuardrails: Array<{
22+
name?: string;
23+
execute: (
24+
input: TextInput
25+
) => Promise<{ outputInfo: Record<string, unknown>; tripwireTriggered: boolean }>;
26+
}>;
1627
model?: string;
1728
temperature?: number;
1829
max_tokens?: number;
@@ -35,20 +46,20 @@ vi.mock('../../runtime', () => ({
3546
instantiateGuardrails: vi.fn(() =>
3647
Promise.resolve([
3748
{
38-
definition: {
49+
definition: {
3950
name: 'Keywords',
4051
description: 'Test guardrail',
4152
mediaType: 'text/plain',
4253
configSchema: z.object({}),
4354
checkFn: vi.fn(),
4455
contextSchema: z.object({}),
45-
metadata: {}
56+
metadata: {},
4657
},
4758
config: {},
48-
run: vi.fn().mockResolvedValue({
49-
tripwireTriggered: false,
50-
info: { checked_text: 'test input' },
51-
}),
59+
run: vi.fn().mockResolvedValue({
60+
tripwireTriggered: false,
61+
info: { checked_text: 'test input' },
62+
}),
5263
},
5364
])
5465
),
@@ -83,7 +94,11 @@ describe('GuardrailAgent', () => {
8394
},
8495
};
8596

86-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
97+
const agent = (await GuardrailAgent.create(
98+
config,
99+
'Test Agent',
100+
'Test instructions'
101+
)) as MockAgent;
87102

88103
expect(agent.name).toBe('Test Agent');
89104
expect(agent.instructions).toBe('Test instructions');
@@ -100,7 +115,11 @@ describe('GuardrailAgent', () => {
100115
},
101116
};
102117

103-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
118+
const agent = (await GuardrailAgent.create(
119+
config,
120+
'Test Agent',
121+
'Test instructions'
122+
)) as MockAgent;
104123

105124
expect(agent.name).toBe('Test Agent');
106125
expect(agent.instructions).toBe('Test instructions');
@@ -125,7 +144,11 @@ describe('GuardrailAgent', () => {
125144
},
126145
};
127146

128-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
147+
const agent = (await GuardrailAgent.create(
148+
config,
149+
'Test Agent',
150+
'Test instructions'
151+
)) as MockAgent;
129152

130153
expect(agent.name).toBe('Test Agent');
131154
expect(agent.instructions).toBe('Test instructions');
@@ -148,12 +171,12 @@ describe('GuardrailAgent', () => {
148171
max_tokens: 1000,
149172
};
150173

151-
const agent = await GuardrailAgent.create(
174+
const agent = (await GuardrailAgent.create(
152175
config,
153176
'Test Agent',
154177
'Test instructions',
155178
agentKwargs
156-
) as MockAgent;
179+
)) as MockAgent;
157180

158181
expect(agent.model).toBe('gpt-4');
159182
expect(agent.temperature).toBe(0.7);
@@ -163,7 +186,11 @@ describe('GuardrailAgent', () => {
163186
it('should handle empty configuration gracefully', async () => {
164187
const config = { version: 1 };
165188

166-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
189+
const agent = (await GuardrailAgent.create(
190+
config,
191+
'Test Agent',
192+
'Test instructions'
193+
)) as MockAgent;
167194

168195
expect(agent.name).toBe('Test Agent');
169196
expect(agent.instructions).toBe('Test instructions');
@@ -180,13 +207,13 @@ describe('GuardrailAgent', () => {
180207
},
181208
};
182209

183-
const agent = await GuardrailAgent.create(
210+
const agent = (await GuardrailAgent.create(
184211
config,
185212
'Test Agent',
186213
'Test instructions',
187214
{},
188215
true // raiseGuardrailErrors = true
189-
) as MockAgent;
216+
)) as MockAgent;
190217

191218
expect(agent.name).toBe('Test Agent');
192219
expect(agent.instructions).toBe('Test instructions');
@@ -202,7 +229,11 @@ describe('GuardrailAgent', () => {
202229
},
203230
};
204231

205-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
232+
const agent = (await GuardrailAgent.create(
233+
config,
234+
'Test Agent',
235+
'Test instructions'
236+
)) as MockAgent;
206237

207238
expect(agent.name).toBe('Test Agent');
208239
expect(agent.instructions).toBe('Test instructions');
@@ -214,6 +245,97 @@ describe('GuardrailAgent', () => {
214245
// For now, we'll skip it since the error handling is tested in the actual implementation
215246
expect(true).toBe(true); // Placeholder assertion
216247
});
248+
249+
it('should work without instructions parameter', async () => {
250+
const config = { version: 1 };
251+
252+
// Should not throw TypeError about missing instructions
253+
const agent = (await GuardrailAgent.create(config, 'NoInstructions')) as MockAgent;
254+
255+
expect(agent.name).toBe('NoInstructions');
256+
expect(agent.instructions).toBeUndefined();
257+
});
258+
259+
it('should accept callable instructions', async () => {
260+
const config = { version: 1 };
261+
262+
const dynamicInstructions = (ctx: unknown, agent: unknown) => {
263+
return `You are ${(agent as { name: string }).name}`;
264+
};
265+
266+
const agent = (await GuardrailAgent.create(
267+
config,
268+
'DynamicAgent',
269+
dynamicInstructions
270+
)) as MockAgent;
271+
272+
expect(agent.name).toBe('DynamicAgent');
273+
expect(typeof agent.instructions).toBe('function');
274+
expect(agent.instructions).toBe(dynamicInstructions);
275+
});
276+
277+
it('should merge user input guardrails with config guardrails', async () => {
278+
const config = {
279+
version: 1,
280+
input: {
281+
version: 1,
282+
guardrails: [{ name: 'Keywords', config: {} }],
283+
},
284+
};
285+
286+
// Create a custom user guardrail
287+
const customGuardrail: InputGuardrail = {
288+
name: 'Custom Input Guard',
289+
execute: async () => ({ outputInfo: {}, tripwireTriggered: false }),
290+
};
291+
292+
const agent = (await GuardrailAgent.create(config, 'MergedAgent', 'Test instructions', {
293+
inputGuardrails: [customGuardrail],
294+
})) as MockAgent;
295+
296+
// Should have both config and user guardrails merged (config first, then user)
297+
expect(agent.inputGuardrails).toHaveLength(2);
298+
expect(agent.inputGuardrails[0].name).toContain('input:');
299+
expect(agent.inputGuardrails[1].name).toBe('Custom Input Guard');
300+
});
301+
302+
it('should merge user output guardrails with config guardrails', async () => {
303+
const config = {
304+
version: 1,
305+
output: {
306+
version: 1,
307+
guardrails: [{ name: 'URL Filter', config: {} }],
308+
},
309+
};
310+
311+
// Create a custom user guardrail
312+
const customGuardrail: OutputGuardrail = {
313+
name: 'Custom Output Guard',
314+
execute: async () => ({ outputInfo: {}, tripwireTriggered: false }),
315+
};
316+
317+
const agent = (await GuardrailAgent.create(config, 'MergedAgent', 'Test instructions', {
318+
outputGuardrails: [customGuardrail],
319+
})) as MockAgent;
320+
321+
// Should have both config and user guardrails merged (config first, then user)
322+
expect(agent.outputGuardrails).toHaveLength(2);
323+
expect(agent.outputGuardrails[0].name).toContain('output:');
324+
expect(agent.outputGuardrails[1].name).toBe('Custom Output Guard');
325+
});
326+
327+
it('should handle empty user guardrail arrays gracefully', async () => {
328+
const config = { version: 1 };
329+
330+
const agent = (await GuardrailAgent.create(config, 'EmptyListAgent', 'Test instructions', {
331+
inputGuardrails: [],
332+
outputGuardrails: [],
333+
})) as MockAgent;
334+
335+
expect(agent.name).toBe('EmptyListAgent');
336+
expect(agent.inputGuardrails).toHaveLength(0);
337+
expect(agent.outputGuardrails).toHaveLength(0);
338+
});
217339
});
218340

219341
describe('guardrail function creation', () => {
@@ -226,7 +348,11 @@ describe('GuardrailAgent', () => {
226348
},
227349
};
228350

229-
const agent = await GuardrailAgent.create(config, 'Test Agent', 'Test instructions') as MockAgent;
351+
const agent = (await GuardrailAgent.create(
352+
config,
353+
'Test Agent',
354+
'Test instructions'
355+
)) as MockAgent;
230356

231357
expect(agent.inputGuardrails).toHaveLength(1);
232358

@@ -254,7 +380,7 @@ describe('GuardrailAgent', () => {
254380
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
255381
Promise.resolve([
256382
{
257-
definition: {
383+
definition: {
258384
name: 'Keywords',
259385
description: 'Test guardrail',
260386
mediaType: 'text/plain',
@@ -263,22 +389,26 @@ describe('GuardrailAgent', () => {
263389
metadata: {},
264390
ctxRequirements: z.object({}),
265391
schema: () => ({}),
266-
instantiate: vi.fn()
392+
instantiate: vi.fn(),
267393
},
268394
config: {},
269395
run: vi.fn().mockRejectedValue(new Error('Guardrail execution failed')),
270-
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T> ? T extends readonly (infer U)[] ? U : never : never,
396+
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
397+
? T extends readonly (infer U)[]
398+
? U
399+
: never
400+
: never,
271401
])
272402
);
273403

274404
// Test with raiseGuardrailErrors = false (default behavior)
275-
const agentDefault = await GuardrailAgent.create(
405+
const agentDefault = (await GuardrailAgent.create(
276406
config,
277407
'Test Agent',
278408
'Test instructions',
279409
{},
280410
false
281-
) as MockAgent;
411+
)) as MockAgent;
282412

283413
const guardrailFunctionDefault = agentDefault.inputGuardrails[0];
284414
const resultDefault = await guardrailFunctionDefault.execute('test');
@@ -293,7 +423,7 @@ describe('GuardrailAgent', () => {
293423
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
294424
Promise.resolve([
295425
{
296-
definition: {
426+
definition: {
297427
name: 'Keywords',
298428
description: 'Test guardrail',
299429
mediaType: 'text/plain',
@@ -302,22 +432,26 @@ describe('GuardrailAgent', () => {
302432
metadata: {},
303433
ctxRequirements: z.object({}),
304434
schema: () => ({}),
305-
instantiate: vi.fn()
435+
instantiate: vi.fn(),
306436
},
307437
config: {},
308438
run: vi.fn().mockRejectedValue(new Error('Guardrail execution failed')),
309-
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T> ? T extends readonly (infer U)[] ? U : never : never,
439+
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
440+
? T extends readonly (infer U)[]
441+
? U
442+
: never
443+
: never,
310444
])
311445
);
312446

313447
// Test with raiseGuardrailErrors = true (fail-secure mode)
314-
const agentStrict = await GuardrailAgent.create(
448+
const agentStrict = (await GuardrailAgent.create(
315449
config,
316450
'Test Agent',
317451
'Test instructions',
318452
{},
319453
true
320-
) as MockAgent;
454+
)) as MockAgent;
321455

322456
const guardrailFunctionStrict = agentStrict.inputGuardrails[0];
323457

0 commit comments

Comments
 (0)