Skip to content

Commit a920604

Browse files
mage: cancel context on SIGINT (#313)
* mage: cancel context on SIGINT On receiving an interrupt signal, mage cancels the context allowing the magefile to perform any cleanup before exiting. A second interrupt signal will kill the magefile process without delay. The behaviour for a timeout remains unchanged (context is cancelled and the magefile exits). * mage: add cleanup timeout to cancel Co-authored-by: Nate Finch <[email protected]>
1 parent 300bbc8 commit a920604

File tree

6 files changed

+203
-14
lines changed

6 files changed

+203
-14
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ jobs:
99
fail-fast: false
1010
matrix:
1111
go-version:
12+
- 1.18.x
1213
- 1.17.x
1314
- 1.16.x
1415
- 1.15.x

mage/main.go

+6
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ import (
1212
"log"
1313
"os"
1414
"os/exec"
15+
"os/signal"
1516
"path/filepath"
1617
"regexp"
1718
"runtime"
1819
"sort"
1920
"strings"
21+
"syscall"
2022
"text/template"
2123
"time"
2224

@@ -737,6 +739,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int {
737739
c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String()))
738740
}
739741
debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n"))
742+
// catch SIGINT to allow magefile to handle them
743+
sigCh := make(chan os.Signal, 1)
744+
signal.Notify(sigCh, syscall.SIGINT)
745+
defer signal.Stop(sigCh)
740746
err := c.Run()
741747
if !sh.CmdRan(err) {
742748
errlog.Printf("failed to run compiled magefile: %v", err)

mage/main_test.go

+108-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"runtime"
2323
"strconv"
2424
"strings"
25+
"syscall"
2526
"testing"
2627
"time"
2728

@@ -1292,7 +1293,7 @@ func TestCompiledFlags(t *testing.T) {
12921293
if err == nil {
12931294
t.Fatalf("expected an error because of timeout")
12941295
}
1295-
got = stdout.String()
1296+
got = stderr.String()
12961297
want = "context deadline exceeded"
12971298
if strings.Contains(got, want) == false {
12981299
t.Errorf("got %q, does not contain %q", got, want)
@@ -1384,7 +1385,7 @@ func TestCompiledEnvironmentVars(t *testing.T) {
13841385
if err == nil {
13851386
t.Fatalf("expected an error because of timeout")
13861387
}
1387-
got = stdout.String()
1388+
got = stderr.String()
13881389
want = "context deadline exceeded"
13891390
if strings.Contains(got, want) == false {
13901391
t.Errorf("got %q, does not contain %q", got, want)
@@ -1457,6 +1458,111 @@ func TestCompiledVerboseFlag(t *testing.T) {
14571458
}
14581459
}
14591460

1461+
func TestSignals(t *testing.T) {
1462+
stderr := &bytes.Buffer{}
1463+
stdout := &bytes.Buffer{}
1464+
dir := "./testdata/signals"
1465+
compileDir, err := ioutil.TempDir(dir, "")
1466+
if err != nil {
1467+
t.Fatal(err)
1468+
}
1469+
name := filepath.Join(compileDir, "mage_out")
1470+
// The CompileOut directory is relative to the
1471+
// invocation directory, so chop off the invocation dir.
1472+
outName := "./" + name[len(dir)-1:]
1473+
defer os.RemoveAll(compileDir)
1474+
inv := Invocation{
1475+
Dir: dir,
1476+
Stdout: stdout,
1477+
Stderr: stderr,
1478+
CompileOut: outName,
1479+
}
1480+
code := Invoke(inv)
1481+
if code != 0 {
1482+
t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr)
1483+
}
1484+
1485+
run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error {
1486+
stderr.Reset()
1487+
stdout.Reset()
1488+
cmd := exec.Command(filename, target)
1489+
cmd.Stderr = stderr
1490+
cmd.Stdout = stdout
1491+
if err := cmd.Start(); err != nil {
1492+
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
1493+
filename, target, err, stdout, stderr)
1494+
}
1495+
pid := cmd.Process.Pid
1496+
go func() {
1497+
time.Sleep(time.Millisecond * 500)
1498+
for _, s := range signals {
1499+
syscall.Kill(pid, s)
1500+
time.Sleep(time.Millisecond * 50)
1501+
}
1502+
}()
1503+
if err := cmd.Wait(); err != nil {
1504+
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
1505+
filename, target, err, stdout, stderr)
1506+
}
1507+
return nil
1508+
}
1509+
1510+
if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil {
1511+
t.Fatal(err)
1512+
}
1513+
got := stdout.String()
1514+
want := "received sighup\n"
1515+
if strings.Contains(got, want) == false {
1516+
t.Errorf("got %q, does not contain %q", got, want)
1517+
}
1518+
1519+
if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil {
1520+
t.Fatal(err)
1521+
}
1522+
got = stdout.String()
1523+
want = "exiting...done\n"
1524+
if strings.Contains(got, want) == false {
1525+
t.Errorf("got %q, does not contain %q", got, want)
1526+
}
1527+
got = stderr.String()
1528+
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n"
1529+
if strings.Contains(got, want) == false {
1530+
t.Errorf("got %q, does not contain %q", got, want)
1531+
}
1532+
1533+
if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil {
1534+
t.Fatal(err)
1535+
}
1536+
got = stdout.String()
1537+
want = "exiting...done\ndeferred cleanup\n"
1538+
if strings.Contains(got, want) == false {
1539+
t.Errorf("got %q, does not contain %q", got, want)
1540+
}
1541+
got = stderr.String()
1542+
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n"
1543+
if strings.Contains(got, want) == false {
1544+
t.Errorf("got %q, does not contain %q", got, want)
1545+
}
1546+
1547+
if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil {
1548+
t.Fatalf("expected an error because of force kill")
1549+
}
1550+
got = stderr.String()
1551+
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nexiting mage\nError: exit forced\n"
1552+
if strings.Contains(got, want) == false {
1553+
t.Errorf("got %q, does not contain %q", got, want)
1554+
}
1555+
1556+
if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT); err == nil {
1557+
t.Fatalf("expected an error because of force kill")
1558+
}
1559+
got = stderr.String()
1560+
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nError: cleanup timeout exceeded\n"
1561+
if strings.Contains(got, want) == false {
1562+
t.Errorf("got %q, does not contain %q", got, want)
1563+
}
1564+
}
1565+
14601566
func TestCompiledDeterministic(t *testing.T) {
14611567
dir := "./testdata/compiled"
14621568
compileDir, err := ioutil.TempDir(dir, "")

mage/template.go

+37-11
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ import (
1414
_ioutil "io/ioutil"
1515
_log "log"
1616
"os"
17+
"os/signal"
1718
_filepath "path/filepath"
1819
_sort "sort"
1920
"strconv"
2021
_strings "strings"
22+
"syscall"
2123
_tabwriter "text/tabwriter"
2224
"time"
2325
{{range .Imports}}{{.UniqueName}} "{{.Path}}"
@@ -256,23 +258,27 @@ Options:
256258
}
257259
258260
var ctx context.Context
259-
var ctxCancel func()
261+
ctxCancel := func(){}
262+
263+
// by deferring in a closure, we let the cancel function get replaced
264+
// by the getContext function.
265+
defer func() {
266+
ctxCancel()
267+
}()
260268
261269
getContext := func() (context.Context, func()) {
262-
if ctx != nil {
263-
return ctx, ctxCancel
270+
if ctx == nil {
271+
if args.Timeout != 0 {
272+
ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout)
273+
} else {
274+
ctx, ctxCancel = context.WithCancel(context.Background())
275+
}
264276
}
265277
266-
if args.Timeout != 0 {
267-
ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout)
268-
} else {
269-
ctx = context.Background()
270-
ctxCancel = func() {}
271-
}
272278
return ctx, ctxCancel
273279
}
274280
275-
runTarget := func(fn func(context.Context) error) interface{} {
281+
runTarget := func(logger *_log.Logger, fn func(context.Context) error) interface{} {
276282
var err interface{}
277283
ctx, cancel := getContext()
278284
d := make(chan interface{})
@@ -284,14 +290,34 @@ Options:
284290
err := fn(ctx)
285291
d <- err
286292
}()
293+
sigCh := make(chan os.Signal, 1)
294+
signal.Notify(sigCh, syscall.SIGINT)
287295
select {
296+
case <-sigCh:
297+
logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...")
298+
cancel()
299+
cleanupCh := time.After(5 * time.Second)
300+
301+
select {
302+
// target exited by itself
303+
case err = <-d:
304+
return err
305+
// cleanup timeout exceeded
306+
case <-cleanupCh:
307+
return _fmt.Errorf("cleanup timeout exceeded")
308+
// second SIGINT received
309+
case <-sigCh:
310+
logger.Println("exiting mage")
311+
return _fmt.Errorf("exit forced")
312+
}
288313
case <-ctx.Done():
289314
cancel()
290315
e := ctx.Err()
291316
_fmt.Printf("ctx err: %v\n", e)
292317
return e
293318
case err = <-d:
294-
cancel()
319+
// we intentionally don't cancel the context here, because
320+
// the next target will need to run with the same context.
295321
return err
296322
}
297323
}

mage/testdata/signals/signals.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//+build mage
2+
3+
package main
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"os"
9+
"os/signal"
10+
"syscall"
11+
"time"
12+
)
13+
14+
// Exits after receiving SIGHUP
15+
func ExitsAfterSighup(ctx context.Context) {
16+
sigC := make(chan os.Signal, 1)
17+
signal.Notify(sigC, syscall.SIGHUP)
18+
<-sigC
19+
fmt.Println("received sighup")
20+
}
21+
22+
// Exits after SIGINT and wait
23+
func ExitsAfterSigint(ctx context.Context) {
24+
sigC := make(chan os.Signal, 1)
25+
signal.Notify(sigC, syscall.SIGINT)
26+
<-sigC
27+
fmt.Printf("exiting...")
28+
time.Sleep(200 * time.Millisecond)
29+
fmt.Println("done")
30+
}
31+
32+
// Exits after ctx cancel and wait
33+
func ExitsAfterCancel(ctx context.Context) {
34+
defer func() {
35+
fmt.Println("deferred cleanup")
36+
}()
37+
<-ctx.Done()
38+
fmt.Printf("exiting...")
39+
time.Sleep(200 * time.Millisecond)
40+
fmt.Println("done")
41+
}
42+
43+
// Ignores all signals, requires killing via timeout or second SIGINT
44+
func IgnoresSignals(ctx context.Context) {
45+
sigC := make(chan os.Signal, 1)
46+
signal.Notify(sigC, syscall.SIGINT)
47+
for {
48+
<-sigC
49+
}
50+
}

parse/parse.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (f Function) ExecCode() string {
169169
}
170170
out += `
171171
}
172-
ret := runTarget(wrapFn)`
172+
ret := runTarget(logger, wrapFn)`
173173
return out
174174
}
175175

0 commit comments

Comments
 (0)