Skip to content

Commit

Permalink
feat: Add PromptTemplate type (#5787)
Browse files Browse the repository at this point in the history
* Flesh out PromptVersionTemplate type

* Return GraphQL objects with the correct type

* Use new types in node query

* Decouple pydantic models and gql types

* Rebuild gql schema

* Rework model names

* Update gql schema

* Propagate name into schema

* Incorporate feedback

* Update schema

* adjust UI to new schema

* cleanup

* Remove `hub` naming and clean up type annotations

---------

Co-authored-by: Mikyo King <[email protected]>
  • Loading branch information
anticorrelator and mikeldking committed Feb 19, 2025
1 parent a0d10ed commit 7221d79
Show file tree
Hide file tree
Showing 14 changed files with 488 additions and 169 deletions.
32 changes: 31 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,11 @@ type JSONInvocationParameter implements InvocationParameterBase {
defaultValue: JSON
}

type JSONPromptMessage {
role: PromptMessageRole!
content: JSON!
}

type LabelFraction {
label: String!
fraction: Float!
Expand Down Expand Up @@ -1385,6 +1390,11 @@ type Prompt implements Node {
promptVersions(first: Int = 50, last: Int, after: String, before: String): PromptVersionConnection!
}

type PromptChatTemplate {
Version: String!
messages: [PromptTemplateMessage!]!
}

"""A connection to a list of items."""
type PromptConnection {
"""Pagination data for this connection"""
Expand All @@ -1403,6 +1413,13 @@ type PromptEdge {
node: Prompt!
}

enum PromptMessageRole {
USER
SYSTEM
AI
TOOL
}

type PromptResponse {
"""The prompt submitted to the LLM"""
prompt: String
Expand All @@ -1411,12 +1428,20 @@ type PromptResponse {
response: String
}

type PromptStringTemplate {
template: String!
}

union PromptTemplate = PromptStringTemplate | PromptChatTemplate

enum PromptTemplateFormat {
MUSTACHE
FSTRING
NONE
}

union PromptTemplateMessage = TextPromptMessage | JSONPromptMessage

enum PromptTemplateType {
STRING
CHAT
Expand All @@ -1429,7 +1454,7 @@ type PromptVersion implements Node {
description: String!
templateType: PromptTemplateType!
templateFormat: PromptTemplateFormat!
template: JSON!
template: PromptTemplate!
invocationParameters: JSON
tools: JSON
outputSchema: JSON
Expand Down Expand Up @@ -1792,6 +1817,11 @@ type TextChunk implements ChatCompletionSubscriptionPayload {
content: String!
}

type TextPromptMessage {
role: PromptMessageRole!
content: String!
}

input TimeRange {
"""The start of the time range"""
start: DateTime = null
Expand Down
15 changes: 8 additions & 7 deletions app/src/hooks/useChatMessageStyles.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import { useMemo } from "react";

import { ViewProps } from "@phoenix/components";
import { ViewStyleProps } from "@phoenix/components/types";

export function useChatMessageStyles(
role: string
): Pick<ViewProps, "backgroundColor" | "borderColor"> {
return useMemo<ViewProps>(() => {
if (role === "user" || role === "human") {
): Pick<ViewStyleProps, "backgroundColor" | "borderColor"> {
return useMemo<ViewStyleProps>(() => {
const normalizedRole = role.toLowerCase();
if (normalizedRole === "user" || normalizedRole === "human") {
return {
backgroundColor: "grey-100",
borderColor: "grey-500",
};
} else if (role === "assistant" || role === "ai") {
} else if (normalizedRole === "assistant" || normalizedRole === "ai") {
return {
backgroundColor: "blue-100",
borderColor: "blue-700",
};
} else if (role === "system") {
} else if (normalizedRole === "system" || normalizedRole === "developer") {
return {
backgroundColor: "indigo-100",
borderColor: "indigo-700",
};
} else if (["function", "tool"].includes(role)) {
} else if (["function", "tool"].includes(normalizedRole)) {
return {
backgroundColor: "yellow-100",
borderColor: "yellow-700",
Expand Down
67 changes: 43 additions & 24 deletions app/src/pages/prompt/PromptChatMessages.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import React, { useMemo } from "react";
import { useFragment } from "react-relay";
import { graphql } from "relay-runtime";
import React from "react";
import { graphql, useFragment } from "react-relay";

import { Flex, Text } from "@arizeai/components";

import { TemplateLanguages } from "@phoenix/components/templateEditor/constants";
import { TemplateLanguage } from "@phoenix/components/templateEditor/types";

import {
PromptChatMessages__main$data,
PromptChatMessages__main$key,
PromptTemplateFormat,
} from "./__generated__/PromptChatMessages__main.graphql";
import { ChatTemplateMessage } from "./ChatTemplateMessage";
import { PromptChatTemplate, PromptChatTemplateSchema } from "./schemas";

const convertTemplateFormat = (
templateFormat: PromptTemplateFormat
Expand All @@ -30,49 +29,69 @@ export function PromptChatMessages({
}: {
promptVersion: PromptChatMessages__main$key;
}) {
const { template, templateType, templateFormat } = useFragment(
const { template, templateFormat } = useFragment(
graphql`
fragment PromptChatMessages__main on PromptVersion {
template
template {
__typename
... on PromptChatTemplate {
messages {
... on JSONPromptMessage {
role
jsonContent: content
}
... on TextPromptMessage {
role
content
}
}
}
... on PromptStringTemplate {
template
}
}
templateType
templateFormat
}
`,
promptVersion
);

if (templateType === "STRING") {
return <Text>{template}</Text>;
if (template.__typename === "PromptStringTemplate") {
return <Text>{template.template}</Text>;
}
if (template.__typename === "PromptChatTemplate") {
return (
<ChatMessages
template={template}
templateFormat={convertTemplateFormat(templateFormat)}
/>
);
}
if (template.__typename === "%other") {
throw new Error("Unknown template type" + template.__typename);
}

return (
<ChatMessages
template={template}
templateFormat={convertTemplateFormat(templateFormat)}
/>
);
}

function ChatMessages({
template,
templateFormat,
}: {
template: PromptChatTemplate | unknown;
template: Extract<
PromptChatMessages__main$data["template"],
{ __typename: "PromptChatTemplate" }
>;
templateFormat: TemplateLanguage;
}) {
const messages = useMemo(() => {
const parsedTemplate = PromptChatTemplateSchema.safeParse(template);
if (!parsedTemplate.success) {
return [];
}
return parsedTemplate.data.messages;
}, [template]);
const { messages } = template;
return (
<Flex direction="column" gap="size-200">
{messages.map((message, i) => (
// TODO: Handle JSON content for things like tool calls
<ChatTemplateMessage
key={i}
{...message}
role={message.role as string}
content={message.content || JSON.stringify(message.jsonContent)}
templateFormat={templateFormat}
/>
))}
Expand Down
107 changes: 102 additions & 5 deletions app/src/pages/prompt/__generated__/PromptChatMessages__main.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7221d79

Please sign in to comment.