Skip to content

Commit

Permalink
fix(streaming): call stream.abort() explicitly when request is aborted (
Browse files Browse the repository at this point in the history
#3042)

* feat(utils/stream): enable to abort streaming manually

* feat(utils/stream): prevent multiple aborts, and enable to get the abort status

* fix(streaming): call `stream.abort()` explicitly when request is aborted

* test: add tests for streaming

* docs(stream): add comments

* test: add --allow-net to deno test command in ci.yml

* test(streaming): update test code

* test(stream): retry flaky test up to 3 times at "bun"

* test(streaming): refactor test to use afterEach

* fix(streaming): in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming

* refactor(streaming): tweaks code layout
  • Loading branch information
usualoma authored Jun 28, 2024
1 parent a6ad42d commit 2d3bc55
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
- uses: denoland/setup-deno@v1
with:
deno-version: v1.x
- run: env NAME=Deno deno test --coverage=coverage/raw/deno-runtime --allow-read --allow-env --allow-write -c runtime_tests/deno/deno.json runtime_tests/deno
- run: env NAME=Deno deno test --coverage=coverage/raw/deno-runtime --allow-read --allow-env --allow-write --allow-net -c runtime_tests/deno/deno.json runtime_tests/deno
- run: deno test -c runtime_tests/deno-jsx/deno.precompile.json --coverage=coverage/raw/deno-precompile-jsx runtime_tests/deno-jsx
- run: deno test -c runtime_tests/deno-jsx/deno.react-jsx.json --coverage=coverage/raw/deno-react-jsx runtime_tests/deno-jsx
- uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"scripts": {
"test": "tsc --noEmit && vitest --run && vitest -c .vitest.config/jsx-runtime-default.ts --run && vitest -c .vitest.config/jsx-runtime-dom.ts --run",
"test:watch": "vitest --watch",
"test:deno": "deno test --allow-read --allow-env --allow-write -c runtime_tests/deno/deno.json runtime_tests/deno && deno test --no-lock -c runtime_tests/deno-jsx/deno.precompile.json runtime_tests/deno-jsx && deno test --no-lock -c runtime_tests/deno-jsx/deno.react-jsx.json runtime_tests/deno-jsx",
"test:deno": "deno test --allow-read --allow-env --allow-write --allow-net -c runtime_tests/deno/deno.json runtime_tests/deno && deno test --no-lock -c runtime_tests/deno-jsx/deno.precompile.json runtime_tests/deno-jsx && deno test --no-lock -c runtime_tests/deno-jsx/deno.react-jsx.json runtime_tests/deno-jsx",
"test:bun": "bun test --jsx-import-source ../../src/jsx runtime_tests/bun/index.test.tsx",
"test:fastly": "vitest --run --config ./runtime_tests/fastly/vitest.config.ts",
"test:node": "vitest --run --config ./runtime_tests/node/vitest.config.ts",
Expand Down
74 changes: 73 additions & 1 deletion runtime_tests/bun/index.test.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { afterAll, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { serveStatic, toSSG } from '../../src/adapter/bun'
import { createBunWebSocket } from '../../src/adapter/bun/websocket'
import type { BunWebSocketData } from '../../src/adapter/bun/websocket'
Expand All @@ -11,6 +11,7 @@ import { jsx } from '../../src/jsx'
import { basicAuth } from '../../src/middleware/basic-auth'
import { jwt } from '../../src/middleware/jwt'
import { HonoRequest } from '../../src/request'
import { stream, streamSSE } from '../..//src/helper/streaming'

// Test just only minimal patterns.
// Because others are tested well in Cloudflare Workers environment already.
Expand Down Expand Up @@ -316,3 +317,74 @@ async function deleteDirectory(dirPath) {
await fs.unlink(dirPath)
}
}

describe('streaming', () => {
const app = new Hono()
let server: ReturnType<typeof Bun.serve>
let aborted = false

app.get('/stream', (c) => {
return stream(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})
app.get('/streamSSE', (c) => {
return streamSSE(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})

beforeEach(() => {
aborted = false
server = Bun.serve({ port: 0, fetch: app.fetch })
})

afterEach(() => {
server.stop()
})

describe('stream', () => {
it('Should call onAbort', async () => {
const ac = new AbortController()
const req = new Request(`http://localhost:${server.port}/stream`, {
signal: ac.signal,
})
expect(aborted).toBe(false)
const res = fetch(req).catch(() => {})
await new Promise((resolve) => setTimeout(resolve, 10))
ac.abort()
await res
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
expect(aborted).toBe(true)
})
})

describe('streamSSE', () => {
it('Should call onAbort', async () => {
const ac = new AbortController()
const req = new Request(`http://localhost:${server.port}/streamSSE`, {
signal: ac.signal,
})
const res = fetch(req).catch(() => {})
await new Promise((resolve) => setTimeout(resolve, 10))
ac.abort()
await res
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
expect(aborted).toBe(true)
})
})
})
69 changes: 69 additions & 0 deletions runtime_tests/deno/stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { Hono } from '../../src/hono.ts'
import { assertEquals } from './deps.ts'
import { stream, streamSSE } from '../../src/helper/streaming/index.ts'

Deno.test('Shuld call onAbort via stream', async () => {
const app = new Hono()
let aborted = false
app.get('/stream', (c) => {
return stream(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})

const server = Deno.serve({ port: 0 }, app.fetch)
const ac = new AbortController()
const req = new Request(`http://localhost:${server.addr.port}/stream`, {
signal: ac.signal,
})
assertEquals
const res = fetch(req).catch(() => {})
assertEquals(aborted, false)
await new Promise((resolve) => setTimeout(resolve, 10))
ac.abort()
await res
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
assertEquals(aborted, true)

await server.shutdown()
})

Deno.test('Shuld call onAbort via streamSSE', async () => {
const app = new Hono()
let aborted = false
app.get('/stream', (c) => {
return streamSSE(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})

const server = Deno.serve({ port: 0 }, app.fetch)
const ac = new AbortController()
const req = new Request(`http://localhost:${server.addr.port}/stream`, {
signal: ac.signal,
})
assertEquals
const res = fetch(req).catch(() => {})
assertEquals(aborted, false)
await new Promise((resolve) => setTimeout(resolve, 10))
ac.abort()
await res
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
assertEquals(aborted, true)

await server.shutdown()
})
67 changes: 67 additions & 0 deletions runtime_tests/node/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { env, getRuntimeKey } from '../../src/helper/adapter'
import { basicAuth } from '../../src/middleware/basic-auth'
import { jwt } from '../../src/middleware/jwt'
import { HonoRequest } from '../../src/request'
import { stream, streamSSE } from '../../src/helper/streaming'

// Test only minimal patterns.
// See <https://github.com/honojs/node-server> for more tests and information.
Expand Down Expand Up @@ -96,3 +97,69 @@ describe('JWT Auth Middleware', () => {
expect(res.text).toBe('auth')
})
})

describe('stream', () => {
const app = new Hono()

let aborted = false

app.get('/stream', (c) => {
return stream(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})

const server = createAdaptorServer(app)

it('Should call onAbort', async () => {
const req = request(server)
.get('/stream')
.end(() => {})

expect(aborted).toBe(false)
await new Promise((resolve) => setTimeout(resolve, 10))
req.abort()
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
expect(aborted).toBe(true)
})
})

describe('streamSSE', () => {
const app = new Hono()

let aborted = false

app.get('/stream', (c) => {
return streamSSE(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
return new Promise<void>((resolve) => {
stream.onAbort(resolve)
})
})
})

const server = createAdaptorServer(app)

it('Should call onAbort', async () => {
const req = request(server)
.get('/stream')
.end(() => {})

expect(aborted).toBe(false)
await new Promise((resolve) => setTimeout(resolve, 10))
req.abort()
while (!aborted) {
await new Promise((resolve) => setTimeout(resolve))
}
expect(aborted).toBe(true)
})
})
27 changes: 27 additions & 0 deletions src/helper/streaming/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ describe('SSE Streaming helper', () => {
expect(aborted).toBeTruthy()
})

it('Check streamSSE Response if aborted by abort signal', async () => {
const ac = new AbortController()
const req = new Request('http://localhost/', { signal: ac.signal })
const c = new Context(req)

let aborted = false
const res = streamSSE(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
for (let i = 0; i < 3; i++) {
await stream.writeSSE({
data: `Message ${i}`,
})
await stream.sleep(1)
}
})
if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const { value } = await reader.read()
expect(value).toEqual(new TextEncoder().encode('data: Message 0\n\n'))
ac.abort()
expect(aborted).toBeTruthy()
})

it('Should include retry in the SSE message', async () => {
const retryTime = 3000 // 3 seconds
const res = streamSSE(c, async (stream) => {
Expand Down
8 changes: 8 additions & 0 deletions src/helper/streaming/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const run = async (
}
}

const contextStash = new WeakMap<ReadableStream, Context>()
export const streamSSE = (
c: Context,
cb: (stream: SSEStreamingApi) => Promise<void>,
Expand All @@ -66,6 +67,13 @@ export const streamSSE = (
const { readable, writable } = new TransformStream()
const stream = new SSEStreamingApi(writable, readable)

// bun does not cancel response stream when request is canceled, so detect abort by signal
c.req.raw.signal.addEventListener('abort', () => {
stream.abort()
})
// in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming
contextStash.set(stream.responseReadable, c)

c.header('Transfer-Encoding', 'chunked')
c.header('Content-Type', 'text/event-stream')
c.header('Cache-Control', 'no-cache')
Expand Down
25 changes: 25 additions & 0 deletions src/helper/streaming/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,31 @@ describe('Basic Streaming Helper', () => {
expect(aborted).toBeTruthy()
})

it('Check stream Response if aborted by abort signal', async () => {
const ac = new AbortController()
const req = new Request('http://localhost/', { signal: ac.signal })
const c = new Context(req)

let aborted = false
const res = stream(c, async (stream) => {
stream.onAbort(() => {
aborted = true
})
for (let i = 0; i < 3; i++) {
await stream.write(new Uint8Array([i]))
await stream.sleep(1)
}
})
if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
const { value } = await reader.read()
expect(value).toEqual(new Uint8Array([0]))
ac.abort()
expect(aborted).toBeTruthy()
})

it('Check stream Response if error occurred', async () => {
const onError = vi.fn()
const res = stream(
Expand Down
9 changes: 9 additions & 0 deletions src/helper/streaming/stream.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import type { Context } from '../../context'
import { StreamingApi } from '../../utils/stream'

const contextStash = new WeakMap<ReadableStream, Context>()
export const stream = (
c: Context,
cb: (stream: StreamingApi) => Promise<void>,
onError?: (e: Error, stream: StreamingApi) => Promise<void>
): Response => {
const { readable, writable } = new TransformStream()
const stream = new StreamingApi(writable, readable)

// bun does not cancel response stream when request is canceled, so detect abort by signal
c.req.raw.signal.addEventListener('abort', () => {
stream.abort()
})
// in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming
contextStash.set(stream.responseReadable, c)
;(async () => {
try {
await cb(stream)
Expand All @@ -21,5 +29,6 @@ export const stream = (
stream.close()
}
})()

return c.newResponse(stream.responseReadable)
}
22 changes: 22 additions & 0 deletions src/utils/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,26 @@ describe('StreamingApi', () => {
expect(handleAbort1).toBeCalled()
expect(handleAbort2).toBeCalled()
})

it('abort()', async () => {
const { readable, writable } = new TransformStream()
const handleAbort1 = vi.fn()
const handleAbort2 = vi.fn()
const api = new StreamingApi(writable, readable)
api.onAbort(handleAbort1)
api.onAbort(handleAbort2)
expect(handleAbort1).not.toBeCalled()
expect(handleAbort2).not.toBeCalled()
expect(api.aborted).toBe(false)

api.abort()
expect(handleAbort1).toHaveBeenCalledOnce()
expect(handleAbort2).toHaveBeenCalledOnce()
expect(api.aborted).toBe(true)

api.abort()
expect(handleAbort1).toHaveBeenCalledOnce()
expect(handleAbort2).toHaveBeenCalledOnce()
expect(api.aborted).toBe(true)
})
})
Loading

0 comments on commit 2d3bc55

Please sign in to comment.