Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions src/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,45 @@ type UpgradeWebSocketOptions = {
onError: (err: unknown) => void
}

const rejectUpgradeRequest = (socket: Duplex, status: number) => {
// Hop-by-hop headers per RFC 9110 Section 7.6.1
// (https://www.rfc-editor.org/rfc/rfc9110.html#name-connection) plus
// `keep-alive` (commonly treated as hop-by-hop by HTTP implementations) and
// WebSocket handshake headers managed by `ws` itself. These must not be
// forwarded onto the upgrade response or the handshake will be corrupted.
const responseHeadersToSkip = new Set([
'connection',
'content-length',
'keep-alive',
'proxy-authenticate',
'proxy-authorization',
'te',
'trailer',
'transfer-encoding',
'upgrade',
'sec-websocket-accept',
'sec-websocket-extensions',
'sec-websocket-protocol',
])

const appendResponseHeaders = (headers: string[], responseHeaders?: Headers) => {
if (!responseHeaders) {
return
}
responseHeaders.forEach((value, key) => {
if (responseHeadersToSkip.has(key.toLowerCase())) {
return
}
headers.push(`${key}: ${value}`)
})
}

const rejectUpgradeRequest = (socket: Duplex, status: number, responseHeaders?: Headers) => {
const responseLines = ['Connection: close', 'Content-Length: 0']
appendResponseHeaders(responseLines, responseHeaders)

socket.end(
`HTTP/1.1 ${status.toString()} ${STATUS_CODES[status] ?? ''}\r\n` +
'Connection: close\r\n' +
'Content-Length: 0\r\n' +
`${responseLines.join('\r\n')}\r\n` +
'\r\n'
)
}
Expand Down Expand Up @@ -121,13 +155,15 @@ export const setupWebSocket = (options: {
}

let status = 400
let responseHeaders: Headers | undefined
try {
const response = (await fetchCallback(
createUpgradeRequest(request),
env as unknown as Parameters<FetchCallback>[1]
)) as Response
if (response instanceof Response) {
status = response.status
responseHeaders = response.headers
}
} catch {
if (server.listenerCount('upgrade') === 1) {
Expand All @@ -141,14 +177,25 @@ export const setupWebSocket = (options: {
if (!waiter || waiter.connectionSymbol !== env[CONNECTION_SYMBOL_KEY]) {
waiterMap.delete(request)
if (server.listenerCount('upgrade') === 1) {
rejectUpgradeRequest(socket, status)
rejectUpgradeRequest(socket, status, responseHeaders)
}
return
}

wss.handleUpgrade(request, socket, head, (ws) => {
wss.emit('connection', ws, request)
})
const addResponseHeaders = (headers: string[]) => {
appendResponseHeaders(headers, responseHeaders)
}

// `headers` is emitted synchronously inside `handleUpgrade`, so this
// listener cannot leak across concurrent upgrades on the shared `wss`.
wss.on('headers', addResponseHeaders)
try {
wss.handleUpgrade(request, socket, head, (ws) => {
wss.emit('connection', ws, request)
})
} finally {
wss.off('headers', addResponseHeaders)
}
})

server.on('close', () => {
Expand Down
67 changes: 67 additions & 0 deletions test/websocket.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,73 @@ describe('WebSocket', () => {
}
})

it('should forward response headers set by middleware on a successful upgrade', async () => {
const app = new Hono()

app.use(async (c, next) => {
await next()
c.header('x-auth-result', c.req.header('authorization') ? 'authorized' : 'missing')
})
app.get(
'/ws',
upgradeWebSocket(() => ({}))
)

const { server, address } = await startServer(app)

try {
const ws = new WebSocket(`ws://127.0.0.1:${address.port}/ws`, {
headers: {
authorization: 'Bearer token',
},
})
const responseHeaders = await new Promise<Record<string, string | string[] | undefined>>(
(resolve, reject) => {
ws.once('upgrade', (response) => {
resolve(response.headers)
})
ws.once('error', reject)
}
)
expect(responseHeaders['x-auth-result']).toBe('authorized')
ws.close()
} finally {
await new Promise<void>((resolve) => server.close(() => resolve()))
}
})

it('should forward response headers when the upgrade is rejected', async () => {
const app = new Hono()

app.get(
'/ws',
() =>
new Response(null, {
status: 401,
headers: {
'x-auth-result': 'missing',
},
})
)

const { server, address } = await startServer(app)

try {
const ws = new WebSocket(`ws://127.0.0.1:${address.port}/ws`)
const responseHeaders = await new Promise<Record<string, string | string[] | undefined>>(
(resolve, reject) => {
ws.once('unexpected-response', (_, response) => {
resolve(response.headers)
})
ws.once('open', () => reject(new Error('WebSocket must not be upgraded')))
}
)
expect(responseHeaders['x-auth-result']).toBe('missing')
} finally {
await new Promise<void>((resolve) => server.close(() => resolve()))
}
})

it('should not block other upgrade listeners', async () => {
const app = new Hono()
const { server, address } = await startServer(app)
Expand Down
Loading