Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertexai): explicit caching #10363

Merged
merged 12 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions vertexai/genai/aiplatformpb_veneer.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package genai

import (
"fmt"
"time"

pb "cloud.google.com/go/aiplatform/apiv1beta1/aiplatformpb"
"cloud.google.com/go/civil"
Expand Down Expand Up @@ -84,6 +85,77 @@ func (v BlockedReason) String() string {
return fmt.Sprintf("BlockedReason(%d)", v)
}

// CachedContent is a resource used in LLM queries for users to explicitly specify what to cache
// and how to cache.
type CachedContent struct {
// Expiration time of the cached content.
//
// Types that are assignable to Expiration:
//
// *CachedContent_ExpireTime
// *CachedContent_Ttl
Expiration ExpireTimeOrTTL
// Immutable. Identifier. The resource name of the cached content
// Format:
// projects/{project}/locations/{location}/cachedContents/{cached_content}
Name string
// Immutable. The name of the publisher model to use for cached content.
// Format:
// projects/{project}/locations/{location}/publishers/{publisher}/models/{model}
Model string
// Optional. Input only. Immutable. Developer set system instruction.
// Currently, text only
SystemInstruction *Content
// Optional. Input only. Immutable. The content to cache
Contents []*Content
// Optional. Input only. Immutable. A list of `Tools` the model may use to
// generate the next response
Tools []*Tool
// Optional. Input only. Immutable. Tool config. This config is shared for all
// tools
ToolConfig *ToolConfig
// Output only. Creatation time of the cache entry.
CreateTime time.Time
// Output only. When the cache entry was last updated in UTC time.
UpdateTime time.Time
}

func (v *CachedContent) toProto() *pb.CachedContent {
if v == nil {
return nil
}
p := &pb.CachedContent{
Name: v.Name,
Model: v.Model,
SystemInstruction: v.SystemInstruction.toProto(),
Contents: support.TransformSlice(v.Contents, (*Content).toProto),
Tools: support.TransformSlice(v.Tools, (*Tool).toProto),
ToolConfig: v.ToolConfig.toProto(),
CreateTime: support.TimeToProto(v.CreateTime),
UpdateTime: support.TimeToProto(v.UpdateTime),
}
populateCachedContentTo(p, v)
return p
}

func (CachedContent) fromProto(p *pb.CachedContent) *CachedContent {
if p == nil {
return nil
}
v := &CachedContent{
Name: p.Name,
Model: p.Model,
SystemInstruction: (Content{}).fromProto(p.SystemInstruction),
Contents: support.TransformSlice(p.Contents, (Content{}).fromProto),
Tools: support.TransformSlice(p.Tools, (Tool{}).fromProto),
ToolConfig: (ToolConfig{}).fromProto(p.ToolConfig),
CreateTime: support.TimeFromProto(p.CreateTime),
UpdateTime: support.TimeFromProto(p.UpdateTime),
}
populateCachedContentFrom(v, p)
return v
}

// Candidate is a response candidate generated from the model.
type Candidate struct {
// Output only. Index of the candidate.
Expand Down Expand Up @@ -592,6 +664,14 @@ type GenerationConfig struct {
// otherwise the behavior is undefined.
// This is a preview feature.
ResponseMIMEType string
// Optional. The `Schema` object allows the definition of input and output
// data types. These types can be objects, but also primitives and arrays.
// Represents a select subset of an [OpenAPI 3.0 schema
// object](https://spec.openapis.org/oas/v3.0.3#schema).
// If set, a compatible response_mime_type must also be set.
// Compatible mimetypes:
// `application/json`: Schema for JSON response.
ResponseSchema *Schema
}

func (v *GenerationConfig) toProto() *pb.GenerationConfig {
Expand All @@ -608,6 +688,7 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig {
PresencePenalty: v.PresencePenalty,
FrequencyPenalty: v.FrequencyPenalty,
ResponseMimeType: v.ResponseMIMEType,
ResponseSchema: v.ResponseSchema.toProto(),
}
}

Expand All @@ -625,6 +706,7 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig {
PresencePenalty: p.PresencePenalty,
FrequencyPenalty: p.FrequencyPenalty,
ResponseMIMEType: p.ResponseMimeType,
ResponseSchema: (Schema{}).fromProto(p.ResponseSchema),
}
}

Expand Down
192 changes: 192 additions & 0 deletions vertexai/genai/caching.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"context"
"errors"
"fmt"
"time"

aiplatform "cloud.google.com/go/aiplatform/apiv1beta1"
pb "cloud.google.com/go/aiplatform/apiv1beta1/aiplatformpb"
"cloud.google.com/go/vertexai/internal/support"
"google.golang.org/api/iterator"
durationpb "google.golang.org/protobuf/types/known/durationpb"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
)

type cacheClient = aiplatform.GenAiCacheClient

var (
newCacheClient = aiplatform.NewGenAiCacheClient
newCacheRESTClient = aiplatform.NewGenAiCacheRESTClient
)

// GenerativeModelFromCachedContent returns a [GenerativeModel] that uses the given [CachedContent].
// The argument should come from a call to [Client.CreateCachedContent] or [Client.GetCachedContent].
func (c *Client) GenerativeModelFromCachedContent(cc *CachedContent) *GenerativeModel {
return &GenerativeModel{
c: c,
name: cc.Model,
fullName: inferFullModelName(c.projectID, c.location, cc.Model),
CachedContentName: cc.Name,
}
}

// CreateCachedContent creates a new CachedContent.
// The argument should contain a model name and some data to be cached, which can include
// contents, a system instruction, tools and/or tool configuration. It can also
// include an expiration time or TTL. But it should not include a name; the system
// will generate one.
//
// The return value will contain the name, which should be used to refer to the CachedContent
// in other API calls. It will also hold various metadata like expiration and creation time.
// It will not contain any of the actual content provided as input.
//
// You can use the return value to create a model with [Client.GenerativeModelFromCachedContent].
// Or you can set [GenerativeModel.CachedContentName] to the name of the CachedContent, in which
// case you must ensure that the model provided in this call matches the name in the [GenerativeModel].
func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*CachedContent, error) {
if cc.Name != "" {
return nil, errors.New("genai.CreateCachedContent: do not provide a name; one will be generated")
}
pcc := cc.toProto()
pcc.Model = inferFullModelName(c.projectID, c.location, pcc.Model)
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, &pb.CreateCachedContentRequest{
Parent: c.parent(),
CachedContent: pcc,
}))
}

// GetCachedContent retrieves the CachedContent with the given name.
func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedContent, error) {
return c.cachedContentFromProto(c.cc.GetCachedContent(ctx, &pb.GetCachedContentRequest{Name: name}))
}

// DeleteCachedContent deletes the CachedContent with the given name.
func (c *Client) DeleteCachedContent(ctx context.Context, name string) error {
return c.cc.DeleteCachedContent(ctx, &pb.DeleteCachedContentRequest{Name: name})
}

// CachedContentToUpdate specifies which fields of a CachedContent to modify in a call to
// [Client.UpdateCachedContent].
type CachedContentToUpdate struct {
// If non-nil, update the expire time or TTL.
Expiration *ExpireTimeOrTTL
}

// UpdateCachedContent modifies the [CachedContent] according to the values
// of the [CachedContentToUpdate] struct.
// It returns the modified CachedContent.
//
// The argument CachedContent must have its Name field populated.
// If its UpdateTime field is non-zero, it will be compared with the update time
// of the stored CachedContent and the call will fail if they differ.
// This avoids a race condition when two updates are attempted concurrently.
// All other fields of the argument CachedContent are ignored.
func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu *CachedContentToUpdate) (*CachedContent, error) {
if ccu == nil || ccu.Expiration == nil {
return nil, errors.New("cloud.google.com/go/vertexai/genai.UpdateCachedContent: no update specified")
}
cc2 := &CachedContent{
Name: cc.Name,
UpdateTime: cc.UpdateTime,
Expiration: *ccu.Expiration,
}
mask := "expire_time"
if ccu.Expiration.ExpireTime.IsZero() {
mask = "ttl"
}
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, &pb.UpdateCachedContentRequest{
CachedContent: cc2.toProto(),
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}},
}))
}

// ListCachedContents lists all the CachedContents associated with the project and location.
func (c *Client) ListCachedContents(ctx context.Context) *CachedContentIterator {
return &CachedContentIterator{
it: c.cc.ListCachedContents(ctx, &pb.ListCachedContentsRequest{Parent: c.parent()}),
}
}

// A CachedContentIterator iterates over CachedContents.
type CachedContentIterator struct {
it *aiplatform.CachedContentIterator
}

// Next returns the next result. Its second return value is iterator.Done if there are no more
// results. Once Next returns Done, all subsequent calls will return Done.
func (it *CachedContentIterator) Next() (*CachedContent, error) {
m, err := it.it.Next()
if err != nil {
return nil, err
}
return (CachedContent{}).fromProto(m), nil
}

// PageInfo supports pagination. See the google.golang.org/api/iterator package for details.
func (it *CachedContentIterator) PageInfo() *iterator.PageInfo {
return it.it.PageInfo()
}

func (c *Client) cachedContentFromProto(pcc *pb.CachedContent, err error) (*CachedContent, error) {
if err != nil {
return nil, err
}
cc := (CachedContent{}).fromProto(pcc)
return cc, nil
}

// ExpireTimeOrTTL describes the time when a resource expires.
// If ExpireTime is non-zero, it is the expiration time.
// Otherwise, the expiration time is the value of TTL ("time to live") added
// to the current time.
type ExpireTimeOrTTL struct {
ExpireTime time.Time
TTL time.Duration
}

// populateCachedContentTo populates some fields of p from v.
func populateCachedContentTo(p *pb.CachedContent, v *CachedContent) {
exp := v.Expiration
if !exp.ExpireTime.IsZero() {
p.Expiration = &pb.CachedContent_ExpireTime{
ExpireTime: timestamppb.New(exp.ExpireTime),
}
} else if exp.TTL != 0 {
p.Expiration = &pb.CachedContent_Ttl{
Ttl: durationpb.New(exp.TTL),
}
}
// If both fields of v.Expiration are zero, leave p.Expiration unset.
}

// populateCachedContentFrom populates some fields of v from p.
func populateCachedContentFrom(v *CachedContent, p *pb.CachedContent) {
if p.Expiration == nil {
return
}
switch e := p.Expiration.(type) {
case *pb.CachedContent_ExpireTime:
v.Expiration.ExpireTime = support.TimeFromProto(e.ExpireTime)
case *pb.CachedContent_Ttl:
v.Expiration.TTL = e.Ttl.AsDuration()
default:
panic(fmt.Sprintf("unknown type of CachedContent.Expiration: %T", p.Expiration))
}
}
Loading
Loading