Skip to content

Commit

Permalink
Improve shutdown handling
Browse files Browse the repository at this point in the history
  • Loading branch information
bep committed Dec 28, 2020
1 parent 77f7923 commit c6a216c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 17 deletions.
32 changes: 29 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package godartsass

import (
"bytes"
"errors"
"io"
"os/exec"
"regexp"
"time"
)

Expand All @@ -19,22 +21,26 @@ func newConn(cmd *exec.Cmd) (_ conn, err error) {
}()

out, err := cmd.StdoutPipe()
stdErr := &tailBuffer{limit: 1024}
c := conn{out, in, stdErr, cmd}
cmd.Stderr = c.stdErr

return conn{out, in, cmd}, err
return c, err
}

// conn wraps a ReadCloser, WriteCloser, and a Cmd.
type conn struct {
io.ReadCloser
io.WriteCloser
cmd *exec.Cmd
stdErr *tailBuffer
cmd *exec.Cmd
}

// Start starts conn's Cmd.
func (c conn) Start() error {
err := c.cmd.Start()
if err != nil {
c.Close()
return c.Close()
}
return err
}
Expand All @@ -56,15 +62,35 @@ func (c conn) Close() error {
return cmdErr
}

var brokenPipeRe = regexp.MustCompile("Broken pipe|pipe is being closed")

// dart-sass-embedded ends on itself on EOF, this is just to give it some
// time to do so.
func (c conn) waitWithTimeout() error {
result := make(chan error, 1)
go func() { result <- c.cmd.Wait() }()
select {
case err := <-result:
if _, ok := err.(*exec.ExitError); ok {
if brokenPipeRe.MatchString(c.stdErr.String()) {
return nil
}
}
return err
case <-time.After(time.Second):
return errors.New("timed out waiting for dart-sass-embedded to finish")
}
}

type tailBuffer struct {
limit int
bytes.Buffer
}

func (b *tailBuffer) Write(p []byte) (n int, err error) {
if len(p)+b.Buffer.Len() > b.limit {
b.Reset()
}
n, err = b.Buffer.Write(p)
return
}
69 changes: 56 additions & 13 deletions transpiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"errors"
"fmt"
"io"
"strings"

"os"
"os/exec"
"regexp"
"runtime"
"sync"

"github.com/cli/safeexec"
Expand All @@ -20,7 +20,11 @@ import (
"google.golang.org/protobuf/proto"
)

var defaultDartSassEmbeddedFilename = "dart-sass-embedded"
const defaultDartSassEmbeddedFilename = "dart-sass-embedded"

// ErrShutdown will be returned from Execute if the transpiler is or
// is about to be shut down.
var ErrShutdown = errors.New("connection is shut down")

// Start creates an starts a new SCSS transpiler that communicates with the
// Dass Sass Embedded protocol via Stdin and Stdout.
Expand Down Expand Up @@ -70,6 +74,9 @@ type Transpiler struct {
// stdin/stdout of the Dart Sass protocol
conn io.ReadWriteCloser

closing bool
shutdown bool

// Protects the sending of messages to Dart Sass.
sendMu sync.Mutex

Expand Down Expand Up @@ -106,15 +113,27 @@ func (e SassError) Error() string {
return e.Message
}

// Close closes the stream to the embedded Dart Sass Protocol, which
// shuts down.
// Close closes the stream to the embedded Dart Sass Protocol, shutting it down.
// If it is already shutting down, ErrShutdown is returned.
func (t *Transpiler) Close() error {
return t.conn.Close()
t.sendMu.Lock()
defer t.sendMu.Unlock()
t.mu.Lock()
defer t.mu.Unlock()

if t.closing {
return ErrShutdown
}

t.closing = true
err := t.conn.Close()

return err
}

// Execute transpiles the string Source given in Args into CSS.
// If Dart Sass resturns a "compile failure", the error returned will be
// of type SassError..
// of type SassError.
func (t *Transpiler) Execute(args Args) (Result, error) {
var result Result

Expand Down Expand Up @@ -241,7 +260,7 @@ func (t *Transpiler) input() {
},
}

t.sendInboundMessage(
err = t.sendInboundMessage(
&embeddedsass.InboundMessage{
Message: response,
},
Expand Down Expand Up @@ -272,7 +291,7 @@ func (t *Transpiler) input() {
},
}

t.sendInboundMessage(
err = t.sendInboundMessage(
&embeddedsass.InboundMessage{
Message: response,
},
Expand All @@ -284,7 +303,22 @@ func (t *Transpiler) input() {
default:
err = fmt.Errorf("unsupported response message type. %T", msg.Message)
}
}

// Terminate pending calls.
t.sendMu.Lock()
defer t.sendMu.Unlock()
t.mu.Lock()
defer t.mu.Unlock()

t.shutdown = true
isEOF := err == io.EOF || strings.Contains(err.Error(), "already closed")
if isEOF {
if t.closing {
err = ErrShutdown
} else {
err = io.ErrUnexpectedEOF
}
}

for _, call := range t.pending {
Expand All @@ -295,7 +329,6 @@ func (t *Transpiler) input() {

func (t *Transpiler) newCall(createInbound func(seq uint32) (*embeddedsass.InboundMessage, error), args Args) (*call, error) {
t.mu.Lock()
// TODO1 handle shutdown.
id := t.seq
req, err := createInbound(id)
if err != nil {
Expand All @@ -309,8 +342,16 @@ func (t *Transpiler) newCall(createInbound func(seq uint32) (*embeddedsass.Inbou
importResolver: args.ImportResolver,
}

if t.shutdown || t.closing {
t.mu.Unlock()
call.Error = ErrShutdown
call.done()
return call, nil
}

t.pending[id] = call
t.seq++

t.mu.Unlock()

switch c := call.Request.Message.(type) {
Expand All @@ -326,6 +367,12 @@ func (t *Transpiler) newCall(createInbound func(seq uint32) (*embeddedsass.Inbou
func (t *Transpiler) sendInboundMessage(message *embeddedsass.InboundMessage) error {
t.sendMu.Lock()
defer t.sendMu.Unlock()
t.mu.Lock()
if t.closing || t.shutdown {
t.mu.Unlock()
return ErrShutdown
}
t.mu.Unlock()

out, err := proto.Marshal(message)
if err != nil {
Expand Down Expand Up @@ -363,10 +410,6 @@ func (call *call) done() {
}
}

func isWindows() bool {
return runtime.GOOS == "windows"
}

var hasSchemaRe = regexp.MustCompile("^[a-z]*:")

func hasSchema(s string) bool {
Expand Down
49 changes: 48 additions & 1 deletion transpiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,52 @@ div { p { width: $width; } }`,

}

func TestTranspilerClose(t *testing.T) {
c := qt.New(t)
transpiler, _ := newTestTranspiler(c, Options{})
var wg sync.WaitGroup

for i := 0; i < 10; i++ {
wg.Add(1)
go func(gor int) {
defer wg.Done()
for j := 0; j < 4; j++ {
src := fmt.Sprintf(`
$primary-color: #%03d;
div { color: $primary-color; }`, gor)

num := gor + j

if num == 10 {
err := transpiler.Close()
if err != nil {
c.Check(err, qt.Equals, ErrShutdown)
}
}

result, err := transpiler.Execute(Args{Source: src})

if err != nil {
c.Check(err, qt.Equals, ErrShutdown)
} else {
c.Check(err, qt.IsNil)
c.Check(result.CSS, qt.Equals, fmt.Sprintf("div {\n color: #%03d;\n}", gor))
}

if c.Failed() {
return
}
}
}(i)
}
wg.Wait()

for _, p := range transpiler.pending {
c.Assert(p.Error, qt.Equals, ErrShutdown)
}
}

func BenchmarkTranspiler(b *testing.B) {
type tester struct {
src string
Expand Down Expand Up @@ -318,7 +364,8 @@ func newTestTranspiler(c *qt.C, opts Options) (*Transpiler, func()) {
c.Assert(err, qt.IsNil)

return transpiler, func() {
c.Assert(transpiler.Close(), qt.IsNil)
err := transpiler.Close()
c.Assert(err, qt.IsNil)
}
}

Expand Down

0 comments on commit c6a216c

Please sign in to comment.