Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1248,8 +1248,12 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem
},
}
}
// Prompt is not required when type is background_removal
if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") &&
// Prompt is not required for certain operation types that work without a text prompt
var imageEditParamsType *string
if req.Params != nil {
imageEditParamsType = req.Params.Type
}
if !isPromptOptionalImageEditType(imageEditParamsType) &&
(req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down Expand Up @@ -1316,8 +1320,12 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req
},
}
}
// Prompt is not required when type is background_removal
if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") &&
// Prompt is not required for certain operation types that work without a text prompt
var imageEditStreamParamsType *string
if req.Params != nil {
imageEditStreamParamsType = req.Params.Type
}
if !isPromptOptionalImageEditType(imageEditStreamParamsType) &&
(req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down
18 changes: 14 additions & 4 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1953,8 +1953,10 @@ func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostConte
}

// ImageEdit performs image editing using Amazon Bedrock.
// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2.
// Supports three edit types: INPAINTING, OUTPAINTING, and BACKGROUND_REMOVAL.
// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2 (three edit types:
// INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL), and Stability AI edit models (inpaint, outpaint,
// recolor, search-replace, erase-object, remove-bg, control-sketch, control-structure, style-guide,
// style-transfer, upscale-creative, upscale-conservative, upscale-fast).
// Returns a BifrostImageGenerationResponse containing the edited images and any error that occurred.
func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageEditRequest); err != nil {
Expand All @@ -1969,17 +1971,25 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche
var jsonData []byte
var bifrostError *schemas.BifrostError

// Resolve deployment alias before building the request body so that
// Stability AI routing and task-type inference use the actual model ID.
path, deployment := provider.getModelPath("invoke", request.Model, key)

jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageEditRequest(request) },
func() (providerUtils.RequestBodyWithExtraParams, error) {
if isStabilityAIModel(deployment) {
return ToStabilityAIImageEditRequest(request, deployment)
}
return ToBedrockImageEditRequest(request)
},
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
Radheshg04 marked this conversation as resolved.
provider.GetProviderKey())
if bifrostError != nil {
return nil, bifrostError
}

// Make API request (same URL as image generation)
path, deployment := provider.getModelPath("invoke", request.Model, key)
rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key)
if providerResponseHeaders != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
Expand Down
286 changes: 286 additions & 0 deletions core/providers/bedrock/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,292 @@ func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGener
return config
}

// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
func getStabilityAITaskTypeFromParams(t string) string {
switch strings.ToLower(t) {
case "inpainting", "inpaint":
return "inpaint"
case "outpainting", "outpaint":
return "outpaint"
case "background_removal", "remove_background":
return "remove-bg"
case "erase_object":
return "erase-object"
case "upscale_fast":
return "upscale-fast"
case "upscale_creative":
return "upscale-creative"
case "upscale_conservative":
return "upscale-conservative"
case "recolor":
return "recolor"
case "search_replace":
return "search-replace"
case "control_sketch":
return "control-sketch"
case "control_structure":
return "control-structure"
case "style_guide":
return "style-guide"
case "style_transfer":
return "style-transfer"
default:
return ""
Comment thread
Radheshg04 marked this conversation as resolved.
}
}

// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
// Returns an error if the model name does not match any known pattern.
func getStabilityAIEditTaskType(model string) (string, error) {
m := strings.ToLower(model)
switch {
case strings.Contains(m, "stable-creative-upscale"):
return "upscale-creative", nil
case strings.Contains(m, "stable-conservative-upscale"):
return "upscale-conservative", nil
case strings.Contains(m, "stable-fast-upscale"):
return "upscale-fast", nil
case strings.Contains(m, "stable-image-inpaint"):
return "inpaint", nil
case strings.Contains(m, "stable-outpaint"):
return "outpaint", nil
case strings.Contains(m, "stable-image-search-recolor"):
return "recolor", nil
case strings.Contains(m, "stable-image-search-replace"):
return "search-replace", nil
case strings.Contains(m, "stable-image-erase-object"):
return "erase-object", nil
case strings.Contains(m, "stable-image-remove-background"):
return "remove-bg", nil
case strings.Contains(m, "stable-image-control-sketch"):
return "control-sketch", nil
case strings.Contains(m, "stable-image-control-structure"):
return "control-structure", nil
case strings.Contains(m, "stable-image-style-guide"):
return "style-guide", nil
case strings.Contains(m, "stable-style-transfer"):
return "style-transfer", nil
default:
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
// used for task-type inference so that alias-mapped models route correctly.
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}

var taskType string
if request.Params != nil && request.Params.Type != nil {
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
}
if taskType == "" {
var err error
taskType, err = getStabilityAIEditTaskType(deployment)
if err != nil {
return nil, err
}
}
Comment thread
Radheshg04 marked this conversation as resolved.

req := &StabilityAIImageEditRequest{}

// Image sourcing
if taskType == "style-transfer" {
if len(request.Input.Images) != 2 {
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
}
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
req.InitImage = &initB64
req.StyleImage = &styleB64
Comment thread
coderabbitai[bot] marked this conversation as resolved.
} else {
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
req.Image = &imageB64
}

// Common fields populated based on task allowlist
prompt := request.Input.Prompt
switch taskType {
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
req.Prompt = &prompt
}

// Negative prompt
if request.Params != nil && request.Params.NegativePrompt != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.NegativePrompt = request.Params.NegativePrompt
}
}

// Seed
if request.Params != nil && request.Params.Seed != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.Seed = request.Params.Seed
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// Mask (from Params.Mask bytes)
if request.Params != nil && len(request.Params.Mask) > 0 {
switch taskType {
case "inpaint", "erase-object":
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
req.Mask = &maskB64
}
}

// ExtraParams
if request.Params != nil {
// Typed OutputFormat takes priority over ExtraParams
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}

if request.Params.ExtraParams != nil {
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
for k, v := range request.Params.ExtraParams {
ep[k] = v
}

// output_format — all tasks (fallback if not already set by typed field)
if req.OutputFormat == nil {
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
delete(ep, "output_format")
req.OutputFormat = v
}
}

// style_preset
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative":
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
delete(ep, "style_preset")
req.StylePreset = v
}
}

// grow_mask
switch taskType {
case "inpaint", "recolor", "search-replace", "erase-object":
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
delete(ep, "grow_mask")
req.GrowMask = v
}
}

// outpaint directional fields
if taskType == "outpaint" {
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
delete(ep, "left")
req.Left = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
delete(ep, "right")
req.Right = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
delete(ep, "up")
req.Up = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
delete(ep, "down")
req.Down = v
}
}

// creativity
switch taskType {
case "upscale-creative", "upscale-conservative", "outpaint":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
delete(ep, "creativity")
req.Creativity = v
}
}

// select_prompt (recolor)
if taskType == "recolor" {
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
delete(ep, "select_prompt")
req.SelectPrompt = v
}
}

// search_prompt (search-replace)
if taskType == "search-replace" {
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
delete(ep, "search_prompt")
req.SearchPrompt = v
}
}

// control_strength
switch taskType {
case "control-sketch", "control-structure":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
delete(ep, "control_strength")
req.ControlStrength = v
}
}

// style-guide fields
if taskType == "style-guide" {
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
delete(ep, "aspect_ratio")
req.AspectRatio = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
delete(ep, "fidelity")
req.Fidelity = v
}
}

// style-transfer fields
if taskType == "style-transfer" {
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
delete(ep, "style_strength")
req.StyleStrength = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
delete(ep, "composition_fidelity")
req.CompositionFidelity = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
delete(ep, "change_strength")
req.ChangeStrength = v
}
}

req.ExtraParams = ep
Comment thread
Radheshg04 marked this conversation as resolved.
}
}

// Validate required per-task fields
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
}
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
}

return req, nil
}

// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
if response == nil {
Expand Down
Loading
Loading