Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*_creds*
**/venv/
**/__pycache__/
**/__pycache__/**
107 changes: 103 additions & 4 deletions transports/bifrost-http/integrations/genai/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,103 @@ import (

var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction))

// CustomBlob handles URL-safe base64 decoding for Google GenAI requests
type CustomBlob struct {
Data []byte `json:"data,omitempty"`
MIMEType string `json:"mimeType,omitempty"`
}

// UnmarshalJSON custom unmarshalling to handle URL-safe base64 encoding
func (b *CustomBlob) UnmarshalJSON(data []byte) error {
// First unmarshal into a temporary struct with string data
var temp struct {
Data string `json:"data,omitempty"`
MIMEType string `json:"mimeType,omitempty"`
}

if err := json.Unmarshal(data, &temp); err != nil {
return err
}

b.MIMEType = temp.MIMEType

if temp.Data != "" {
// Convert URL-safe base64 to standard base64
standardBase64 := strings.ReplaceAll(strings.ReplaceAll(temp.Data, "_", "/"), "-", "+")

// Add padding if necessary
switch len(standardBase64) % 4 {
case 2:
standardBase64 += "=="
case 3:
standardBase64 += "="
}

decoded, err := base64.StdEncoding.DecodeString(standardBase64)
if err != nil {
return fmt.Errorf("failed to decode base64 data: %v", err)
}
b.Data = decoded
}

return nil
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// CustomPart handles Google GenAI Part with custom Blob unmarshalling
type CustomPart struct {
VideoMetadata *genai_sdk.VideoMetadata `json:"videoMetadata,omitempty"`
Thought bool `json:"thought,omitempty"`
CodeExecutionResult *genai_sdk.CodeExecutionResult `json:"codeExecutionResult,omitempty"`
ExecutableCode *genai_sdk.ExecutableCode `json:"executableCode,omitempty"`
FileData *genai_sdk.FileData `json:"fileData,omitempty"`
FunctionCall *genai_sdk.FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *genai_sdk.FunctionResponse `json:"functionResponse,omitempty"`
InlineData *CustomBlob `json:"inlineData,omitempty"`
Text string `json:"text,omitempty"`
}

// ToGenAIPart converts CustomPart to genai_sdk.Part
func (p *CustomPart) ToGenAIPart() *genai_sdk.Part {
part := &genai_sdk.Part{
VideoMetadata: p.VideoMetadata,
Thought: p.Thought,
CodeExecutionResult: p.CodeExecutionResult,
ExecutableCode: p.ExecutableCode,
FileData: p.FileData,
FunctionCall: p.FunctionCall,
FunctionResponse: p.FunctionResponse,
Text: p.Text,
}

if p.InlineData != nil {
part.InlineData = &genai_sdk.Blob{
Data: p.InlineData.Data,
MIMEType: p.InlineData.MIMEType,
}
}

return part
}

// CustomContent handles Google GenAI Content with custom Part unmarshalling
type CustomContent struct {
Parts []*CustomPart `json:"parts,omitempty"`
Role string `json:"role,omitempty"`
}

// ToGenAIContent converts CustomContent to genai_sdk.Content
func (c *CustomContent) ToGenAIContent() genai_sdk.Content {
parts := make([]*genai_sdk.Part, len(c.Parts))
for i, part := range c.Parts {
parts[i] = part.ToGenAIPart()
}

return genai_sdk.Content{
Parts: parts,
Role: c.Role,
}
}

// ensureExtraParams ensures that bifrostReq.Params and bifrostReq.Params.ExtraParams are initialized
func ensureExtraParams(bifrostReq *schemas.BifrostRequest) {
if bifrostReq.Params == nil {
Expand All @@ -27,8 +124,8 @@ func ensureExtraParams(bifrostReq *schemas.BifrostRequest) {

type GeminiChatRequest struct {
Model string `json:"model,omitempty"` // Model field for explicit model specification
Contents []genai_sdk.Content `json:"contents"`
SystemInstruction *genai_sdk.Content `json:"systemInstruction,omitempty"`
Contents []CustomContent `json:"contents"`
SystemInstruction *CustomContent `json:"systemInstruction,omitempty"`
GenerationConfig genai_sdk.GenerationConfig `json:"generationConfig,omitempty"`
SafetySettings []genai_sdk.SafetySetting `json:"safetySettings,omitempty"`
Tools []genai_sdk.Tool `json:"tools,omitempty"`
Expand All @@ -51,9 +148,11 @@ func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest {

allGenAiMessages := []genai_sdk.Content{}
if r.SystemInstruction != nil {
allGenAiMessages = append(allGenAiMessages, *r.SystemInstruction)
allGenAiMessages = append(allGenAiMessages, r.SystemInstruction.ToGenAIContent())
}
for _, content := range r.Contents {
allGenAiMessages = append(allGenAiMessages, content.ToGenAIContent())
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
allGenAiMessages = append(allGenAiMessages, r.Contents...)

for _, content := range allGenAiMessages {
if len(content.Parts) == 0 {
Expand Down
18 changes: 5 additions & 13 deletions transports/bifrost-http/integrations/openai/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters {

// Direct field mapping
if r.MaxTokens != nil {
params.ExtraParams["max_tokens"] = *r.MaxTokens
params.MaxTokens = r.MaxTokens
}
if r.Temperature != nil {
params.ExtraParams["temperature"] = *r.Temperature
params.Temperature = r.Temperature
}
if r.TopP != nil {
params.ExtraParams["top_p"] = *r.TopP
params.TopP = r.TopP
}
if r.PresencePenalty != nil {
params.ExtraParams["presence_penalty"] = *r.PresencePenalty
params.PresencePenalty = r.PresencePenalty
}
if r.FrequencyPenalty != nil {
params.ExtraParams["frequency_penalty"] = *r.FrequencyPenalty
params.FrequencyPenalty = r.FrequencyPenalty
}
if r.N != nil {
params.ExtraParams["n"] = *r.N
Expand All @@ -101,18 +101,10 @@ func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters {
if r.Stream != nil {
params.ExtraParams["stream"] = *r.Stream
}
if r.ResponseFormat != nil {
params.ExtraParams["response_format"] = r.ResponseFormat
}
if r.Seed != nil {
params.ExtraParams["seed"] = *r.Seed
}

// Return nil if no parameters were set
if len(params.ExtraParams) == 0 {
return nil
}

return params
}

Expand Down