Skip to content

Commit 7a18e60

Browse files
committed
bedrock cleanup
1 parent ce39234 commit 7a18e60

File tree

6 files changed

+58
-25
lines changed

6 files changed

+58
-25
lines changed

src/providers/bedrock/api.ts

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { Context } from 'hono';
2-
import { Options } from '../../types/requestBody';
2+
import { Options, Params } from '../../types/requestBody';
33
import { endpointStrings, ProviderAPIConfig } from '../types';
44
import { bedrockInvokeModels } from './constants';
55
import {
6+
getAwsEndpointDomain,
67
generateAWSHeaders,
7-
getAssumedRoleCredentials,
88
getFoundationModelFromInferenceProfile,
99
providerAssumedRoleCredentials,
1010
} from './utils';
@@ -18,6 +18,7 @@ interface BedrockAPIConfigInterface extends Omit<ProviderAPIConfig, 'headers'> {
1818
transformedRequestBody: Record<string, any> | string;
1919
transformedRequestUrl: string;
2020
gatewayRequestBody?: Params;
21+
headers?: Record<string, string>;
2122
}) => Promise<Record<string, any>> | Record<string, any>;
2223
}
2324

@@ -66,7 +67,14 @@ const ENDPOINTS_TO_ROUTE_TO_S3 = [
6667
'initiateMultipartUpload',
6768
];
6869

69-
const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
70+
const getMethod = (
71+
fn: endpointStrings,
72+
transformedRequestUrl: string,
73+
c: Context
74+
) => {
75+
if (fn === 'proxy') {
76+
return c.req.method;
77+
}
7078
if (fn === 'uploadFile') {
7179
const url = new URL(transformedRequestUrl);
7280
return url.searchParams.get('partNumber') ? 'PUT' : 'POST';
@@ -121,36 +129,47 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
121129
gatewayRequestURL.split('/v1/files/')[1]
122130
);
123131
const bucketName = s3URL.replace('s3://', '').split('/')[0];
124-
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
132+
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
125133
}
126134
if (fn === 'retrieveFileContent') {
127135
const s3URL = decodeURIComponent(
128136
gatewayRequestURL.split('/v1/files/')[1]
129137
);
130138
const bucketName = s3URL.replace('s3://', '').split('/')[0];
131-
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
139+
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
132140
}
133141
if (fn === 'uploadFile')
134-
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
142+
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
135143
const isAWSControlPlaneEndpoint =
136144
fn && AWS_CONTROL_PLANE_ENDPOINTS.includes(fn);
137-
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
145+
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
138146
},
139147
headers: async ({
140148
c,
141149
fn,
142150
providerOptions,
143151
transformedRequestBody,
144152
transformedRequestUrl,
153+
gatewayRequestBody, // for proxy use the passed body blindly
154+
headers: requestHeaders,
145155
}) => {
146-
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
147-
const service = getService(fn as endpointStrings);
156+
const { awsService } = providerOptions;
157+
const method =
158+
c.get('method') || // method set specifically into context
159+
getMethod(fn as endpointStrings, transformedRequestUrl, c); // method calculated
160+
const service = awsService || getService(fn as endpointStrings);
148161

149-
const headers: Record<string, string> = {
150-
'content-type': 'application/json',
151-
};
162+
let headers: Record<string, string> = {};
152163

153-
if (method === 'PUT' || method === 'GET') {
164+
if (fn === 'proxy' && service !== 'bedrock') {
165+
headers = { ...(requestHeaders ?? {}) };
166+
} else {
167+
headers = {
168+
'content-type': 'application/json',
169+
};
170+
}
171+
172+
if ((method === 'PUT' || method === 'GET') && fn !== 'proxy') {
154173
delete headers['content-type'];
155174
}
156175

@@ -160,7 +179,8 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
160179
await providerAssumedRoleCredentials(c, providerOptions);
161180
}
162181

163-
let finalRequestBody = transformedRequestBody;
182+
let finalRequestBody =
183+
fn === 'proxy' ? gatewayRequestBody : transformedRequestBody;
164184

165185
if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {
166186
// Cancel doesn't require any body, but fetch is sending empty body, to match the signature this block is required.
@@ -183,7 +203,6 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
183203
fn,
184204
gatewayRequestBodyJSON: gatewayRequestBody,
185205
gatewayRequestURL,
186-
c,
187206
}) => {
188207
if (fn === 'retrieveFile') {
189208
const fileId = decodeURIComponent(

src/providers/bedrock/listBatches.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@ export const BedrockListBatchesResponseTransform = (
2828
output_file_id: encodeURIComponent(
2929
batch.outputDataConfig.s3OutputDataConfig.s3Uri
3030
),
31-
finalizing_at: batch.endTime
32-
? new Date(batch.endTime).getTime()
33-
: undefined,
34-
expires_at: batch.jobExpirationTime
35-
? new Date(batch.jobExpirationTime).getTime()
36-
: undefined,
31+
finalizing_at: new Date(batch.endTime).getTime(),
32+
expires_at: new Date(batch.jobExpirationTime).getTime(),
3733
}));
3834

3935
return {

src/providers/bedrock/listFinetunes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export const BedrockListFinetuneResponseTransform: (
1010
if (responseStatus !== 200) {
1111
return BedrockErrorResponseTransform(response) || response;
1212
}
13+
1314
const records =
1415
response?.modelCustomizationJobSummaries as BedrockFinetuneRecord[];
1516
const openaiRecords = records.map(bedrockFinetuneToOpenAI);

src/providers/bedrock/types.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ export interface BedrockInferenceProfile {
8181
type: string;
8282
}
8383

84+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
85+
export enum BEDROCK_STOP_REASON {
86+
end_turn = 'end_turn',
87+
tool_use = 'tool_use',
88+
max_tokens = 'max_tokens',
89+
stop_sequence = 'stop_sequence',
90+
guardrail_intervened = 'guardrail_intervened',
91+
content_filtered = 'content_filtered',
92+
}
93+
8494
export interface BedrockMessagesParams extends MessageCreateParamsBase {
8595
additionalModelRequestFields?: Record<string, any>;
8696
additional_model_request_fields?: Record<string, any>;

src/providers/bedrock/uploadFileUtils.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,10 @@ interface BedrockAnthropicChatCompleteResponse {
830830
stop_reason: string;
831831
model: string;
832832
stop_sequence: null | string;
833+
usage: {
834+
input_tokens: number;
835+
output_tokens: number;
836+
};
833837
}
834838

835839
export const BedrockAnthropicChatCompleteResponseTransform: (
@@ -874,9 +878,10 @@ export const BedrockAnthropicChatCompleteResponseTransform: (
874878
},
875879
],
876880
usage: {
877-
prompt_tokens: 0,
878-
completion_tokens: 0,
879-
total_tokens: 0,
881+
prompt_tokens: response.usage.input_tokens,
882+
completion_tokens: response.usage.output_tokens,
883+
total_tokens:
884+
response.usage.input_tokens + response.usage.output_tokens,
880885
},
881886
};
882887
}
@@ -933,6 +938,7 @@ export const BedrockMistralChatCompleteResponseTransform: (
933938
finish_reason: response.outputs[0].stop_reason,
934939
},
935940
],
941+
// mistral not sending usage.
936942
usage: {
937943
prompt_tokens: 0,
938944
completion_tokens: 0,

src/providers/bedrock/utils.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ import { GatewayError } from '../../errors/GatewayError';
1313
import { BedrockFinetuneRecord, BedrockInferenceProfile } from './types';
1414
import { FinetuneRequest } from '../types';
1515
import { BEDROCK } from '../../globals';
16+
import { Environment } from '../../utils/env';
1617

1718
export const getAwsEndpointDomain = (c: Context) =>
18-
env(c).AWS_ENDPOINT_DOMAIN || 'amazonaws.com';
19+
Environment(c).AWS_ENDPOINT_DOMAIN || 'amazonaws.com';
1920

2021
export const generateAWSHeaders = async (
2122
body: Record<string, any> | string | undefined,

0 commit comments

Comments
 (0)