diff --git a/deno_dist/client/client.ts b/deno_dist/client/client.ts index e6bbb2b7c..de04ee7f4 100644 --- a/deno_dist/client/client.ts +++ b/deno_dist/client/client.ts @@ -3,7 +3,13 @@ import type { ValidationTargets } from '../types.ts' import { serialize } from '../utils/cookie.ts' import type { UnionToIntersection } from '../utils/types.ts' import type { Callback, Client, ClientRequestOptions } from './types.ts' -import { deepMerge, mergePath, removeIndexString, replaceUrlParam } from './utils.ts' +import { + deepMerge, + mergePath, + removeIndexString, + replaceUrlParam, + replaceUrlProtocol, +} from './utils.ts' const createProxy = (callback: Callback, path: string[]) => { const proxy: unknown = new Proxy(() => {}, { @@ -147,8 +153,11 @@ export const hc = >( return new URL(url) } if (method === 'ws') { - const targetUrl = - opts.args[0] && opts.args[0].param ? replaceUrlParam(url, opts.args[0].param) : url + const targetUrl = replaceUrlProtocol( + opts.args[0] && opts.args[0].param ? replaceUrlParam(url, opts.args[0].param) : url, + 'ws' + ) + return new WebSocket(targetUrl) } diff --git a/deno_dist/client/utils.ts b/deno_dist/client/utils.ts index 517993619..ccaf84095 100644 --- a/deno_dist/client/utils.ts +++ b/deno_dist/client/utils.ts @@ -15,6 +15,15 @@ export const replaceUrlParam = (urlString: string, params: Record { + switch (protocol) { + case 'ws': + return urlString.replace(/^http/, 'ws') + case 'http': + return urlString.replace(/^ws/, 'http') + } +} + export const removeIndexString = (urlSting: string) => { return urlSting.replace(/\/index$/, '') } diff --git a/src/client/client.test.ts b/src/client/client.test.ts index 799f555a5..afe7f9fb4 100644 --- a/src/client/client.test.ts +++ b/src/client/client.test.ts @@ -4,6 +4,7 @@ import { rest } from 'msw' import { setupServer } from 'msw/node' import { expectTypeOf, vi } from 'vitest' +import { upgradeWebSocket } from '../helper' import { Hono } from '../hono' import { parse } from '../utils/cookie' import type { Equal, Expect } from '../utils/types' @@ -686,3 +687,58 @@ describe('Dynamic headers', () => { expect(data.requestDynamic).toEqual('two') }) }) + +describe('WebSocket URL Protocol Translation', () => { + const app = new Hono() + const route = app.get( + '/', + upgradeWebSocket((c) => ({ + onMessage(event, ws) { + console.log(`Message from client: ${event.data}`) + ws.send('Hello from server!') + }, + onClose: () => { + console.log('Connection closed') + }, + })) + ) + + type AppType = typeof route + + const server = setupServer() + const webSocketMock = vi.fn() + + beforeAll(() => server.listen()) + beforeEach(() => { + vi.stubGlobal('WebSocket', webSocketMock) + }) + afterEach(() => { + vi.clearAllMocks() + server.resetHandlers() + }) + afterAll(() => server.close()) + + it('Translates HTTP to ws', async () => { + const client = hc('http://localhost') + client.index.$ws() + expect(webSocketMock).toHaveBeenCalledWith('ws://localhost/index') + }) + + it('Translates HTTPS to wss', async () => { + const client = hc('https://localhost') + client.index.$ws() + expect(webSocketMock).toHaveBeenCalledWith('wss://localhost/index') + }) + + it('Keeps ws unchanged', async () => { + const client = hc('ws://localhost') + client.index.$ws() + expect(webSocketMock).toHaveBeenCalledWith('ws://localhost/index') + }) + + it('Keeps wss unchanged', async () => { + const client = hc('wss://localhost') + client.index.$ws() + expect(webSocketMock).toHaveBeenCalledWith('wss://localhost/index') + }) +}) diff --git a/src/client/client.ts b/src/client/client.ts index 8165c57cc..5c609ea61 100644 --- a/src/client/client.ts +++ b/src/client/client.ts @@ -3,7 +3,13 @@ import type { ValidationTargets } from '../types' import { serialize } from '../utils/cookie' import type { UnionToIntersection } from '../utils/types' import type { Callback, Client, ClientRequestOptions } from './types' -import { deepMerge, mergePath, removeIndexString, replaceUrlParam } from './utils' +import { + deepMerge, + mergePath, + removeIndexString, + replaceUrlParam, + replaceUrlProtocol, +} from './utils' const createProxy = (callback: Callback, path: string[]) => { const proxy: unknown = new Proxy(() => {}, { @@ -147,8 +153,11 @@ export const hc = >( return new URL(url) } if (method === 'ws') { - const targetUrl = - opts.args[0] && opts.args[0].param ? replaceUrlParam(url, opts.args[0].param) : url + const targetUrl = replaceUrlProtocol( + opts.args[0] && opts.args[0].param ? replaceUrlParam(url, opts.args[0].param) : url, + 'ws' + ) + return new WebSocket(targetUrl) } diff --git a/src/client/utils.test.ts b/src/client/utils.test.ts index 32ddfe415..3a345f5b2 100644 --- a/src/client/utils.test.ts +++ b/src/client/utils.test.ts @@ -1,4 +1,10 @@ -import { deepMerge, mergePath, removeIndexString, replaceUrlParam } from './utils' +import { + deepMerge, + mergePath, + removeIndexString, + replaceUrlParam, + replaceUrlProtocol, +} from './utils' describe('mergePath', () => { it('Should merge paths correctly', () => { @@ -42,6 +48,32 @@ describe('replaceUrlParams', () => { }) }) +describe('replaceUrlProtocol', () => { + it('Should replace http to ws', () => { + const url = 'http://localhost' + const newUrl = replaceUrlProtocol(url, 'ws') + expect(newUrl).toBe('ws://localhost') + }) + + it('Should replace https to wss', () => { + const url = 'https://localhost' + const newUrl = replaceUrlProtocol(url, 'ws') + expect(newUrl).toBe('wss://localhost') + }) + + it('Should replace ws to http', () => { + const url = 'ws://localhost' + const newUrl = replaceUrlProtocol(url, 'http') + expect(newUrl).toBe('http://localhost') + }) + + it('Should replace wss to https', () => { + const url = 'wss://localhost' + const newUrl = replaceUrlProtocol(url, 'http') + expect(newUrl).toBe('https://localhost') + }) +}) + describe('removeIndexString', () => { it('Should remove last `/index` string', () => { let url = 'http://localhost/index' diff --git a/src/client/utils.ts b/src/client/utils.ts index d46ce242f..2f6887242 100644 --- a/src/client/utils.ts +++ b/src/client/utils.ts @@ -15,6 +15,15 @@ export const replaceUrlParam = (urlString: string, params: Record { + switch (protocol) { + case 'ws': + return urlString.replace(/^http/, 'ws') + case 'http': + return urlString.replace(/^ws/, 'http') + } +} + export const removeIndexString = (urlSting: string) => { return urlSting.replace(/\/index$/, '') }