Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve watch mode to reduce recompilation #366

Merged
merged 12 commits into from
Jan 7, 2024
21 changes: 6 additions & 15 deletions benchmarks/templ/template_templ.go
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

182 changes: 151 additions & 31 deletions cmd/templ/generatecmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/md5"
_ "embed"
"errors"
"fmt"
Expand Down Expand Up @@ -56,6 +57,8 @@ var defaultWorkerCount = runtime.NumCPU()

func Run(w io.Writer, args Arguments) (err error) {
ctx, cancel := context.WithCancel(context.Background())
watchCtx, watchCancel := context.WithCancel(context.Background())

signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
defer func() {
Expand All @@ -67,29 +70,44 @@ func Run(w io.Writer, args Arguments) (err error) {
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", args.PPROFPort), nil)
}()
}

go func() {
select {
case <-signalChan: // First signal, cancel context.
fmt.Fprintln(w, "\nCancelling...")
err = run.Stop()
if err != nil {
fmt.Fprintf(w, "Error killing command: %v\n", err)
watching := args.Watch
for {
select {
case <-signalChan: // First signal, cancel context.
if watching {
fmt.Println("stop watching")
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
watchCancel()
continue
}

if ctx.Err() != nil {
fmt.Fprintln(w, "\nHARD EXIT")
os.Exit(2) // hard exit
continue
}

fmt.Fprintln(w, "\nCancelling...")
cancel()

case <-ctx.Done():
break
}
cancel()
case <-ctx.Done():
}
<-signalChan // Second signal, hard exit.
os.Exit(2)
}()
err = runCmd(ctx, w, args)

err = runCmd(ctx, watchCtx, w, args)
if errors.Is(err, context.Canceled) {
return nil
}

return err
}

func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
start := time.Now()
func runCmd(ctx, watchCtx context.Context, w io.Writer, args Arguments) error {
var err error

if args.Watch && args.FileName != "" {
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
}
Expand All @@ -101,7 +119,7 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
opts = append(opts, generator.WithTimestamp(time.Now()))
}
if args.FileName != "" {
return processSingleFile(ctx, w, "", args.FileName, args.GenerateSourceMapVisualisations, opts)
return processSingleFile(ctx, w, "", args.FileName, nil, args.GenerateSourceMapVisualisations, opts)
}
var target *url.URL
if args.Proxy != "" {
Expand All @@ -120,7 +138,7 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
if !path.IsAbs(args.Path) {
args.Path, err = filepath.Abs(args.Path)
if err != nil {
return
return err
}
}

Expand All @@ -129,15 +147,36 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
p = proxy.New(args.ProxyPort, target)
}
fmt.Fprintln(w, "Processing path:", args.Path)

if args.Watch {
err = generateWatched(watchCtx, w, args, opts, p)
if err != nil && !errors.Is(err, context.Canceled) {
return err
}
}

return generateProduction(ctx, w, args, opts, p)
}

func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error {
fmt.Fprintln(w, "Generating dev code:", args.Path)
start := time.Now()

bo := backoff.NewExponentialBackOff()
bo.InitialInterval = time.Millisecond * 500
bo.MaxInterval = time.Second * 3
bo.MaxElapsedTime = 0

var firstRunComplete bool
fileNameToLastModTime := make(map[string]time.Time)
fileNameToHash := make(map[string][md5.Size]byte)
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved

for !firstRunComplete || args.Watch {
changesFound, errs := processChanges(ctx, w, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount, args.KeepOrphanedFiles)
changesFound, errs := processChanges(
ctx, w,
fileNameToLastModTime, fileNameToHash,
args.Path, args.GenerateSourceMapVisualisations,
opts, args.WorkerCount, true, args.KeepOrphanedFiles)
if len(errs) > 0 {
if errors.Is(errs[0], context.Canceled) {
return errs[0]
Expand Down Expand Up @@ -179,20 +218,62 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
}()
}
}
if err = checkTemplVersion(args.Path); err != nil {
if err := checkTemplVersion(args.Path); err != nil {
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
logWarning(w, "templ version check failed: %v\n", err)
err = nil
}

if firstRunComplete {
if changesFound > 0 {
bo.Reset()
}
time.Sleep(bo.NextBackOff())
}

firstRunComplete = true
start = time.Now()
}
return err

return nil
}

func generateProduction(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error {
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
fmt.Fprintln(w, "Generating production code:", args.Path)
start := time.Now()

changesFound, errs := processChanges(
ctx, w, nil, nil,
args.Path, args.GenerateSourceMapVisualisations,
opts, args.WorkerCount, false, args.KeepOrphanedFiles)
if len(errs) > 0 {
if errors.Is(errs[0], context.Canceled) {
return errs[0]
}
logError(w, "Error processing path: %v\n", errors.Join(errs...))
}

if changesFound > 0 {
if len(errs) > 0 {
logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
} else {
logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
}
if args.Command != "" {
fmt.Fprintf(w, "Executing command: %s\n", args.Command)
if _, err := run.Run(ctx, args.Path, args.Command); err != nil {
fmt.Fprintf(w, "Error starting command: %v\n", err)
}
}
// Send server-sent event.
if p != nil {
p.SendSSE("message", "reload")
}
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
}

if err := checkTemplVersion(args.Path); err != nil {
logWarning(w, "templ version check failed: %v\n", err)
}

return nil
}

func shouldSkipDir(dir string) bool {
Expand All @@ -210,10 +291,18 @@ func shouldSkipDir(dir string) bool {
return false
}

func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, keepOrphanedFiles bool) (changesFound int, errs []error) {
func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, hashes map[string][md5.Size]byte, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, watching, keepOrphanedFiles bool) (changesFound int, errs []error) {
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
sem := make(chan struct{}, maxWorkerCount)
var wg sync.WaitGroup

if watching {
opts = append(opts, generator.WithExtractStrings())
}

if fileNameToLastModTime == nil {
fileNameToLastModTime = make(map[string]time.Time)
}

err := filepath.WalkDir(path, func(fileName string, info os.DirEntry, err error) error {
if err != nil {
return err
Expand All @@ -227,19 +316,25 @@ func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime
if info.IsDir() {
return nil
}
if !keepOrphanedFiles && strings.HasSuffix(fileName, "_templ.go") {

orphaned := !keepOrphanedFiles && strings.HasSuffix(fileName, "_templ.go")
if orphaned {
// Make sure the generated file is orphaned
// by checking if the corresponding .templ file exists.
if _, err := os.Stat(strings.TrimSuffix(fileName, "_templ.go") + ".templ"); err == nil {
// The .templ file exists, so we don't delete the generated file.
return nil
orphaned = false
}
}

devTextFile := !watching && strings.HasSuffix(fileName, "_templ.txt")
if orphaned || devTextFile {
if err = os.Remove(fileName); err != nil {
return fmt.Errorf("failed to remove file: %w", err)
}
logWarning(stdout, "Deleted orphaned file %q\n", fileName)
logWarning(stdout, "Deleted file %q\n", fileName)
return nil
}

if strings.HasSuffix(fileName, ".templ") {
lastModTime := fileNameToLastModTime[fileName]
fileInfo, err := info.Info()
Expand All @@ -255,7 +350,7 @@ func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime
wg.Add(1)
go func() {
defer wg.Done()
if err := processSingleFile(ctx, stdout, path, fileName, generateSourceMapVisualisations, opts); err != nil {
if err := processSingleFile(ctx, stdout, path, fileName, hashes, generateSourceMapVisualisations, opts); err != nil {
errs = append(errs, err)
}
<-sem
Expand Down Expand Up @@ -291,9 +386,9 @@ func openURL(w io.Writer, url string) error {

// processSingleFile generates Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) {
func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, hashes map[string][md5.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) {
start := time.Now()
diag, err := generate(ctx, basePath, fileName, generateSourceMapVisualisations, opts)
diag, err := generate(ctx, basePath, fileName, hashes, generateSourceMapVisualisations, opts)
if err != nil {
return err
}
Expand All @@ -320,11 +415,15 @@ func printDiagnostics(w io.Writer, fileName string, diags []parser.Diagnostic) {

// generate Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func generate(ctx context.Context, basePath, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) {
func generate(ctx context.Context, basePath, fileName string, hashes map[string][md5.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) {
if err = ctx.Err(); err != nil {
return
}

if hashes == nil {
hashes = make(map[string][md5.Size]byte)
}

t, err := parser.Parse(fileName)
if err != nil {
return nil, fmt.Errorf("%s parsing error: %w", fileName, err)
Expand All @@ -338,7 +437,7 @@ func generate(ctx context.Context, basePath, fileName string, generateSourceMapV
}

var b bytes.Buffer
sourceMap, err := generator.Generate(t, &b, append(opts, generator.WithFileName(errorMessageFileName))...)
sourceMap, literals, err := generator.Generate(t, &b, append(opts, generator.WithFileName(errorMessageFileName))...)
if err != nil {
return nil, fmt.Errorf("%s generation error: %w", fileName, err)
}
Expand All @@ -348,8 +447,26 @@ func generate(ctx context.Context, basePath, fileName string, generateSourceMapV
return nil, fmt.Errorf("%s source formatting error: %w", fileName, err)
}

if err = os.WriteFile(targetFileName, data, 0644); err != nil {
return nil, fmt.Errorf("%s write file error: %w", targetFileName, err)
// Hash and write the file if the hash has changed
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
hash := md5.Sum(data)
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
if hashes[targetFileName] != hash {
if err = os.WriteFile(targetFileName, data, 0o644); err != nil {
return nil, fmt.Errorf("%s write file error: %w", targetFileName, err)
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
}
hashes[targetFileName] = hash
}

// Add the txt file if it has changed
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
if len(literals) > 0 {
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
contents := strings.Join(literals, "\n")
txtHash := md5.Sum([]byte(contents))
if hashes[txtFileName] != txtHash {
if err = os.WriteFile(txtFileName, []byte(strings.Join(literals, "\n")), 0o644); err != nil {
return nil, fmt.Errorf("%s write static text error: %w", txtFileName, err)
stephenafamo marked this conversation as resolved.
Show resolved Hide resolved
}
hashes[txtFileName] = txtHash
}
}

if generateSourceMapVisualisations {
Expand Down Expand Up @@ -397,12 +514,15 @@ func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileNa
func logError(w io.Writer, format string, a ...any) {
logWithDecoration(w, "✗", color.FgRed, format, a...)
}

func logWarning(w io.Writer, format string, a ...any) {
logWithDecoration(w, "!", color.FgYellow, format, a...)
}

func logSuccess(w io.Writer, format string, a ...any) {
logWithDecoration(w, "✓", color.FgGreen, format, a...)
}

func logWithDecoration(w io.Writer, decoration string, col color.Attribute, format string, a ...any) {
color.New(col).Fprintf(w, "(%s) ", decoration)
fmt.Fprintf(w, format, a...)
Expand Down
Loading