diff --git a/internal/io/limitedwriter.go b/internal/io/limitedwriter.go new file mode 100644 index 00000000..a42bb258 --- /dev/null +++ b/internal/io/limitedwriter.go @@ -0,0 +1,53 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package io provides a LimitWriter that writes to an underlying writer up to +// a limit. + +package io + +import ( + "errors" + "io" +) + +// ErrLimitExceeded is returned when the write limit is exceeded. +var ErrLimitExceeded = errors.New("write limit exceeded") + +// LimitedWriter is a writer that writes to an underlying writer up to a limit. +type LimitedWriter struct { + W io.Writer // underlying writer + N int64 // remaining bytes +} + +// LimitWriter returns a new LimitWriter that writes to w. +// +// parameters: +// w: the writer to write to +// limit: the maximum number of bytes to write +func LimitWriter(w io.Writer, limit int64) *LimitedWriter { + return &LimitedWriter{W: w, N: limit} +} + +// Write writes p to the underlying writer up to the limit. +func (l *LimitedWriter) Write(p []byte) (int, error) { + if l.N <= 0 { + return 0, ErrLimitExceeded + } + if int64(len(p)) > l.N { + p = p[:l.N] + } + n, err := l.W.Write(p) + l.N -= int64(n) + return n, err +} diff --git a/internal/io/limitedwriter_test.go b/internal/io/limitedwriter_test.go new file mode 100644 index 00000000..264886df --- /dev/null +++ b/internal/io/limitedwriter_test.go @@ -0,0 +1,67 @@ +// Copyright The Notary Project Authors. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io + +import ( + "bytes" + "errors" + "testing" +) + +func TestLimitWriter(t *testing.T) { + limit := int64(10) + + tests := []struct { + input string + expected string + written int + }{ + {"hello", "hello", 5}, + {" world", " world", 6}, + {"!", "!", 1}, + {"1234567891011", "1234567891", 10}, + } + + for _, tt := range tests { + var buf bytes.Buffer + lw := LimitWriter(&buf, limit) + n, err := lw.Write([]byte(tt.input)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != tt.written { + t.Errorf("expected %d bytes written, got %d", tt.written, n) + } + if buf.String() != tt.expected { + t.Errorf("expected buffer %q, got %q", tt.expected, buf.String()) + } + } +} + +func TestLimitWriterFailed(t *testing.T) { + limit := int64(10) + longString := "1234567891011" + + var buf bytes.Buffer + lw := LimitWriter(&buf, limit) + _, err := lw.Write([]byte(longString)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = lw.Write([]byte(longString)) + expectedErr := errors.New("write limit exceeded") + if err.Error() != expectedErr.Error() { + t.Errorf("expected error %v, got %v", expectedErr, err) + } +} diff --git a/plugin/plugin.go b/plugin/plugin.go index d7f0a932..6402fb96 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -27,12 +27,16 @@ import ( "path/filepath" "strings" + "github.com/notaryproject/notation-go/internal/io" "github.com/notaryproject/notation-go/internal/slices" "github.com/notaryproject/notation-go/log" "github.com/notaryproject/notation-go/plugin/proto" "github.com/notaryproject/notation-plugin-framework-go/plugin" ) +// maxPluginOutputSize is the maximum size of the plugin output. +const maxPluginOutputSize = 64 * 1024 * 1024 // 64 MiB + var executor commander = &execCommander{} // for unit test // GenericPlugin is the base requirement to be a plugin. @@ -218,12 +222,14 @@ func (c execCommander) Output(ctx context.Context, name string, command plugin.C var stdout, stderr bytes.Buffer cmd := exec.CommandContext(ctx, name, string(command)) cmd.Stdin = bytes.NewReader(req) - cmd.Stderr = &stderr - cmd.Stdout = &stdout + // The limit writer will be handled by the caller in run() by comparing the + // bytes written with the expected length of the bytes. + cmd.Stderr = io.LimitWriter(&stderr, maxPluginOutputSize) + cmd.Stdout = io.LimitWriter(&stdout, maxPluginOutputSize) err := cmd.Run() if err != nil { if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil, stderr.Bytes(), fmt.Errorf("'%s %s' command execution timeout: %w", name, string(command), err); + return nil, stderr.Bytes(), fmt.Errorf("'%s %s' command execution timeout: %w", name, string(command), err) } return nil, stderr.Bytes(), err }