diff --git a/src/websocket.ts b/src/websocket.ts index fe0cd35..69c778c 100644 --- a/src/websocket.ts +++ b/src/websocket.ts @@ -57,11 +57,45 @@ type UpgradeWebSocketOptions = { onError: (err: unknown) => void } -const rejectUpgradeRequest = (socket: Duplex, status: number) => { +// Hop-by-hop headers per RFC 9110 Section 7.6.1 +// (https://www.rfc-editor.org/rfc/rfc9110.html#name-connection) plus +// `keep-alive` (commonly treated as hop-by-hop by HTTP implementations) and +// WebSocket handshake headers managed by `ws` itself. These must not be +// forwarded onto the upgrade response or the handshake will be corrupted. +const responseHeadersToSkip = new Set([ + 'connection', + 'content-length', + 'keep-alive', + 'proxy-authenticate', + 'proxy-authorization', + 'te', + 'trailer', + 'transfer-encoding', + 'upgrade', + 'sec-websocket-accept', + 'sec-websocket-extensions', + 'sec-websocket-protocol', +]) + +const appendResponseHeaders = (headers: string[], responseHeaders?: Headers) => { + if (!responseHeaders) { + return + } + responseHeaders.forEach((value, key) => { + if (responseHeadersToSkip.has(key.toLowerCase())) { + return + } + headers.push(`${key}: ${value}`) + }) +} + +const rejectUpgradeRequest = (socket: Duplex, status: number, responseHeaders?: Headers) => { + const responseLines = ['Connection: close', 'Content-Length: 0'] + appendResponseHeaders(responseLines, responseHeaders) + socket.end( `HTTP/1.1 ${status.toString()} ${STATUS_CODES[status] ?? ''}\r\n` + - 'Connection: close\r\n' + - 'Content-Length: 0\r\n' + + `${responseLines.join('\r\n')}\r\n` + '\r\n' ) } @@ -121,6 +155,7 @@ export const setupWebSocket = (options: { } let status = 400 + let responseHeaders: Headers | undefined try { const response = (await fetchCallback( createUpgradeRequest(request), @@ -128,6 +163,7 @@ export const setupWebSocket = (options: { )) as Response if (response instanceof Response) { status = response.status + responseHeaders = response.headers } } catch { if (server.listenerCount('upgrade') === 1) { @@ -141,14 +177,25 @@ export const setupWebSocket = (options: { if (!waiter || waiter.connectionSymbol !== env[CONNECTION_SYMBOL_KEY]) { waiterMap.delete(request) if (server.listenerCount('upgrade') === 1) { - rejectUpgradeRequest(socket, status) + rejectUpgradeRequest(socket, status, responseHeaders) } return } - wss.handleUpgrade(request, socket, head, (ws) => { - wss.emit('connection', ws, request) - }) + const addResponseHeaders = (headers: string[]) => { + appendResponseHeaders(headers, responseHeaders) + } + + // `headers` is emitted synchronously inside `handleUpgrade`, so this + // listener cannot leak across concurrent upgrades on the shared `wss`. + wss.on('headers', addResponseHeaders) + try { + wss.handleUpgrade(request, socket, head, (ws) => { + wss.emit('connection', ws, request) + }) + } finally { + wss.off('headers', addResponseHeaders) + } }) server.on('close', () => { diff --git a/test/websocket.test.ts b/test/websocket.test.ts index 830ad9b..a6d3c75 100644 --- a/test/websocket.test.ts +++ b/test/websocket.test.ts @@ -69,6 +69,73 @@ describe('WebSocket', () => { } }) + it('should forward response headers set by middleware on a successful upgrade', async () => { + const app = new Hono() + + app.use(async (c, next) => { + await next() + c.header('x-auth-result', c.req.header('authorization') ? 'authorized' : 'missing') + }) + app.get( + '/ws', + upgradeWebSocket(() => ({})) + ) + + const { server, address } = await startServer(app) + + try { + const ws = new WebSocket(`ws://127.0.0.1:${address.port}/ws`, { + headers: { + authorization: 'Bearer token', + }, + }) + const responseHeaders = await new Promise>( + (resolve, reject) => { + ws.once('upgrade', (response) => { + resolve(response.headers) + }) + ws.once('error', reject) + } + ) + expect(responseHeaders['x-auth-result']).toBe('authorized') + ws.close() + } finally { + await new Promise((resolve) => server.close(() => resolve())) + } + }) + + it('should forward response headers when the upgrade is rejected', async () => { + const app = new Hono() + + app.get( + '/ws', + () => + new Response(null, { + status: 401, + headers: { + 'x-auth-result': 'missing', + }, + }) + ) + + const { server, address } = await startServer(app) + + try { + const ws = new WebSocket(`ws://127.0.0.1:${address.port}/ws`) + const responseHeaders = await new Promise>( + (resolve, reject) => { + ws.once('unexpected-response', (_, response) => { + resolve(response.headers) + }) + ws.once('open', () => reject(new Error('WebSocket must not be upgraded'))) + } + ) + expect(responseHeaders['x-auth-result']).toBe('missing') + } finally { + await new Promise((resolve) => server.close(() => resolve())) + } + }) + it('should not block other upgrade listeners', async () => { const app = new Hono() const { server, address } = await startServer(app)