Skip to content

Commit bebe220

Browse files
committed
implement WriteStateBytes
1 parent 3828557 commit bebe220

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

tfprotov6/state_store.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ type ReadStateBytesStream struct {
5656
Chunks iter.Seq[ReadStateByteChunk]
5757
}
5858

59+
type WriteStateBytesStream struct {
60+
Chunks iter.Seq[WriteStateByteChunk]
61+
}
62+
63+
type WriteStateBytesResponse struct {
64+
Diagnostics []*Diagnostic
65+
}
66+
67+
type WriteStateByteChunk = StateByteChunk
68+
5969
type ReadStateByteChunk struct {
6070
StateByteChunk
6171
Diagnostics []*Diagnostic

tfprotov6/tf6server/server.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/json"
99
"errors"
1010
"fmt"
11+
"io"
1112
"os"
1213
"os/signal"
1314
"regexp"
@@ -1645,6 +1646,78 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16451646
return nil
16461647
}
16471648

1649+
func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error {
1650+
rpc := "WriteStateBytes"
1651+
ctx := srv.Context()
1652+
ctx = s.loggingContext(ctx)
1653+
ctx = logging.RpcContext(ctx, rpc)
1654+
// ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1655+
ctx = s.stoppableContext(ctx)
1656+
// logging.ProtocolTrace(ctx, "Received request")
1657+
// defer logging.ProtocolTrace(ctx, "Served request")
1658+
1659+
ctx = tf6serverlogging.DownstreamRequest(ctx)
1660+
1661+
server, ok := s.downstream.(tfprotov6.StateStoreServer)
1662+
if !ok {
1663+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1664+
logging.ProtocolError(ctx, err.Error())
1665+
return err
1666+
}
1667+
1668+
iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) {
1669+
for {
1670+
chunk, err := srv.Recv()
1671+
if err == io.EOF {
1672+
break
1673+
}
1674+
if err != nil {
1675+
// attempt to send the error back to client
1676+
msgErr := srv.SendMsg(&tfplugin6.WriteStateBytes_Response{
1677+
Diagnostics: toproto.Diagnostics([]*tfprotov6.Diagnostic{
1678+
{
1679+
Severity: tfprotov6.DiagnosticSeverityError,
1680+
Summary: "Writing state chunk failed",
1681+
Detail: fmt.Sprintf("Attempt to write a byte chunk of state %q to %q failed: %s",
1682+
chunk.StateId, chunk.TypeName, err),
1683+
},
1684+
}),
1685+
})
1686+
if msgErr != nil {
1687+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1688+
logging.ProtocolError(ctx, err.Error())
1689+
return
1690+
}
1691+
return
1692+
}
1693+
1694+
ok := yield(tfprotov6.WriteStateByteChunk{
1695+
Bytes: chunk.Bytes,
1696+
TotalLength: chunk.TotalLength,
1697+
Range: tfprotov6.StateByteRange{
1698+
Start: chunk.Range.Start,
1699+
End: chunk.Range.End,
1700+
},
1701+
})
1702+
if !ok {
1703+
return
1704+
}
1705+
1706+
}
1707+
}
1708+
1709+
resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{
1710+
Chunks: iterator,
1711+
})
1712+
if err != nil {
1713+
return err
1714+
}
1715+
1716+
return srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{
1717+
Diagnostics: toproto.Diagnostics(resp.Diagnostics),
1718+
})
1719+
}
1720+
16481721
func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {
16491722
rpc := "GetStates"
16501723
ctx = s.loggingContext(ctx)

0 commit comments

Comments
 (0)