Skip to content

Commit 24da084

Browse files
authored
Updating safety_identifier usage (#27)
1 parent b26aad8 commit 24da084

File tree

11 files changed

+361
-58
lines changed

11 files changed

+361
-58
lines changed

src/__tests__/unit/chat-resources.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ describe('Chat resource', () => {
9595
messages,
9696
model: 'gpt-4',
9797
stream: false,
98-
safety_identifier: 'oai-guardrails-ts',
98+
safety_identifier: 'openai-guardrails-js',
9999
});
100100
expect(client.handleLlmResponse).toHaveBeenCalledWith(
101101
{ id: 'chat-response' },
@@ -156,7 +156,7 @@ describe('Responses resource', () => {
156156
model: 'gpt-4o',
157157
stream: false,
158158
tools: undefined,
159-
safety_identifier: 'oai-guardrails-ts',
159+
safety_identifier: 'openai-guardrails-js',
160160
});
161161
expect(client.handleLlmResponse).toHaveBeenCalledWith(
162162
{ id: 'responses-api' },

src/__tests__/unit/checks/moderation-secret-keys.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ describe('moderation guardrail', () => {
4848
expect(createMock).toHaveBeenCalledWith({
4949
model: 'omni-moderation-latest',
5050
input: 'bad content',
51+
safety_identifier: 'openai-guardrails-js',
5152
});
5253
expect(result.tripwireTriggered).toBe(true);
5354
expect(result.info?.flagged_categories).toEqual([Category.HATE]);

src/__tests__/unit/checks/user-defined-llm.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ describe('userDefinedLLMCheck', () => {
5252
model: 'gpt-test',
5353
temperature: 0.0,
5454
response_format: { type: 'json_object' },
55+
safety_identifier: 'openai-guardrails-js',
5556
});
5657
expect(result.tripwireTriggered).toBe(true);
5758
expect(result.info?.flagged).toBe(true);
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/**
2+
* Unit tests for safety identifier utilities.
3+
*
4+
* These tests verify the detection logic for determining whether a client
5+
* supports the safety_identifier parameter in OpenAI API calls.
6+
*/
7+
8+
import { describe, it, expect } from 'vitest';
9+
import { supportsSafetyIdentifier, SAFETY_IDENTIFIER } from '../../../utils/safety-identifier';
10+
11+
describe('Safety Identifier utilities', () => {
12+
describe('SAFETY_IDENTIFIER constant', () => {
13+
it('should have the correct value', () => {
14+
expect(SAFETY_IDENTIFIER).toBe('openai-guardrails-js');
15+
});
16+
});
17+
18+
describe('supportsSafetyIdentifier', () => {
19+
it('should return true for official OpenAI client with default baseURL', () => {
20+
// Mock an official OpenAI client (no custom baseURL)
21+
const mockClient = {
22+
constructor: { name: 'OpenAI' },
23+
baseURL: undefined,
24+
};
25+
26+
expect(supportsSafetyIdentifier(mockClient)).toBe(true);
27+
});
28+
29+
it('should return true for OpenAI client with explicit api.openai.com baseURL', () => {
30+
const mockClient = {
31+
constructor: { name: 'OpenAI' },
32+
baseURL: 'https://api.openai.com/v1',
33+
};
34+
35+
expect(supportsSafetyIdentifier(mockClient)).toBe(true);
36+
});
37+
38+
it('should return false for Azure OpenAI client', () => {
39+
const mockClient = {
40+
constructor: { name: 'AzureOpenAI' },
41+
baseURL: 'https://example.openai.azure.com/v1',
42+
};
43+
44+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
45+
});
46+
47+
it('should return false for AsyncAzureOpenAI client', () => {
48+
const mockClient = {
49+
constructor: { name: 'AsyncAzureOpenAI' },
50+
baseURL: 'https://example.openai.azure.com/v1',
51+
};
52+
53+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
54+
});
55+
56+
it('should return false for local model with custom baseURL (Ollama)', () => {
57+
const mockClient = {
58+
constructor: { name: 'OpenAI' },
59+
baseURL: 'http://localhost:11434/v1',
60+
};
61+
62+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
63+
});
64+
65+
it('should return false for alternative OpenAI-compatible provider', () => {
66+
const mockClient = {
67+
constructor: { name: 'OpenAI' },
68+
baseURL: 'https://api.together.xyz/v1',
69+
};
70+
71+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
72+
});
73+
74+
it('should return false for vLLM server', () => {
75+
const mockClient = {
76+
constructor: { name: 'OpenAI' },
77+
baseURL: 'http://localhost:8000/v1',
78+
};
79+
80+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
81+
});
82+
83+
it('should return false for null client', () => {
84+
expect(supportsSafetyIdentifier(null)).toBe(false);
85+
});
86+
87+
it('should return false for undefined client', () => {
88+
expect(supportsSafetyIdentifier(undefined)).toBe(false);
89+
});
90+
91+
it('should return false for non-object client', () => {
92+
expect(supportsSafetyIdentifier('not an object')).toBe(false);
93+
expect(supportsSafetyIdentifier(123)).toBe(false);
94+
});
95+
96+
it('should check _client.baseURL if baseURL is not directly accessible', () => {
97+
const mockClient = {
98+
constructor: { name: 'OpenAI' },
99+
_client: {
100+
baseURL: 'http://localhost:11434/v1',
101+
},
102+
};
103+
104+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
105+
});
106+
107+
it('should check _baseURL if baseURL and _client.baseURL are not accessible', () => {
108+
const mockClient = {
109+
constructor: { name: 'OpenAI' },
110+
_baseURL: 'http://localhost:11434/v1',
111+
};
112+
113+
expect(supportsSafetyIdentifier(mockClient)).toBe(false);
114+
});
115+
116+
it('should return true when api.openai.com is found via _client.baseURL', () => {
117+
const mockClient = {
118+
constructor: { name: 'OpenAI' },
119+
_client: {
120+
baseURL: 'https://api.openai.com/v1',
121+
},
122+
};
123+
124+
expect(supportsSafetyIdentifier(mockClient)).toBe(true);
125+
});
126+
});
127+
});
128+

src/checks/llm-base.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { z } from 'zod';
1111
import { OpenAI } from 'openai';
1212
import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types';
1313
import { defaultSpecRegistry } from '../registry';
14+
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';
1415

1516
/**
1617
* Configuration schema for LLM-based content checks.
@@ -195,15 +196,25 @@ export async function runLLM(
195196
temperature = 1.0;
196197
}
197198

198-
const response = await client.chat.completions.create({
199+
// Build API call parameters
200+
const params: Record<string, unknown> = {
199201
messages: [
200202
{ role: 'system', content: fullPrompt },
201203
{ role: 'user', content: `# Text\n\n${text}` },
202204
],
203205
model: model,
204206
temperature: temperature,
205207
response_format: { type: 'json_object' },
206-
});
208+
};
209+
210+
// Only include safety_identifier for official OpenAI API (not Azure or local providers)
211+
if (supportsSafetyIdentifier(client)) {
212+
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
213+
params.safety_identifier = SAFETY_IDENTIFIER;
214+
}
215+
216+
// @ts-ignore - safety_identifier is not in the OpenAI types yet
217+
const response = await client.chat.completions.create(params);
207218

208219
const result = response.choices[0]?.message?.content;
209220
if (!result) {

src/checks/moderation.ts

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import { z } from 'zod';
2222
import { CheckFn, GuardrailResult } from '../types';
2323
import { defaultSpecRegistry } from '../registry';
2424
import OpenAI from 'openai';
25+
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';
2526

2627
/**
2728
* Enumeration of supported moderation categories.
@@ -78,6 +79,42 @@ export const ModerationContext = z.object({
7879

7980
export type ModerationContext = z.infer<typeof ModerationContext>;
8081

82+
/**
83+
* Check if an error is a 404 Not Found error from the OpenAI API.
84+
*
85+
* @param error The error to check
86+
* @returns True if the error is a 404 error
87+
*/
88+
function isNotFoundError(error: unknown): boolean {
89+
return !!(error && typeof error === 'object' && 'status' in error && error.status === 404);
90+
}
91+
92+
/**
93+
* Call the OpenAI moderation API.
94+
*
95+
* @param client The OpenAI client to use
96+
* @param data The text to analyze
97+
* @returns The moderation API response
98+
*/
99+
function callModerationAPI(
100+
client: OpenAI,
101+
data: string
102+
): ReturnType<OpenAI['moderations']['create']> {
103+
const params: Record<string, unknown> = {
104+
model: 'omni-moderation-latest',
105+
input: data,
106+
};
107+
108+
// Only include safety_identifier for official OpenAI API (not Azure or local providers)
109+
if (supportsSafetyIdentifier(client)) {
110+
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
111+
params.safety_identifier = SAFETY_IDENTIFIER;
112+
}
113+
114+
// @ts-ignore - safety_identifier is not in the OpenAI types yet
115+
return client.moderations.create(params);
116+
}
117+
81118
/**
82119
* Guardrail check_fn to flag disallowed content categories using OpenAI moderation API.
83120
*
@@ -102,39 +139,55 @@ export const moderationCheck: CheckFn<ModerationContext, string, ModerationConfi
102139
const configObj = actualConfig as Record<string, unknown>;
103140
const categories = (configObj.categories as string[]) || Object.values(Category);
104141

105-
// Reuse provided client only if it targets the official OpenAI API.
106-
const reuseClientIfOpenAI = (context: unknown): OpenAI | null => {
107-
try {
108-
const contextObj = context as Record<string, unknown>;
109-
const candidate = contextObj?.guardrailLlm;
110-
if (!candidate || typeof candidate !== 'object') return null;
111-
if (!(candidate instanceof OpenAI)) return null;
112-
113-
const candidateObj = candidate as unknown as Record<string, unknown>;
114-
const baseURL: string | undefined =
115-
(candidateObj.baseURL as string) ??
116-
((candidateObj._client as Record<string, unknown>)?.baseURL as string) ??
117-
(candidateObj._baseURL as string);
118-
119-
if (
120-
baseURL === undefined ||
121-
(typeof baseURL === 'string' && baseURL.includes('api.openai.com'))
122-
) {
123-
return candidate as OpenAI;
124-
}
125-
return null;
126-
} catch {
127-
return null;
142+
// Get client from context if available
143+
let client: OpenAI | null = null;
144+
if (ctx) {
145+
const contextObj = ctx as Record<string, unknown>;
146+
const candidate = contextObj.guardrailLlm;
147+
if (candidate && candidate instanceof OpenAI) {
148+
client = candidate;
128149
}
129-
};
130-
131-
const client = reuseClientIfOpenAI(ctx) ?? new OpenAI();
150+
}
132151

133152
try {
134-
const resp = await client.moderations.create({
135-
model: 'omni-moderation-latest',
136-
input: data,
137-
});
153+
// Try the context client first, fall back if moderation endpoint doesn't exist
154+
let resp: Awaited<ReturnType<typeof callModerationAPI>>;
155+
if (client !== null) {
156+
try {
157+
resp = await callModerationAPI(client, data);
158+
} catch (error) {
159+
160+
// Moderation endpoint doesn't exist on this provider (e.g., third-party)
161+
// Fall back to the OpenAI client
162+
if (isNotFoundError(error)) {
163+
try {
164+
resp = await callModerationAPI(new OpenAI(), data);
165+
} catch (fallbackError) {
166+
// If fallback fails, provide a helpful error message
167+
const errorMessage = fallbackError instanceof Error
168+
? fallbackError.message
169+
: String(fallbackError);
170+
171+
// Check if it's an API key error
172+
if (errorMessage.includes('api_key') || errorMessage.includes('OPENAI_API_KEY')) {
173+
return {
174+
tripwireTriggered: false,
175+
info: {
176+
checked_text: data,
177+
error: 'Moderation API requires OpenAI API key. Set OPENAI_API_KEY environment variable or pass a client with valid credentials.',
178+
},
179+
};
180+
}
181+
throw fallbackError;
182+
}
183+
} else {
184+
throw error;
185+
}
186+
}
187+
} else {
188+
// No context client, use fallback
189+
resp = await callModerationAPI(new OpenAI(), data);
190+
}
138191

139192
const results = resp.results || [];
140193
if (!results.length) {

src/checks/user-defined-llm.ts

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import { z } from 'zod';
1010
import { CheckFn, GuardrailResult } from '../types';
1111
import { defaultSpecRegistry } from '../registry';
12+
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';
1213

1314
/**
1415
* Configuration schema for user-defined LLM moderation checks.
@@ -91,27 +92,45 @@ export const userDefinedLLMCheck: CheckFn<UserDefinedContext, string, UserDefine
9192
// Try with JSON response format first, fall back to text if not supported
9293
let response;
9394
try {
94-
response = await ctx.guardrailLlm.chat.completions.create({
95+
// Build API call parameters
96+
const params: Record<string, unknown> = {
9597
messages: [
9698
{ role: 'system', content: renderedSystemPrompt },
9799
{ role: 'user', content: data },
98100
],
99101
model: config.model,
100102
temperature: 0.0,
101103
response_format: { type: 'json_object' },
102-
});
104+
};
105+
106+
// Only include safety_identifier for official OpenAI API (not Azure or local providers)
107+
if (supportsSafetyIdentifier(ctx.guardrailLlm)) {
108+
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
109+
params.safety_identifier = SAFETY_IDENTIFIER;
110+
}
111+
112+
response = await ctx.guardrailLlm.chat.completions.create(params);
103113
} catch (error: unknown) {
104114
// If JSON response format is not supported, try without it
105115
if (error && typeof error === 'object' && 'error' in error &&
106116
(error as { error?: { param?: string } }).error?.param === 'response_format') {
107-
response = await ctx.guardrailLlm.chat.completions.create({
117+
// Build fallback parameters without response_format
118+
const fallbackParams: Record<string, unknown> = {
108119
messages: [
109120
{ role: 'system', content: renderedSystemPrompt },
110121
{ role: 'user', content: data },
111122
],
112123
model: config.model,
113124
temperature: 0.0,
114-
});
125+
};
126+
127+
// Only include safety_identifier for official OpenAI API (not Azure or local providers)
128+
if (supportsSafetyIdentifier(ctx.guardrailLlm)) {
129+
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
130+
fallbackParams.safety_identifier = SAFETY_IDENTIFIER;
131+
}
132+
133+
response = await ctx.guardrailLlm.chat.completions.create(fallbackParams);
115134
} else {
116135
// Return error information instead of re-throwing
117136
return {

0 commit comments

Comments
 (0)