Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add PromptTemplate type #5787

Merged
merged 13 commits into from
Dec 20, 2024
31 changes: 30 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,11 @@ type JSONInvocationParameter implements InvocationParameterBase {
defaultValue: JSON
}

type JSONPromptMessage {
role: PromptMessageRole!
content: JSON!
}

type LabelFraction {
label: String!
fraction: Float!
Expand Down Expand Up @@ -1377,6 +1382,11 @@ type Prompt implements Node {
promptVersions(first: Int = 50, last: Int, after: String, before: String): PromptVersionConnection!
}

type PromptChatTemplateV1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
type PromptChatTemplateV1 {
type PromptChatTemplate {

version: String!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
version: String!
__version: String!

messages: [TextPromptMessageJSONPromptMessage!]!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
messages: [TextPromptMessageJSONPromptMessage!]!
messages: [PromptTemplateMessages!]!

}

"""A connection to a list of items."""
type PromptConnection {
"""Pagination data for this connection"""
Expand All @@ -1395,6 +1405,12 @@ type PromptEdge {
node: Prompt!
}

enum PromptMessageRole {
USER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tool

SYSTEM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a description that this includes the openAI developer role

AI
}

type PromptResponse {
"""The prompt submitted to the LLM"""
prompt: String
Expand All @@ -1403,6 +1419,10 @@ type PromptResponse {
response: String
}

type PromptStringTemplate {
template: String!
}

enum PromptTemplateFormat {
MUSTACHE
FSTRING
Expand All @@ -1414,14 +1434,16 @@ enum PromptTemplateType {
CHAT
}

union PromptTemplateVersion = PromptStringTemplate | PromptChatTemplateV1

type PromptVersion implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
user: String
description: String!
templateType: PromptTemplateType!
templateFormat: PromptTemplateFormat!
template: JSON!
template: PromptTemplateVersion!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The union type should be here

invocationParameters: JSON
tools: JSON
outputSchema: JSON
Expand Down Expand Up @@ -1783,6 +1805,13 @@ type TextChunk implements ChatCompletionSubscriptionPayload {
content: String!
}

type TextPromptMessage {
role: PromptMessageRole!
content: String!
}

union TextPromptMessageJSONPromptMessage = TextPromptMessage | JSONPromptMessage
Copy link
Contributor

@axiomofjoy axiomofjoy Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name this union type as PromptMessage?


input TimeRange {
"""The start of the time range"""
start: DateTime!
Expand Down
Empty file.
51 changes: 51 additions & 0 deletions src/phoenix/server/api/helpers/prompthub/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from enum import Enum
from typing import Any, Union

import strawberry
from pydantic import BaseModel

JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]


@strawberry.enum
class PromptMessageRole(str, Enum):
USER = "user"
SYSTEM = "system"
AI = "ai" # E.g. the assistant. Normalize to AI for consistency.


class TextPromptMessage(BaseModel):
role: PromptMessageRole
content: str


class JSONPromptMessage(BaseModel):
role: PromptMessageRole
content: JSONSerializable


class PromptChatTemplateV1(BaseModel):
_version: str = "messages-v1"
template: list[Union[TextPromptMessage, JSONPromptMessage]]


class PromptStringTemplate(BaseModel):
template: str


# TODO: Figure out enums, maybe just store whole tool blobs
# class PromptToolParameter(BaseModel):
# name: str
# type: str
# description: str
# required: bool
# default: str


class PromptToolDefinition(BaseModel):
definition: JSONSerializable


class PromptTools(BaseModel):
_version: str = "tools-v1"
tools: list[PromptToolDefinition]
42 changes: 17 additions & 25 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
PromptTemplateType,
PromptVersion,
)
from phoenix.server.api.types.PromptVersionTemplate import (
PromptChatTemplateV1,
PromptMessageRole,
TextPromptMessage,
)
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.SystemApiKey import SystemApiKey
Expand Down Expand Up @@ -539,29 +544,24 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
created_at=datetime.now(),
)
elif type_name == PromptVersion.__name__:
template = PromptChatTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
content="Hello what's the weather in Antarctica like?",
)
]
)

if node_id == 2:
return PromptVersion(
id_attr=2,
user="alice",
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template={
"_version": "messages-v1",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{
"role": "user",
"content": "Hello what's the weather in {{location}} like?",
},
{"role": "ai", "content": "Looking up the weather in {{location}}..."},
],
},
invocation_parameters={
"temperature": 0.5,
"model": "gpt-4o",
"max_tokens": 100,
},
template=template,
invocation_parameters={"temperature": 0.5},
tools={
"_version": "tools-v1",
"tools": [
Expand Down Expand Up @@ -600,15 +600,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template={
"_version": "messages-v1",
"messages": [
{
"role": "user",
"content": "Hello what's the weather in {{location}} like?",
}
],
},
template=template,
invocation_parameters=None,
tools=None,
output_schema=None,
Expand Down
48 changes: 22 additions & 26 deletions src/phoenix/server/api/types/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
from strawberry.types import Info

from phoenix.server.api.context import Context
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
from phoenix.server.api.types.pagination import (
ConnectionArgs,
CursorString,
connection_from_list,
)
from phoenix.server.api.types.PromptVersionTemplate import (
PromptChatTemplateV1,
PromptMessageRole,
TextPromptMessage,
)

from .PromptVersion import PromptTemplateFormat, PromptTemplateType, PromptVersion

Expand Down Expand Up @@ -37,29 +46,24 @@ async def prompt_versions(
before=before if isinstance(before, CursorString) else None,
)

template = PromptChatTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
content="Hello what's the weather in Antarctica like?",
)
]
)

dummy_data = [
PromptVersion(
id_attr=2,
user="alice",
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template={
"_version": "messages-v1",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{
"role": "user",
"content": "Hello what's the weather in {{location}} like?",
},
{"role": "ai", "content": "Looking up the weather in {{location}}..."},
],
},
invocation_parameters={
"temperature": 0.5,
"model": "gpt-4o",
"max_tokens": 100,
},
template=template,
invocation_parameters={"temperature": 0.5},
tools={
"_version": "tools-v1",
"tools": [
Expand Down Expand Up @@ -96,15 +100,7 @@ async def prompt_versions(
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template={
"_version": "messages-v1",
"messages": [
{
"role": "user",
"content": "Hello what's the weather in {{location}} like?",
}
],
},
template=template,
model_name="gpt-4o",
model_provider="openai",
),
Expand Down
4 changes: 3 additions & 1 deletion src/phoenix/server/api/types/PromptVersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from strawberry.relay import Node, NodeID
from strawberry.scalars import JSON

from phoenix.server.api.types.PromptVersionTemplate import PromptTemplate


@strawberry.enum
class PromptTemplateType(str, Enum):
Expand All @@ -28,7 +30,7 @@ class PromptVersion(Node):
description: str
template_type: PromptTemplateType
template_format: PromptTemplateFormat
template: JSON
template: PromptTemplate
invocation_parameters: Optional[JSON] = None
tools: Optional[JSON] = None
output_schema: Optional[JSON] = None
Expand Down
38 changes: 38 additions & 0 deletions src/phoenix/server/api/types/PromptVersionTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Part of the Phoenix PromptHub feature set

from typing import Union

import strawberry
from strawberry.scalars import JSON

from phoenix.server.api.helpers.prompthub.models import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from phoenix.server.api.helpers.prompthub.models import (
from phoenix.server.api.helpers.prompts.models import (

PromptMessageRole,
)


@strawberry.type
class TextPromptMessage:
role: PromptMessageRole
content: str


@strawberry.type
class JSONPromptMessage:
role: PromptMessageRole
content: JSON
Comment on lines +18 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for tool call messages, images, etc.?



@strawberry.type
class PromptChatTemplateV1:
version: str = "messages-v1"
messages: list[Union[TextPromptMessage, JSONPromptMessage]]


@strawberry.type
class PromptStringTemplate:
template: str


PromptTemplate = strawberry.union(
"PromptTemplateVersion", (PromptStringTemplate, PromptChatTemplateV1)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just name PromptTemplateVersion to match the underlying GraphQL type.

Loading