Skip to content

Commit a813345

Browse files
committed
implement WriteStateBytes (WIP)
1 parent aec0209 commit a813345

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

tfprotov6/state_store.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ type ReadStateBytesStream struct {
5656
Chunks iter.Seq[ReadStateByteChunk]
5757
}
5858

59+
// type ChunkIterator func(StateByteChunkRequest) StateByteChunk
60+
61+
type WriteStateBytesStream struct {
62+
Chunks iter.Seq[WriteStateByteChunk]
63+
}
64+
65+
type WriteStateBytesResponse struct {
66+
Diagnostics []*Diagnostic
67+
}
68+
69+
type WriteStateByteChunk = StateByteChunk
70+
5971
type ReadStateByteChunk struct {
6072
StateByteChunk
6173
Diagnostics []*Diagnostic

tfprotov6/tf6server/server.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/json"
99
"errors"
1010
"fmt"
11+
"io"
1112
"os"
1213
"os/signal"
1314
"regexp"
@@ -1645,6 +1646,69 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16451646
return nil
16461647
}
16471648

1649+
func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error {
1650+
rpc := "WriteStateBytes"
1651+
ctx := srv.Context()
1652+
ctx = s.loggingContext(ctx)
1653+
ctx = logging.RpcContext(ctx, rpc)
1654+
// ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1655+
ctx = s.stoppableContext(ctx)
1656+
// logging.ProtocolTrace(ctx, "Received request")
1657+
// defer logging.ProtocolTrace(ctx, "Served request")
1658+
1659+
ctx = tf6serverlogging.DownstreamRequest(ctx)
1660+
1661+
server, ok := s.downstream.(tfprotov6.StateStoreServer)
1662+
if !ok {
1663+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1664+
logging.ProtocolError(ctx, err.Error())
1665+
return err
1666+
}
1667+
1668+
var iteratorErr error
1669+
1670+
// TODO: what about error handling per chunk and providers having the ability to do cleanup on interruption?
1671+
1672+
iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) {
1673+
for {
1674+
chunk, err := srv.Recv()
1675+
if err == io.EOF {
1676+
break
1677+
}
1678+
if err != nil {
1679+
iteratorErr = err
1680+
srv.SendMsg(&tfplugin6.WriteStateBytes_Response{
1681+
// Diagnostics: ,
1682+
})
1683+
return
1684+
}
1685+
1686+
yield(tfprotov6.WriteStateByteChunk{
1687+
Bytes: chunk.Bytes,
1688+
TotalLength: chunk.TotalLength,
1689+
Range: tfprotov6.StateByteRange{
1690+
Start: chunk.Range.Start,
1691+
End: chunk.Range.End,
1692+
},
1693+
})
1694+
1695+
}
1696+
}
1697+
1698+
resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{
1699+
Chunks: iterator,
1700+
})
1701+
if err != nil {
1702+
return err
1703+
}
1704+
1705+
err = srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{
1706+
// Diagnostics: resp.Diagnostics,
1707+
})
1708+
1709+
return nil
1710+
}
1711+
16481712
func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {
16491713
rpc := "GetStates"
16501714
ctx = s.loggingContext(ctx)

0 commit comments

Comments
 (0)