Skip to content
Merged
26 changes: 26 additions & 0 deletions FirebaseAI/Sources/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public struct GenerationConfig: Sendable {
/// Output schema of the generated candidate text.
let responseSchema: Schema?

/// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format.
///
/// If set, `responseSchema` must be omitted and `responseMIMEType` is required.
let responseJSONSchema: JSONObject?

/// Supported modalities of the response.
let responseModalities: [ResponseModality]?

Expand Down Expand Up @@ -175,6 +180,26 @@ public struct GenerationConfig: Sendable {
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
self.responseSchema = responseSchema
responseJSONSchema = nil
self.responseModalities = responseModalities
self.thinkingConfig = thinkingConfig
}

init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil,
maxOutputTokens: Int? = nil, presencePenalty: Float? = nil, frequencyPenalty: Float? = nil,
stopSequences: [String]? = nil, responseMIMEType: String, responseJSONSchema: JSONObject,
responseModalities: [ResponseModality]? = nil, thinkingConfig: ThinkingConfig? = nil) {
self.temperature = temperature
self.topP = topP
self.topK = topK
self.candidateCount = candidateCount
self.maxOutputTokens = maxOutputTokens
self.presencePenalty = presencePenalty
self.frequencyPenalty = frequencyPenalty
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
responseSchema = nil
self.responseJSONSchema = responseJSONSchema
self.responseModalities = responseModalities
self.thinkingConfig = thinkingConfig
}
Expand All @@ -195,6 +220,7 @@ extension GenerationConfig: Encodable {
case stopSequences
case responseMIMEType = "responseMimeType"
case responseSchema
case responseJSONSchema = "responseJsonSchema"
case responseModalities
case thinkingConfig
}
Expand Down
270 changes: 248 additions & 22 deletions FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ struct SchemaTests {
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentJSONSchemaItems(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("array"),
"description": .string("A list of city names"),
"items": .object([
"type": .string("string"),
"description": .string("The name of the city"),
]),
"minItems": .number(3),
"maxItems": .number(5),
]
),
safetySettings: safetySettings
)
let prompt = "What are the biggest cities in Canada?"
let response = try await model.generateContent(prompt)
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedJSON = try JSONDecoder().decode([String].self, from: jsonData)
#expect(decodedJSON.count >= 3, "Expected at least 3 cities, but got \(decodedJSON.count)")
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
Expand All @@ -96,14 +124,41 @@ struct SchemaTests {
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentJSONSchemaNumberRange(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("integer"),
"description": .string("A number"),
"minimum": .number(110),
"maximum": .number(120),
]
),
safetySettings: safetySettings
)
let prompt = "Give me a number"

let response = try await model.generateContent(prompt)

let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedNumber = try JSONDecoder().decode(Double.self, from: jsonData)
#expect(decodedNumber >= 110.0, "Expected a number >= 110, but got \(decodedNumber)")
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
}

private struct ProductInfo: Codable {
let productName: String
let rating: Int
let price: Double
let salePrice: Float
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
struct ProductInfo: Codable {
let productName: String
let rating: Int // Will correspond to .integer in schema
let price: Double // Will correspond to .double in schema
let salePrice: Float // Will correspond to .float in schema
}
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
generationConfig: GenerationConfig(
Expand Down Expand Up @@ -150,28 +205,95 @@ struct SchemaTests {
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
struct MailingAddress: Decodable {
let streetAddress: String
let city: String
func generateContentJSONSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_FlashLite,
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("object"),
"title": .string("ProductInfo"),
"properties": .object([
"productName": .object([
"type": .string("string"),
"description": .string("The name of the product"),
]),
"price": .object([
"type": .string("number"),
"description": .string("A price"),
"minimum": .number(10.00),
"maximum": .number(120.00),
]),
"salePrice": .object([
"type": .string("number"),
"description": .string("A sale price"),
"minimum": .number(5.00),
"maximum": .number(90.00),
]),
"rating": .object([
"type": .string("integer"),
"description": .string("A rating"),
"minimum": .number(1),
"maximum": .number(5),
]),
]),
"required": .array([
.string("productName"),
.string("price"),
.string("salePrice"),
.string("rating"),
]),
"propertyOrdering": .array([
.string("salePrice"),
.string("rating"),
.string("price"),
.string("productName"),
]),
]
),
safetySettings: safetySettings
)
let prompt = "Describe a premium wireless headphone, including a user rating and price."

// Canadian-specific
let province: String?
let postalCode: String?
let response = try await model.generateContent(prompt)

// U.S.-specific
let state: String?
let zipCode: String?
let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
let jsonData = try #require(text.data(using: .utf8))
let decodedProduct = try JSONDecoder().decode(ProductInfo.self, from: jsonData)
let price = decodedProduct.price
let salePrice = decodedProduct.salePrice
let rating = decodedProduct.rating
#expect(price >= 10.0, "Expected a price >= 10.00, but got \(price)")
#expect(price <= 120.0, "Expected a price <= 120.00, but got \(price)")
#expect(salePrice >= 5.0, "Expected a salePrice >= 5.00, but got \(salePrice)")
#expect(salePrice <= 90.0, "Expected a salePrice <= 90.00, but got \(salePrice)")
#expect(rating >= 1, "Expected a rating >= 1, but got \(rating)")
#expect(rating <= 5, "Expected a rating <= 5, but got \(rating)")
}

private struct MailingAddress: Decodable {
let streetAddress: String
let city: String

// Canadian-specific
let province: String?
let postalCode: String?

var isCanadian: Bool {
return province != nil && postalCode != nil && state == nil && zipCode == nil
}
// U.S.-specific
let state: String?
let zipCode: String?

var isAmerican: Bool {
return province == nil && postalCode == nil && state != nil && zipCode != nil
}
var isCanadian: Bool {
return province != nil && postalCode != nil && state == nil && zipCode == nil
}

var isAmerican: Bool {
return province == nil && postalCode == nil && state != nil && zipCode != nil
}
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
let streetSchema = Schema.string(description:
"The civic number and street name, for example, '123 Main Street'.")
let citySchema = Schema.string(description: "The name of the city.")
Expand Down Expand Up @@ -232,4 +354,108 @@ struct SchemaTests {
"Expected Canadian Queen's University address, got \(queensAddress)."
)
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContentAnyOfJSONSchema(_ config: InstanceConfig) async throws {
let streetSchema: JSONValue = .object([
"type": .string("string"),
"description": .string("The civic number and street name, for example, '123 Main Street'."),
])
let citySchema: JSONValue = .object([
"type": .string("string"),
"description": .string("The name of the city."),
])
let canadianAddressSchema: JSONObject = [
"type": .string("object"),
"description": .string("A Canadian mailing address"),
"properties": .object([
"streetAddress": streetSchema,
"city": citySchema,
"province": .object([
"type": .string("string"),
"description": .string(
"The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."
),
]),
"postalCode": .object([
"type": .string("string"),
"description": .string("The postal code, for example, 'A1A 1A1'."),
]),
]),
"required": .array([
.string("streetAddress"),
.string("city"),
.string("province"),
.string("postalCode"),
]),
]
let americanAddressSchema: JSONObject = [
"type": .string("object"),
"description": .string("A U.S. mailing address"),
"properties": .object([
"streetAddress": streetSchema,
"city": citySchema,
"state": .object([
"type": .string("string"),
"description": .string(
"The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."
),
]),
"zipCode": .object([
"type": .string("string"),
"description": .string("The 5-digit ZIP code, for example, '12345'."),
]),
]),
"required": .array([
.string("streetAddress"),
.string("city"),
.string("state"),
.string("zipCode"),
]),
]
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2_5_Flash,
generationConfig: GenerationConfig(
temperature: 0.0,
topP: 0.0,
topK: 1,
responseMIMEType: "application/json",
responseJSONSchema: [
"type": .string("array"),
"items": .object([
"anyOf": .array([
.object(canadianAddressSchema),
.object(americanAddressSchema),
]),
]),
]
),
safetySettings: safetySettings
)
let prompt = """
What are the mailing addresses for the University of Waterloo, UC Berkeley and Queen's U?
"""

let response = try await model.generateContent(prompt)

let text = try #require(response.text)
let jsonData = try #require(text.data(using: .utf8))
let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData)
try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).")
let waterlooAddress = decodedAddresses[0]
#expect(
waterlooAddress.isCanadian,
"Expected Canadian University of Waterloo address, got \(waterlooAddress)."
)
let berkeleyAddress = decodedAddresses[1]
#expect(
berkeleyAddress.isAmerican,
"Expected American UC Berkeley address, got \(berkeleyAddress)."
)
let queensAddress = decodedAddresses[2]
#expect(
queensAddress.isCanadian,
"Expected Canadian Queen's University address, got \(queensAddress)."
)
}
}
Loading
Loading