Skip to content
37 changes: 30 additions & 7 deletions src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,8 @@ public sealed interface PromptMessageContent {
/**
* Represents prompt message content that is either text or an image.
*/
@Serializable(with = PromptMessageContentTextOrImagePolymorphicSerializer::class)
public sealed interface PromptMessageContentTextOrImage : PromptMessageContent
@Serializable(with = PromptMessageContentMultimodalPolymorphicSerializer::class)
public sealed interface PromptMessageContentMultimodal : PromptMessageContent

/**
* Text provided to or from an LLM.
Expand All @@ -898,7 +898,7 @@ public data class TextContent(
* The text content of the message.
*/
val text: String? = null,
) : PromptMessageContentTextOrImage {
) : PromptMessageContentMultimodal {
override val type: String = TYPE

public companion object {
Expand All @@ -920,21 +920,44 @@ public data class ImageContent(
* The MIME type of the image. Different providers may support different image types.
*/
val mimeType: String,
) : PromptMessageContentTextOrImage {
) : PromptMessageContentMultimodal {
override val type: String = TYPE

public companion object {
public const val TYPE: String = "image"
}
}

/**
* Audio provided to or from an LLM.
*/
@Serializable
public data class AudioContent(
/**
* The base64-encoded audio data.
*/
val data: String,

/**
* The MIME type of the audio. Different providers may support different audio types.
*/
val mimeType: String,
) : PromptMessageContentMultimodal {
override val type: String = TYPE

public companion object {
public const val TYPE: String = "audio"
}
}


/**
* An image provided to or from an LLM.
*/
@Serializable
public data class UnknownContent(
override val type: String,
) : PromptMessageContentTextOrImage
) : PromptMessageContentMultimodal

/**
* The contents of a resource, embedded into a prompt or tool call result.
Expand Down Expand Up @@ -1204,7 +1227,7 @@ public class ModelPreferences(
@Serializable
public data class SamplingMessage(
val role: Role,
val content: PromptMessageContentTextOrImage,
val content: PromptMessageContentMultimodal,
)

/**
Expand Down Expand Up @@ -1286,7 +1309,7 @@ public data class CreateMessageResult(
*/
val stopReason: StopReason? = null,
val role: Role,
val content: PromptMessageContentTextOrImage,
val content: PromptMessageContentMultimodal,
override val _meta: JsonObject = EmptyJsonObject,
) : ClientResult

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,19 @@ internal object PromptMessageContentPolymorphicSerializer :
ImageContent.TYPE -> ImageContent.serializer()
TextContent.TYPE -> TextContent.serializer()
EmbeddedResource.TYPE -> EmbeddedResource.serializer()
AudioContent.TYPE -> AudioContent.serializer()
else -> UnknownContent.serializer()
}
}
}

internal object PromptMessageContentTextOrImagePolymorphicSerializer :
JsonContentPolymorphicSerializer<PromptMessageContentTextOrImage>(PromptMessageContentTextOrImage::class) {
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<PromptMessageContentTextOrImage> {
internal object PromptMessageContentMultimodalPolymorphicSerializer :
JsonContentPolymorphicSerializer<PromptMessageContentMultimodal>(PromptMessageContentMultimodal::class) {
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<PromptMessageContentMultimodal> {
return when (element.jsonObject.getValue("type").jsonPrimitive.content) {
ImageContent.TYPE -> ImageContent.serializer()
TextContent.TYPE -> TextContent.serializer()
AudioContent.TYPE -> AudioContent.serializer()
else -> UnknownContent.serializer()
}
}
Expand Down