Skip to content

Commit

Permalink
Add support for the conversation API (#646)
Browse files Browse the repository at this point in the history
* feat: conversation api implementation

Signed-off-by: mikeee <[email protected]>

* chore: deps for conversation api

Signed-off-by: mikeee <[email protected]>

* fix: cleanup convo example

Signed-off-by: mikeee <[email protected]>

* refactor: add a conversationrequest builder and docs

Signed-off-by: mikeee <[email protected]>

* fix: lint and refactor, adding preallocations for ins/outs

Signed-off-by: Mike Nguyen <[email protected]>

* fix: lint and imports

Signed-off-by: mikeee <[email protected]>

* fix: bump to dapr master refs

Signed-off-by: mikeee <[email protected]>

* fix: enable scheduler with cli fix + tidy

Signed-off-by: mikeee <[email protected]>

---------

Signed-off-by: mikeee <[email protected]>
Signed-off-by: Mike Nguyen <[email protected]>
  • Loading branch information
mikeee authored Nov 27, 2024
1 parent e52d60c commit dce63f1
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 60 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/validate_examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
GOARCH: amd64
GOPROXY: https://proxy.golang.org
DAPR_INSTALL_URL: https://raw.githubusercontent.com/dapr/cli/master/install/install.sh
DAPR_CLI_REF: ${{ github.event.inputs.daprcli_commit }}
DAPR_REF: ${{ github.event.inputs.daprdapr_commit }}
DAPR_CLI_REF: 8bf3a1605f7b2ecfa7d4633ce4c5de13cdb65c5e
DAPR_REF: c86a77f6db5fb9f294f39d096ff0d9a053e55982
CHECKOUT_REPO: ${{ github.repository }}
CHECKOUT_REF: ${{ github.ref }}
outputs:
Expand Down Expand Up @@ -164,6 +164,7 @@ jobs:
[
"actor",
"configuration",
"conversation",
"crypto",
"dist-scheduler",
"grpc-service",
Expand Down
3 changes: 3 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ type Client interface {
// DeleteJobAlpha1 deletes a scheduled job.
DeleteJobAlpha1(ctx context.Context, name string) error

// ConverseAlpha1 interacts with a conversational AI model.
ConverseAlpha1(ctx context.Context, request conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error)

// GrpcClient returns the base grpc client if grpc is used and nil otherwise
GrpcClient() pb.DaprClient

Expand Down
146 changes: 146 additions & 0 deletions client/conversation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"

"google.golang.org/protobuf/types/known/anypb"

runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1"
)

// conversationRequest object - currently unexported as used in a functions option pattern
type conversationRequest struct {
name string
inputs []ConversationInput
Parameters map[string]*anypb.Any
Metadata map[string]string
ContextID *string
ScrubPII *bool // Scrub PII from the output
Temperature *float64
}

// NewConversationRequest defines a request with a component name and one or more inputs as a slice
func NewConversationRequest(llmName string, inputs []ConversationInput) conversationRequest {
return conversationRequest{
name: llmName,
inputs: inputs,
}
}

type conversationRequestOption func(request *conversationRequest)

// ConversationInput defines a single input.
type ConversationInput struct {
// The string to send to the llm.
Message string
// The role of the message.
Role *string
// Whether to Scrub PII from the input
ScrubPII *bool
}

// ConversationResponse is the basic response from a conversationRequest.
type ConversationResponse struct {
ContextID string
Outputs []ConversationResult
}

// ConversationResult is the individual
type ConversationResult struct {
Result string
Parameters map[string]*anypb.Any
}

// WithParameters should be used to provide parameters for custom fields.
func WithParameters(parameters map[string]*anypb.Any) conversationRequestOption {
return func(o *conversationRequest) {
o.Parameters = parameters
}
}

// WithMetadata used to define metadata to be passed to components.
func WithMetadata(metadata map[string]string) conversationRequestOption {
return func(o *conversationRequest) {
o.Metadata = metadata
}
}

// WithContextID to provide a new context or continue an existing one.
func WithContextID(id string) conversationRequestOption {
return func(o *conversationRequest) {
o.ContextID = &id
}
}

// WithScrubPII to define whether the outputs should have PII removed.
func WithScrubPII(scrub bool) conversationRequestOption {
return func(o *conversationRequest) {
o.ScrubPII = &scrub
}
}

// WithTemperature to specify which way the LLM leans.
func WithTemperature(temp float64) conversationRequestOption {
return func(o *conversationRequest) {
o.Temperature = &temp
}
}

// ConverseAlpha1 can invoke an LLM given a request created by the NewConversationRequest function.
func (c *GRPCClient) ConverseAlpha1(ctx context.Context, req conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error) {
cinputs := make([]*runtimev1pb.ConversationInput, len(req.inputs))
for i, in := range req.inputs {
cinputs[i] = &runtimev1pb.ConversationInput{
Message: in.Message,
Role: in.Role,
ScrubPII: in.ScrubPII,
}
}

for _, opt := range options {
if opt != nil {
opt(&req)
}
}

request := runtimev1pb.ConversationRequest{
Name: req.name,
ContextID: req.ContextID,
Inputs: cinputs,
Parameters: req.Parameters,
Metadata: req.Metadata,
ScrubPII: req.ScrubPII,
Temperature: req.Temperature,
}

resp, err := c.protoClient.ConverseAlpha1(ctx, &request)
if err != nil {
return nil, err
}

outputs := make([]ConversationResult, len(resp.GetOutputs()))
for i, o := range resp.GetOutputs() {
outputs[i] = ConversationResult{
Result: o.GetResult(),
Parameters: o.GetParameters(),
}
}

return &ConversationResponse{
ContextID: resp.GetContextID(),
Outputs: outputs,
}, nil
}
36 changes: 36 additions & 0 deletions examples/conversation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Dapr Conversation Example with go-sdk

## Step

### Prepare

- Dapr installed

### Run Conversation Example

<!-- STEP
name: Run Conversation
output_match_mode: substring
expected_stdout_lines:
- '== APP == conversation output: hello world'
background: true
sleep: 60
timeout_seconds: 60
-->

```bash
dapr run --app-id conversation \
--dapr-grpc-port 50001 \
--log-level debug \
--resources-path ./config \
-- go run ./main.go
```

<!-- END_STEP -->

## Result

```
- '== APP == conversation output: hello world'
```
7 changes: 7 additions & 0 deletions examples/conversation/config/conversation-echo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: echo
spec:
type: conversation.echo
version: v1
48 changes: 48 additions & 0 deletions examples/conversation/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main

import (
"context"
"fmt"
dapr "github.com/dapr/go-sdk/client"
"log"
)

func main() {
client, err := dapr.NewClient()
if err != nil {
panic(err)
}

input := dapr.ConversationInput{
Message: "hello world",
// Role: nil, // Optional
// ScrubPII: nil, // Optional
}

fmt.Printf("conversation input: %s\n", input.Message)

var conversationComponent = "echo"

request := dapr.NewConversationRequest(conversationComponent, []dapr.ConversationInput{input})

resp, err := client.ConverseAlpha1(context.Background(), request)
if err != nil {
log.Fatalf("err: %v", err)
}

fmt.Printf("conversation output: %s\n", resp.Outputs[0].Result)
}
18 changes: 9 additions & 9 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/dapr/go-sdk v0.0.0-00010101000000-000000000000
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809
google.golang.org/protobuf v1.34.2
)
Expand All @@ -18,7 +18,7 @@ require (
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dapr/dapr v1.14.1 // indirect
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-chi/chi/v5 v5.1.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
Expand All @@ -28,12 +28,12 @@ require (
github.com/marusama/semaphore/v2 v2.5.0 // indirect
github.com/microsoft/durabletask-go v0.5.1-0.20241024170039-0c4afbc95428 // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
go.opentelemetry.io/otel v1.27.0 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
40 changes: 20 additions & 20 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dapr/dapr v1.14.1 h1:n+FGF82caTsBjmnmKdBfrO94GRuLeuYs6qrAN5oG4ZM=
github.com/dapr/dapr v1.14.1/go.mod h1:oDNgaPHQIDZ3G4n4g89TElXWgkluYwcar41DI/oF4gw=
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f h1:wXPHK2o5FIABU5BvKk/21MN6GKaoUvWc7fESH/hwVls=
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f/go.mod h1:WlsLcudco11+BhaIvg2XyGxD+2GcZf8OTOawd94dAQs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -59,24 +59,24 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
go.opentelemetry.io/otel v1.27.0 h1:9BZoF3yMK/O1AafMiQTVu0YDj5Ea4hPhxCs7sGva+cg=
go.opentelemetry.io/otel v1.27.0/go.mod h1:DMpAK8fzYRzs+bi3rS5REupisuqTheUlSZJ1WnZaPAQ=
go.opentelemetry.io/otel/metric v1.27.0 h1:hvj3vdEKyeCi4YaYfNjv2NUje8FqKqUY8IlF0FxV/ik=
go.opentelemetry.io/otel/metric v1.27.0/go.mod h1:mVFgmRlhljgBiuk/MP/oKylr4hs85GZAylncepAX/ak=
go.opentelemetry.io/otel/trace v1.27.0 h1:IqYb813p7cmbHk0a5y6pD5JPakbVfftRXABGt5/Rscw=
go.opentelemetry.io/otel/trace v1.27.0/go.mod h1:6RiD1hkAprV4/q+yd2ln1HG9GoPx39SuvvstaLBl+l4=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 h1:N9BgCIAUvn/M+p4NJccWPWb3BWh88+zyL0ll9HgbEeM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw=
google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809 h1:f96Rv5C5Y2CWlbKK6KhKDdyFgGOjPHPEMsdyaxE9k0c=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809/go.mod h1:uaPEAc5V00jjG3DPhGFLXGT290RUV3+aNQigs1W50/8=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
Expand Down
18 changes: 9 additions & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module github.com/dapr/go-sdk
go 1.23.3

require (
github.com/dapr/dapr v1.14.1
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f
github.com/go-chi/chi/v5 v5.1.0
github.com/golang/mock v1.6.0
github.com/google/uuid v1.6.0
github.com/microsoft/durabletask-go v0.5.1-0.20241024170039-0c4afbc95428
github.com/stretchr/testify v1.9.0
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v3 v3.0.1
)
Expand All @@ -23,12 +23,12 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/marusama/semaphore/v2 v2.5.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
go.opentelemetry.io/otel v1.27.0 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
Loading

0 comments on commit dce63f1

Please sign in to comment.