Skip to content

Commit 2dafa71

Browse files
authored
feat: query for the body tag rather than using string replacements (#766)
1 parent 2782591 commit 2dafa71

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

cmd/templ/generatecmd/proxy/proxy.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"strings"
1717
"time"
1818

19+
"github.com/PuerkitoBio/goquery"
1920
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
2021
"github.com/andybalholm/brotli"
2122

@@ -36,7 +37,16 @@ type Handler struct {
3637
}
3738

3839
func insertScriptTagIntoBody(body string) (updated string) {
39-
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
40+
doc, err := goquery.NewDocumentFromReader(strings.NewReader(body))
41+
if err != nil {
42+
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
43+
}
44+
doc.Find("body").AppendHtml(scriptTag)
45+
r, err := doc.Html()
46+
if err != nil {
47+
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
48+
}
49+
return r
4050
}
4151

4252
type passthroughWriteCloser struct {

cmd/templ/generatecmd/proxy/proxy_test.go

+43
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,49 @@ func TestProxy(t *testing.T) {
161161
t.Errorf("unexpected response body (-got +want):\n%s", diff)
162162
}
163163
})
164+
t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) {
165+
// Arrange
166+
r := &http.Response{
167+
Body: io.NopCloser(strings.NewReader(`<html><body><script>console.log("<body></body>")</script></body></html>`)),
168+
Header: make(http.Header),
169+
Request: &http.Request{
170+
URL: &url.URL{
171+
Scheme: "http",
172+
Host: "example.com",
173+
},
174+
},
175+
}
176+
r.Header.Set("Content-Type", "text/html, charset=utf-8")
177+
r.Header.Set("Content-Length", "26")
178+
179+
expectedString := insertScriptTagIntoBody(`<html><body><script>console.log("<body></body>")</script></body></html>`)
180+
if !strings.Contains(expectedString, scriptTag) {
181+
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
182+
}
183+
if !strings.Contains(expectedString, `console.log("<body></body>")`) {
184+
t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString)
185+
}
186+
187+
// Act
188+
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
189+
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
190+
err := h.modifyResponse(r)
191+
if err != nil {
192+
t.Fatalf("unexpected error: %v", err)
193+
}
194+
195+
// Assert
196+
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
197+
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
198+
}
199+
actualBody, err := io.ReadAll(r.Body)
200+
if err != nil {
201+
t.Fatalf("unexpected error reading response: %v", err)
202+
}
203+
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
204+
t.Errorf("unexpected response body (-got +want):\n%s", diff)
205+
}
206+
})
164207
t.Run("gzip: non-html content is not modified", func(t *testing.T) {
165208
// Arrange
166209
r := &http.Response{

0 commit comments

Comments
 (0)