Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions instrumentation/rum/inject.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2025 Datadog, Inc.

package rum

import (
"net/http"
"unicode"
)

type state int

const (
sInit state = iota // looking for '<'
sLt // saw '<', expect '/'
sSlash // saw "</", allow spaces, expect 'h'
sH // expect 'e' (no spaces allowed)
sE // expect 'a' (no spaces allowed)
sA // expect 'd' (no spaces allowed)
sD // saw "...head", allow spaces, expect '>'
sDone // "</head>" found
)

var (
snippet = []byte("<snippet>")
)

// injector of a POC for RUM snippet injection.
// It doesn't handle Content-Length manipulation.
// It isn't concurrent safe.
type injector struct {
wrapped http.ResponseWriter
st state
lastSeen int
seenInCurrent bool
buf [][]byte
}

// Header implements http.ResponseWriter.
func (ij *injector) Header() http.Header {
// TODO: this is a good place to inject Content-Length to the right
// length, not the original one, if injection happened.
return ij.wrapped.Header()
}

// WriteHeader implements http.ResponseWriter.
func (ij *injector) WriteHeader(statusCode int) {
ij.wrapped.WriteHeader(statusCode)
}

// Write implements http.ResponseWriter.
// There are no guarantees that Write will be called with the whole payload.
// We need to keep state of what we've written so far to find the pattern
// "</head>" in all its variants.
func (ij *injector) Write(chunk []byte) (int, error) {
prev := ij.st
// If we've already found the pattern, just write the chunk.
if prev == sDone {
return ij.wrapped.Write(chunk)
}
ij.match(chunk)
if prev == sInit {
// No partial or full match done so far.
if ij.st == sInit {
return ij.wrapped.Write(chunk)
}
// Full match done in the chunk.
if ij.st == sDone {
ij.st = sDone
sz, err := multiWrite(ij.wrapped, chunk[:ij.lastSeen], snippet, chunk[ij.lastSeen:])
if err != nil {
return sz, err
}
return sz, nil
}
// Partial match in progress. We buffer the write.
// ij.lastSeen should be the index of the first byte of the match
// of the first chunk.
ij.buf = append(ij.buf, chunk)
return 0, nil
}
if ij.st != sDone {
// Partial match in progress. We buffer the write.
ij.buf = append(ij.buf, chunk)
return 0, nil
}
// Partial match done.
var (
total int
sz int
err error
)
ij.buf = append(ij.buf, chunk)
seenAt := 0
if ij.seenInCurrent {
seenAt = len(ij.buf) - 1
}
// Write the chunks before the chunk where the pattern starts.
sz, err = multiWrite(ij.wrapped, ij.buf[:seenAt]...)
if err != nil {
return sz, err
}
total += sz
// Write the snippet in the chunk where the pattern starts.
head := ij.buf[seenAt]
sz, err = multiWrite(ij.wrapped, head[:ij.lastSeen], snippet, head[ij.lastSeen:])
if err != nil {
return sz, err
}
total += sz
// Write the rest of the buffered chunks.
sz, err = multiWrite(ij.wrapped, ij.buf[seenAt+1:]...)
if err != nil {
return sz, err
}
total += sz
// Reset the buffer.
ij.buf = ij.buf[:0]
return total, nil
}

func multiWrite(w http.ResponseWriter, chunks ...[]byte) (int, error) {
if len(chunks) == 0 {
return 0, nil
}
sz := 0
for _, chunk := range chunks {
n, err := w.Write(chunk)
if err != nil {
return sz, err
}
sz += n
}
return sz, nil
}

// match updates the state of the injector according on what step of
// the pattern "</head>" have been found.
func (ij *injector) match(p []byte) {
if ij.st == sDone {
return
}
ij.seenInCurrent = false
for i := 0; i < len(p); i++ {
c := unicode.ToLower(rune(p[i]))
switch ij.st {
case sInit:
ij.transition('<', c, sLt, i)
case sLt: // expect '/'
ij.transition('/', c, sSlash, i)
case sSlash: // expect 'h'
if unicode.IsSpace(c) {
continue
}
ij.transition('h', c, sH, i)
case sH: // expect 'e'
ij.transition('e', c, sE, i)
case sE: // expect 'a'
ij.transition('a', c, sA, i)
case sA: // expect 'd'
ij.transition('d', c, sD, i)
case sD: // expect '>'
if unicode.IsSpace(c) {
continue
}
ij.transition('>', c, sDone, i)
}
}
}

func (ij *injector) transition(expected, current rune, target state, pos int) {
switch current {
case expected:
ij.st = target
case '<':
ij.st = sLt
default:
ij.st = sInit
}
if current == '<' {
ij.lastSeen = pos
ij.seenInCurrent = true
}
}

// Flush flushes the buffered chunks to the wrapped writer.
func (ij *injector) Flush() (int, error) {
if len(ij.buf) == 0 {
return 0, nil
}
sz, err := multiWrite(ij.wrapped, ij.buf...)
ij.buf = ij.buf[:0]
return sz, err
}

// Reset resets the state of the injector.
func (i *injector) Reset() {
i.st = sInit
i.lastSeen = -1
i.buf = i.buf[:0]
}

func NewInjector(fn func(w http.ResponseWriter, r *http.Request)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ij := &injector{
wrapped: w,
lastSeen: -1,
buf: make([][]byte, 0, 10),
}
fn(ij, r)
ij.Flush()
})
}
180 changes: 180 additions & 0 deletions instrumentation/rum/inject_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2025 Datadog, Inc.

package rum

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestInjector(t *testing.T) {
payload := []byte("Hello, world!")
h := func(w http.ResponseWriter, r *http.Request) {
w.Write(payload)
}
injector := NewInjector(h)
server := httptest.NewServer(injector)
defer server.Close()

resp, err := http.DefaultClient.Get(server.URL)
assert.NoError(t, err)
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello, world!", string(respBody))
// TODO: when Content-Length is implemented, uncomment this.
// assert.Equal(t, int64(len(payload) + len(snippet)), resp.ContentLength)
}

func TestInjectorMatch(t *testing.T) {
cases := []struct {
in []byte
want state
}{
{[]byte("hello </head> world"), sDone},
{[]byte("noise </ head > tail"), sDone}, // spaces after '/' and before '>'
{[]byte("nope < /head>"), sInit}, // space between '<' and '/'
{[]byte("nope </ he ad >"), sInit}, // spaces inside "head"
{[]byte("ok </\tHead\t\t >"), sDone}, // tabs after '/', spaces before '>'
{[]byte("partial </hea>"), sInit}, // missing 'd'
{[]byte("wrong </header>"), sInit}, // extra letters before '>'
{[]byte("caps </HEAD>"), sDone}, // case-insensitive
{[]byte("attr-like </head foo>"), sInit}, // rejected by our custom rule
{[]byte("prefix << / h e a d >"), sInit}, // multiple violations
}

for _, tc := range cases {
t.Run(string(tc.in), func(t *testing.T) {
i := &injector{}
i.match(tc.in)
got := i.st
i.Reset()
if got != tc.want {
t.Fatalf("match(%q) = %v; want %v", tc.in, got, tc.want)
}
})
}
}

func TestInjectorWrite(t *testing.T) {
cases := []struct {
category string
in string // comma separated chunks
out string
}{
{"basic", "abc</head>def", "abc<snippet></head>def"},
{"basic", "abc</he,ad>def", "abc<snippet></head>def"},
{"basic", "abc,</head>def", "abc<snippet></head>def"},
{"basic", "abc</head>,def", "abc<snippet></head>def"},
{"basic", "abc</h,ea,d>def", "abc<snippet></head>def"},
{"basic", "abc,</hea,d>def", "abc<snippet></head>def"},
{"no-head", "abc", "abc"},
{"no-head", "abc</hea", "abc</hea"},
{"empty", "", ""},
{"empty", ",", ""},
{"incomplete", "abc</he</head>def", "abc</he<snippet></head>def"},
{"incomplete", "abc</he,</head>def", "abc</he<snippet></head>def"},
{"casing", "abc</HeAd>def", "abc<snippet></HeAd>def"},
{"casing", "abc</HEAD>def", "abc<snippet></HEAD>def"},
{"spaces", "abc </head>def", "abc <snippet></head>def"},
{"spaces", "abc </hea,d>def", "abc <snippet></head>def"},
{"spaces", "abc</ head>def", "abc<snippet></ head>def"},
{"spaces", "abc</h ead>def", "abc</h ead>def"},
{"spaces", "abc</he ad>def", "abc</he ad>def"},
{"spaces", "abc</hea d>def", "abc</hea d>def"},
{"spaces", "abc</head >def", "abc<snippet></head >def"},
{"spaces", "abc</head> def", "abc<snippet></head> def"},
// {"comment", "<!-- </head>", "<!-- </head>"}, // TODO: don't inject if </head> is found in a comment
}

for _, tc := range cases {
t.Run(tc.category+":"+tc.in, func(t *testing.T) {
total := 0
recorder := httptest.NewRecorder()
i := &injector{
wrapped: recorder,
}
chunks := strings.Split(tc.in, ",")
for _, chunk := range chunks {
w, err := i.Write([]byte(chunk))
assert.NoError(t, err)
total += w
}
sz, err := i.Flush()
assert.NoError(t, err)
total += sz
body := recorder.Body.String()
assert.Equal(t, tc.out, body)
assert.Equal(t, len(tc.out), total)
})
}
}

type sinkResponseWriter struct {
out []byte
}

func (s *sinkResponseWriter) Header() http.Header {
return http.Header{}
}
func (s *sinkResponseWriter) Write(p []byte) (int, error) {
s.out = append(s.out, p...)
return len(p), nil
}
func (s *sinkResponseWriter) WriteHeader(int) {}

func BenchmarkInjectorWrite(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
sink := &sinkResponseWriter{}
ij := &injector{
wrapped: sink,
}
for i := 0; i < b.N; i++ {
ij.Write([]byte("abc</head>def"))
if !bytes.Equal(sink.out, []byte("abc<snippet></head>def")) {
b.Fatalf("out is not as expected: %q", sink.out)
}
sink.out = sink.out[:0]
ij.Reset()
}
}

func FuzzInjectorWrite(f *testing.F) {
cases := []string{
"abc</head>def",
"abc",
"abc</hea",
"abc</he</head>def",
"abc</HeAd>def",
"abc</HEAD>def",
"abc </head>def",
"abc</ head>def",
"abc</h ead>def",
"abc</he ad>def",
"abc</hea d>def",
"abc</head >def",
"abc</head> def",
"",
}
for _, tc := range cases {
f.Add([]byte(tc))
}
f.Fuzz(func(t *testing.T, in []byte) {
sink := &sinkResponseWriter{}
ij := &injector{
wrapped: sink,
}
ij.Write(in)
})
}
Loading