From 809e2ceaf95ca37c68c2924a8648b314cefca47b Mon Sep 17 00:00:00 2001 From: Matt Diehl Date: Fri, 29 May 2026 12:14:50 -0700 Subject: [PATCH] Fix macOS relay HOL and leak, Linux relay cleanup issues. --- Package.resolved | 19 +- .../Containerization/UnixSocketRelay.swift | 16 +- .../UnixSocketRelayManager.swift | 3 - .../Socket/BidirectionalRelay.swift | 425 +++++++++++------- .../BidirectionalRelayTests.swift | 345 ++++++++++++++ vminitd/Sources/VminitdCore/VsockProxy.swift | 10 +- 6 files changed, 638 insertions(+), 180 deletions(-) create mode 100644 Tests/ContainerizationOSTests/BidirectionalRelayTests.swift diff --git a/Package.resolved b/Package.resolved index abcc214e..7302387c 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "f2af83112ef9c25538d60f115c1d21ccfa89e850a8685333af1b3492ff8cda36", + "originHash" : "3eb53ed7f842d2bfb606d5acdfdb3db4278338a0a8ef9ca39548ea23c856f0a6", "pins" : [ { "identity" : "async-http-client", @@ -19,15 +19,6 @@ "version" : "2.3.0" } }, - { - "identity" : "grpc-swift-nio-transport", - "kind" : "remoteSourceControl", - "location" : "https://github.com/grpc/grpc-swift-nio-transport.git", - "state" : { - "revision" : "f37e0c2d293cea668b11e10e1fb1c24cb40781ff", - "version" : "2.4.4" - } - }, { "identity" : "grpc-swift-protobuf", "kind" : "remoteSourceControl", @@ -159,8 +150,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio.git", "state" : { - "revision" : "4e8f4b1c9adaa59315c523540c1ff2b38adc20a9", - "version" : "2.87.0" + "revision" : "f71c8d2a5e74a2c6d11a0fbe324774b5d6084237", + "version" : "2.99.0" } }, { @@ -177,8 +168,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-http2.git", "state" : { - "revision" : "5e9e99ec96c53bc2c18ddd10c1e25a3cd97c55e5", - "version" : "1.38.0" + "revision" : "81cc18264f92cd307ff98430f89372711d4f6fe9", + "version" : "1.43.0" } }, { diff --git a/Sources/Containerization/UnixSocketRelay.swift b/Sources/Containerization/UnixSocketRelay.swift index e7a3304a..4069bf4e 100644 --- a/Sources/Containerization/UnixSocketRelay.swift +++ b/Sources/Containerization/UnixSocketRelay.swift @@ -25,7 +25,6 @@ package final class UnixSocketRelay: Sendable { private let port: UInt32 private let configuration: UnixSocketConfiguration private let vm: any VirtualMachineInstance - private let queue: DispatchQueue private let log: Logger? private let state: Mutex @@ -39,13 +38,11 @@ package final class UnixSocketRelay: Sendable { port: UInt32, socket: UnixSocketConfiguration, vm: any VirtualMachineInstance, - queue: DispatchQueue, log: Logger? = nil ) throws { self.port = port self.configuration = socket self.vm = vm - self.queue = queue self.log = log self.state = Mutex(.init()) } @@ -224,7 +221,6 @@ extension UnixSocketRelay { let relay = BidirectionalRelay( fd1: hostFd, fd2: guestFd, - queue: queue, log: log ) @@ -232,6 +228,16 @@ extension UnixSocketRelay { $0.activeRelays[relayID] = relay } - relay.start() + do { + try relay.start() + } catch { + state.withLock { $0.activeRelays[relayID] = nil } + throw error + } + + Task { + await relay.waitForCompletion() + state.withLock { $0.activeRelays[relayID] = nil } + } } } diff --git a/Sources/Containerization/UnixSocketRelayManager.swift b/Sources/Containerization/UnixSocketRelayManager.swift index ceae276c..f1f9d3d7 100644 --- a/Sources/Containerization/UnixSocketRelayManager.swift +++ b/Sources/Containerization/UnixSocketRelayManager.swift @@ -21,13 +21,11 @@ import Logging package actor UnixSocketRelayManager { private let vm: any VirtualMachineInstance private var relays: [String: UnixSocketRelay] - private let queue: DispatchQueue private let log: Logger? init(vm: any VirtualMachineInstance, log: Logger? = nil) { self.vm = vm self.relays = [:] - self.queue = DispatchQueue(label: "com.apple.containerization.socket-relay") self.log = log } } @@ -45,7 +43,6 @@ extension UnixSocketRelayManager { port: port, socket: socket, vm: vm, - queue: queue, log: log ) diff --git a/Sources/ContainerizationOS/Socket/BidirectionalRelay.swift b/Sources/ContainerizationOS/Socket/BidirectionalRelay.swift index 96420f97..649a212a 100644 --- a/Sources/ContainerizationOS/Socket/BidirectionalRelay.swift +++ b/Sources/ContainerizationOS/Socket/BidirectionalRelay.swift @@ -34,16 +34,64 @@ import Foundation #endif /// Manages bidirectional data relay between two file descriptors using `DispatchSource`. +/// +/// Uses non-blocking I/O with backpressure: when a destination fd's buffer is full, +/// the relay suspends reading from the source and installs a `DispatchSourceWrite` +/// to resume once the destination is writable again. This prevents blocking the +/// dispatch queue and avoids head-of-line blocking across connections. +/// +/// ## Concurrency model +/// +/// The class has two distinct synchronization domains: +/// +/// - **Serial dispatch queue** — owns all I/O state: the `Direction` objects (`d1`, `d2`) +/// and their read buffers (`buf1`, `buf2`). Every event handler, cancel handler, and +/// `stop()` call runs on this queue. No locks are needed for that state because the +/// queue is the exclusive executor. Fields in this domain are marked `nonisolated(unsafe)`. +/// +/// - **Mutexes** — protect the two pieces of state that cross the queue boundary: +/// `activeDirections` (written by `start()`, which may run off-queue) and +/// `completionState` (read by `waitForCompletion()` from any async context). public final class BidirectionalRelay: Sendable { private let fd1: Int32 private let fd2: Int32 private let log: Logger? private let queue: DispatchQueue + private static let queueKey = DispatchSpecificKey() - // `DispatchSourceRead` is thread-safe. - private struct ConnectionSources: @unchecked Sendable { - let source1: DispatchSourceRead - let source2: DispatchSourceRead + /// Owns one direction of the relay: its read source, optional write source, and + /// any data buffered due to backpressure. + /// + /// All methods must be called only from the relay's serial dispatch queue. + private final class Direction { + var readSource: DispatchSourceRead? + var writeSource: DispatchSourceWrite? + var pendingData: [UInt8] = [] + var pendingOffset: Int = 0 + private var readSuspended = false + + func suspendRead() { + guard let src = readSource, !src.isCancelled, !readSuspended else { return } + readSuspended = true + src.suspend() + } + + func resumeRead() { + guard let src = readSource, !src.isCancelled, readSuspended else { return } + readSuspended = false + src.resume() + } + + /// Resumes the read source before cancelling it if it is suspended. + /// GCD does not deliver a cancel handler for a suspended source until it is resumed. + func cancelRead() { + guard let src = readSource, !src.isCancelled else { return } + if readSuspended { + readSuspended = false + src.resume() + } + src.cancel() + } } private enum CompletionState { @@ -52,12 +100,31 @@ public final class BidirectionalRelay: Sendable { case completed } - private let state: Mutex - private let completionState: Mutex + private enum CopyResult { + case ok + case blocked + case eof + } - // The buffers aren't used concurrently. - private nonisolated(unsafe) let buffer1: UnsafeMutableBufferPointer - private nonisolated(unsafe) let buffer2: UnsafeMutableBufferPointer + // Queue-owned state. Written by start() before activate(), so all subsequent + // accesses from event/cancel handlers observe the initialized values without + // additional synchronization. nonisolated(unsafe) declares that we are taking + // responsibility for this; the serial queue is the enforcing mechanism. + private nonisolated(unsafe) let d1 = Direction() // fd1 → fd2 + private nonisolated(unsafe) let d2 = Direction() // fd2 → fd1 + private nonisolated(unsafe) let buf1: UnsafeMutableBufferPointer + private nonisolated(unsafe) let buf2: UnsafeMutableBufferPointer + + // Counts active read sources. Set to 2 in start() (possibly off-queue) and + // decremented in cancel handlers (always on the queue). The Mutex provides the + // cross-thread visibility guarantee for the initial write from start(). Whichever + // cancel handler drives the count to zero calls closeBothFds() exactly once — + // no cross-source isCancelled checks, no possibility of double-close. + private let activeDirections: Mutex + + // May be read from any async context (waitForCompletion) and written from the + // queue (closeBothFds), so it needs a Mutex rather than queue-only protection. + private let completionState: Mutex /// Creates a new bidirectional relay between two file descriptors. /// @@ -75,84 +142,71 @@ public final class BidirectionalRelay: Sendable { self.fd1 = fd1 self.fd2 = fd2 self.queue = queue ?? DispatchQueue(label: "com.apple.containerization.bidirectional-relay") + self.queue.setSpecific(key: Self.queueKey, value: ()) self.log = log - self.state = Mutex(nil) + self.activeDirections = Mutex(0) self.completionState = Mutex(.pending) let pageSize = Int(getpagesize()) - self.buffer1 = UnsafeMutableBufferPointer.allocate(capacity: pageSize) - self.buffer2 = UnsafeMutableBufferPointer.allocate(capacity: pageSize) + self.buf1 = UnsafeMutableBufferPointer.allocate(capacity: pageSize) + self.buf2 = UnsafeMutableBufferPointer.allocate(capacity: pageSize) } deinit { - buffer1.deallocate() - buffer2.deallocate() + buf1.deallocate() + buf2.deallocate() } - /// Starts the bidirectional relay to copy data from fd1 to fd2 and from fd2 to fd1. - public func start() { - let source1 = DispatchSource.makeReadSource(fileDescriptor: fd1, queue: queue) - let source2 = DispatchSource.makeReadSource(fileDescriptor: fd2, queue: queue) - state.withLock { - $0 = ConnectionSources(source1: source1, source2: source2) - } - - source1.setEventHandler { [self] in - self.fdCopyHandler( - buffer: self.buffer1, - source: source1, - from: self.fd1, - to: self.fd2 + private static func setNonBlocking(_ fd: Int32) throws { + let flags = fcntl(fd, F_GETFL) + guard flags != -1 else { + throw ContainerizationError( + .internalError, + message: "fcntl F_GETFL failed on fd \(fd): errno \(errno)" ) } - - source2.setEventHandler { [self] in - self.fdCopyHandler( - buffer: self.buffer2, - source: source2, - from: self.fd2, - to: self.fd1 + guard fcntl(fd, F_SETFL, flags | O_NONBLOCK) != -1 else { + throw ContainerizationError( + .internalError, + message: "fcntl F_SETFL O_NONBLOCK failed on fd \(fd): errno \(errno)" ) } + } - // Only close underlying fds when both sources are at EOF. - // Ensure that one of the cancel handlers will see both sources cancelled. - source1.setCancelHandler { [self] in - self.log?.debug( - "source1 cancel received", - metadata: ["fd1": "\(self.fd1)", "fd2": "\(self.fd2)"] - ) + /// Starts the bidirectional relay to copy data between fd1 and fd2. + public func start() throws { + try Self.setNonBlocking(fd1) + try Self.setNonBlocking(fd2) - self.state.withLock { _ in - if source2.isCancelled { - self.closeBothFds() - } - } - } + let src1 = DispatchSource.makeReadSource(fileDescriptor: fd1, queue: queue) + let src2 = DispatchSource.makeReadSource(fileDescriptor: fd2, queue: queue) + d1.readSource = src1 + d2.readSource = src2 + activeDirections.withLock { $0 = 2 } - source2.setCancelHandler { [self] in - self.log?.debug( - "source2 cancel received", - metadata: ["fd1": "\(self.fd1)", "fd2": "\(self.fd2)"] - ) + src1.setEventHandler { [self] in handleRead(d1, from: fd1, to: fd2, buffer: buf1) } + src2.setEventHandler { [self] in handleRead(d2, from: fd2, to: fd1, buffer: buf2) } - self.state.withLock { _ in - if source1.isCancelled { - self.closeBothFds() - } - } + src1.setCancelHandler { [self] in + d1.writeSource?.cancel() + d1.writeSource = nil + directionFinished() + } + src2.setCancelHandler { [self] in + d2.writeSource?.cancel() + d2.writeSource = nil + directionFinished() } - source1.activate() - source2.activate() + src1.activate() + src2.activate() } /// Stops the relay and closes both file descriptors. public func stop() { - state.withLock { sources in - sources?.source1.cancel() - sources?.source2.cancel() - sources = nil + runOnQueue { + d1.cancelRead() + d2.cancelRead() } } @@ -172,136 +226,193 @@ public final class BidirectionalRelay: Sendable { } } - private func fdCopyHandler( - buffer: UnsafeMutableBufferPointer, - source: DispatchSourceRead, - from sourceFd: Int32, - to destinationFd: Int32 + private func runOnQueue(_ work: () -> Void) { + if DispatchQueue.getSpecific(key: Self.queueKey) != nil { + work() + } else { + queue.sync(execute: work) + } + } + + private func directionFinished() { + let remaining = activeDirections.withLock { count -> Int in + count -= 1 + return count + } + if remaining == 0 { + closeBothFds() + } + } + + private func handleRead( + _ dir: Direction, + from srcFd: Int32, + to dstFd: Int32, + buffer: UnsafeMutableBufferPointer ) { - if source.data == 0 { - log?.debug( - "source EOF", - metadata: [ - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", - ] - ) - if !source.isCancelled { + do { + switch try Self.copy(buffer: buffer, from: srcFd, to: dstFd, direction: dir) { + case .ok: + break + + case .eof: log?.debug( - "canceling DispatchSourceRead", - metadata: [ - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", - ] + "source EOF", + metadata: ["sourceFd": "\(srcFd)", "destinationFd": "\(dstFd)"] ) - source.cancel() - if shutdown(destinationFd, Int32(SHUT_WR)) != 0 { + dir.cancelRead() + if shutdown(dstFd, Int32(SHUT_WR)) != 0 { log?.debug( - "failed to shut down writes", - metadata: [ - "errno": "\(errno)", - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", - ] + "shutdown(SHUT_WR) failed", + metadata: ["fd": "\(dstFd)", "errno": "\(errno)"] ) } - } - return - } - do { - log?.trace( - "source copy", - metadata: [ - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", - "size": "\(source.data)", - ] - ) - try Self.fileDescriptorCopy( - buffer: buffer, - size: source.data, - from: sourceFd, - to: destinationFd - ) + case .blocked: + log?.debug( + "write blocked, applying backpressure", + metadata: [ + "sourceFd": "\(srcFd)", + "destinationFd": "\(dstFd)", + "pendingBytes": "\(dir.pendingData.count)", + ] + ) + dir.suspendRead() + installWriteSource(for: dir, from: srcFd, to: dstFd) + } } catch { log?.warning( - "file descriptor copy failed", + "I/O error", metadata: [ + "sourceFd": "\(srcFd)", + "destinationFd": "\(dstFd)", "error": "\(error)", - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", ] ) - if !source.isCancelled { - source.cancel() - if shutdown(destinationFd, Int32(SHUT_RDWR)) != 0 { - log?.warning( - "failed to shut down destination after I/O error", - metadata: [ - "errno": "\(errno)", - "sourceFd": "\(sourceFd)", - "destinationFd": "\(destinationFd)", - ] - ) - } + dir.cancelRead() + if shutdown(dstFd, Int32(SHUT_RDWR)) != 0 { + log?.warning( + "shutdown(SHUT_RDWR) failed", + metadata: ["fd": "\(dstFd)", "errno": "\(errno)"] + ) } } } - private static func fileDescriptorCopy( - buffer: UnsafeMutableBufferPointer, - size: UInt, - from sourceFd: Int32, - to destinationFd: Int32 - ) throws { - let bufferSize = buffer.count - var readBytesRemaining = min(Int(size), bufferSize) - - guard let baseAddr = buffer.baseAddress else { - throw ContainerizationError( - .invalidState, - message: "buffer has no base address" + private func installWriteSource(for dir: Direction, from srcFd: Int32, to dstFd: Int32) { + let ws = DispatchSource.makeWriteSource(fileDescriptor: dstFd, queue: queue) + dir.writeSource = ws + ws.setEventHandler { [self] in drainPending(dir: dir, from: srcFd, to: dstFd) } + // No cancel handler: clearing pendingData from a cancel handler would race with + // a newly installed write source if drainPending completes and the read source + // immediately produces another blocked write, installing a fresh write source + // before the old cancel handler fires. pendingData is cleared explicitly by + // drainPending on success, and freed with Direction when the relay is torn down. + ws.activate() + } + + private func drainPending(dir: Direction, from srcFd: Int32, to dstFd: Int32) { + let remaining = dir.pendingData.count - dir.pendingOffset + guard remaining > 0 else { return } + + let n = dir.pendingData.withUnsafeBufferPointer { buf in + write(dstFd, buf.baseAddress!.advanced(by: dir.pendingOffset), remaining) + } + + if n > 0 { + dir.pendingOffset += n + if dir.pendingOffset >= dir.pendingData.count { + dir.writeSource?.cancel() + dir.writeSource = nil + dir.pendingData = [] + dir.pendingOffset = 0 + log?.debug( + "backpressure relieved, resuming reads", + metadata: ["sourceFd": "\(srcFd)", "destinationFd": "\(dstFd)"] + ) + dir.resumeRead() + } + } else if n == -1 && errno == EAGAIN { + // Spurious write-ready notification; wait for the next one. + } else { + log?.warning( + "write error during pending drain", + metadata: ["destinationFd": "\(dstFd)", "errno": "\(errno)"] ) + dir.writeSource?.cancel() + dir.writeSource = nil + dir.cancelRead() + if shutdown(dstFd, Int32(SHUT_RDWR)) != 0 { + log?.warning( + "shutdown(SHUT_RDWR) failed after drain error", + metadata: ["fd": "\(dstFd)", "errno": "\(errno)"] + ) + } + } + } + + /// Drains srcFd into dstFd in a loop until EAGAIN/EWOULDBLOCK or EOF. + /// + /// Looping until EAGAIN is required on Linux: libdispatch uses FIONREAD to decide + /// whether to fire the read event, so when the only remaining readable condition is + /// EOF (FIONREAD == 0), the event is suppressed. Reading in a loop here ensures we + /// observe read() == 0 on the same handler invocation that drained the last bytes. + private static func copy( + buffer: UnsafeMutableBufferPointer, + from srcFd: Int32, + to dstFd: Int32, + direction: Direction + ) throws -> CopyResult { + guard let base = buffer.baseAddress else { + throw ContainerizationError(.invalidState, message: "buffer has no base address") } - while readBytesRemaining > 0 { - let readResult = read(sourceFd, baseAddr, min(bufferSize, readBytesRemaining)) - if readResult <= 0 { + readLoop: while true { + let nr = read(srcFd, base, buffer.count) + if nr == 0 { return .eof } + if nr < 0 { + if errno == EAGAIN || errno == EWOULDBLOCK { return .ok } + if errno == EINTR { continue readLoop } throw ContainerizationError( .internalError, - message: "zero byte read or error in socket relay: fd \(sourceFd), result \(readResult)" + message: "read failed: fd \(srcFd), errno \(errno)" ) } - readBytesRemaining -= readResult - var writeBytesRemaining = readResult - var writeOffset = 0 - while writeBytesRemaining > 0 { - let writeResult = write(destinationFd, baseAddr.advanced(by: writeOffset), writeBytesRemaining) - if writeResult <= 0 { + var offset = 0 + while offset < nr { + let nw = write(dstFd, base.advanced(by: offset), nr - offset) + if nw > 0 { + offset += nw + } else if nw < 0 { + if errno == EINTR { continue } + if errno == EAGAIN || errno == EWOULDBLOCK { + direction.pendingData = Array( + UnsafeBufferPointer(start: base.advanced(by: offset), count: nr - offset) + ) + direction.pendingOffset = 0 + return .blocked + } throw ContainerizationError( .internalError, - message: "zero byte write or error in socket relay: fd \(destinationFd), result \(writeResult)" + message: "write failed: fd \(dstFd), errno \(errno)" + ) + } else { + throw ContainerizationError( + .internalError, + message: "zero-byte write on fd \(dstFd)" ) } - writeBytesRemaining -= writeResult - writeOffset += writeResult } } } private func closeBothFds() { - log?.debug( - "close file descriptors", - metadata: ["fd1": "\(fd1)", "fd2": "\(fd2)"] - ) + log?.debug("closing fds", metadata: ["fd1": "\(fd1)", "fd2": "\(fd2)"]) close(fd1) close(fd2) completionState.withLock { state in - if case .waiting(let c) = state { - c.resume() - } + if case .waiting(let c) = state { c.resume() } state = .completed } } diff --git a/Tests/ContainerizationOSTests/BidirectionalRelayTests.swift b/Tests/ContainerizationOSTests/BidirectionalRelayTests.swift new file mode 100644 index 00000000..51894025 --- /dev/null +++ b/Tests/ContainerizationOSTests/BidirectionalRelayTests.swift @@ -0,0 +1,345 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the Containerization project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation +import Testing + +@testable import ContainerizationOS + +#if canImport(Darwin) +import Darwin +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#endif + +@Suite("BidirectionalRelay tests") +final class BidirectionalRelayTests { + + /// Creates a Unix domain socket pair and returns (fd0, fd1). + private func makeSocketPair() throws -> (Int32, Int32) { + var fds: [Int32] = [0, 0] + #if os(macOS) + let result = socketpair(AF_UNIX, SOCK_STREAM, 0, &fds) + #else + let result = socketpair(AF_UNIX, Int32(SOCK_STREAM.rawValue), 0, &fds) + #endif + try #require(result == 0, "socketpair should succeed, errno: \(errno)") + return (fds[0], fds[1]) + } + + /// Writes all bytes to a file descriptor, retrying on partial writes. + private func writeAll(fd: Int32, data: [UInt8]) throws { + var offset = 0 + while offset < data.count { + let n = data.withUnsafeBufferPointer { buf in + write(fd, buf.baseAddress!.advanced(by: offset), data.count - offset) + } + try #require(n > 0, "write failed, errno: \(errno)") + offset += n + } + } + + /// Reads exactly `count` bytes from a file descriptor with a timeout. + /// Returns the data read, or nil if the timeout expires. + private func readWithTimeout(fd: Int32, count: Int, timeoutSeconds: Double) -> [UInt8]? { + // Set fd to non-blocking for poll-based reading. + let flags = fcntl(fd, F_GETFL) + _ = fcntl(fd, F_SETFL, flags | O_NONBLOCK) + defer { _ = fcntl(fd, F_SETFL, flags) } + + var result = [UInt8](repeating: 0, count: count) + var totalRead = 0 + let deadline = Date().addingTimeInterval(timeoutSeconds) + + while totalRead < count && Date() < deadline { + let n = result.withUnsafeMutableBufferPointer { buf in + read(fd, buf.baseAddress!.advanced(by: totalRead), count - totalRead) + } + if n > 0 { + totalRead += n + } else if n == -1 && (errno == EAGAIN || errno == EWOULDBLOCK) { + // Not ready yet, brief sleep before retry. + usleep(1000) // 1ms + } else { + break + } + } + return totalRead == count ? result : nil + } + + /// Sets a small send buffer on a socket to make it fill quickly. + private func setSendBufferSize(fd: Int32, size: Int32) { + var bufSize = size + setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufSize, socklen_t(MemoryLayout.size)) + } + + // MARK: - Test 1: Basic relay + + @Test + func testBasicRelay() throws { + // Create two socketpairs: + // pair1: (a0) --- relay ---> (b0) + // pair2: (a1) <-- relay --- (b1) + // The relay connects a1 <-> b0. + // Write to a0, read from b1 (data flows: a0 → a1 → relay → b0 → b1). + let (a0, a1) = try makeSocketPair() + let (b0, b1) = try makeSocketPair() + defer { + close(a0) + close(b1) + } + + let relay = BidirectionalRelay(fd1: a1, fd2: b0) + try relay.start() + + let testData: [UInt8] = Array("Hello, BidirectionalRelay!".utf8) + try writeAll(fd: a0, data: testData) + + let received = readWithTimeout(fd: b1, count: testData.count, timeoutSeconds: 2.0) + #expect(received == testData, "Data should pass through the relay") + + // Test the reverse direction: write to b1, read from a0. + let reverseData: [UInt8] = Array("Reverse direction".utf8) + try writeAll(fd: b1, data: reverseData) + + let reverseReceived = readWithTimeout(fd: a0, count: reverseData.count, timeoutSeconds: 2.0) + #expect(reverseReceived == reverseData, "Data should flow in reverse through the relay") + + relay.stop() + } + + // MARK: - Test 2: Cross-connection head-of-line blocking + + @Test + func testNoCrossConnectionBlocking() throws { + // Two relays sharing a single serial queue (simulating the old architecture). + // One relay's destination is saturated (not drained). + // The other relay should still work — proving per-connection isolation. + let sharedQueue = DispatchQueue(label: "test.shared-queue") + + // Relay 1: a0 → a1 --relay1--> b0 → b1 (b1 won't be read, causing backpressure) + let (a0, a1) = try makeSocketPair() + let (b0, b1) = try makeSocketPair() + + // Relay 2: c0 → c1 --relay2--> d0 → d1 (d1 will be read normally) + let (c0, c1) = try makeSocketPair() + let (d0, d1) = try makeSocketPair() + + defer { + close(a0) + close(b1) + close(c0) + close(d1) + } + + // Shrink send buffers to make them fill quickly. + setSendBufferSize(fd: b0, size: 4096) + + let relay1 = BidirectionalRelay(fd1: a1, fd2: b0, queue: sharedQueue) + let relay2 = BidirectionalRelay(fd1: c1, fd2: d0, queue: sharedQueue) + + try relay1.start() + try relay2.start() + + // Saturate relay1: write data into a0 but never read from b1. + // This fills b0's send buffer, triggering backpressure on relay1. + let largeData = [UInt8](repeating: 0x41, count: 65536) + // Use non-blocking write to a0 so we don't block this test thread. + let a0flags = fcntl(a0, F_GETFL) + _ = fcntl(a0, F_SETFL, a0flags | O_NONBLOCK) + _ = largeData.withUnsafeBufferPointer { buf in + write(a0, buf.baseAddress!, buf.count) + } + _ = fcntl(a0, F_SETFL, a0flags) // restore + + // Give relay1 time to process and get blocked. + usleep(100_000) // 100ms + + // Now test relay2: it should still work despite relay1 being backpressured. + let testData: [UInt8] = Array("relay2 works!".utf8) + try writeAll(fd: c0, data: testData) + + let received = readWithTimeout(fd: d1, count: testData.count, timeoutSeconds: 2.0) + #expect(received != nil, "Relay2 should not be blocked by Relay1's backpressure") + if let received { + #expect(received == testData, "Relay2 data should be correct") + } + + relay1.stop() + relay2.stop() + } + + // MARK: - Test 3: Backpressure recovery + + @Test + func testBackpressureRecovery() throws { + let (a0, a1) = try makeSocketPair() + let (b0, b1) = try makeSocketPair() + defer { + close(a0) + close(b1) + } + + // Shrink b0's send buffer so backpressure kicks in quickly. + setSendBufferSize(fd: b0, size: 4096) + + let relay = BidirectionalRelay(fd1: a1, fd2: b0) + try relay.start() + + // Write enough data to trigger backpressure (more than the send buffer). + let totalBytes = 32768 + let sendData = [UInt8]((0.. 0 { + offset += n + } else if n == -1 && (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000) + } else { + break + } + } + } + writeThread.start() + + // Read from b1 (drain) — this should relieve backpressure. + var received = [UInt8]() + let readBuf = UnsafeMutableBufferPointer.allocate(capacity: 4096) + defer { readBuf.deallocate() } + + let deadline = Date().addingTimeInterval(5.0) + let b1flags = fcntl(b1, F_GETFL) + _ = fcntl(b1, F_SETFL, b1flags | O_NONBLOCK) + + while received.count < totalBytes && Date() < deadline { + let n = read(b1, readBuf.baseAddress!, readBuf.count) + if n > 0 { + received.append(contentsOf: UnsafeBufferPointer(start: readBuf.baseAddress!, count: n)) + } else if n == -1 && (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(1000) + } else { + break + } + } + + #expect(received.count == totalBytes, "All bytes should be received after backpressure recovery (got \(received.count)/\(totalBytes))") + #expect(received == sendData, "Received data should match sent data") + + relay.stop() + } + + // MARK: - Test 4: EOF handling + + @Test + func testEOFHandling() async throws { + let (a0, a1) = try makeSocketPair() + let (b0, b1) = try makeSocketPair() + + let relay = BidirectionalRelay(fd1: a1, fd2: b0) + try relay.start() + + // Write some data, then close one end. + let testData: [UInt8] = Array("goodbye".utf8) + try writeAll(fd: a0, data: testData) + close(a0) + + // Read the data from the other end. + let received = readWithTimeout(fd: b1, count: testData.count, timeoutSeconds: 2.0) + #expect(received == testData, "Data should arrive before EOF") + + // Close b1 as well so both directions see EOF. + // (a0 closed → a1 reads EOF → source1 done; + // b1 closed → b0 reads EOF → source2 done → relay complete) + close(b1) + + // The relay should detect EOF on both directions and complete. + let completed = await withTaskGroup(of: Bool.self) { group in + group.addTask { + await relay.waitForCompletion() + return true + } + group.addTask { + try? await Task.sleep(nanoseconds: 3_000_000_000) // 3 seconds + return false + } + let result = await group.next()! + group.cancelAll() + return result + } + + #expect(completed, "Relay should complete after both sides reach EOF") + } + + // MARK: - Test 5: Stop while under backpressure + + @Test + func testStopWhileBackpressured() async throws { + // Verify that stop() works correctly when a read source is suspended + // due to backpressure. Previously, cancelling a suspended dispatch source + // would never deliver the cancel handler, leaking file descriptors. + let (a0, a1) = try makeSocketPair() + let (b0, b1) = try makeSocketPair() + + // Shrink b0's send buffer so backpressure kicks in quickly. + setSendBufferSize(fd: b0, size: 4096) + + let relay = BidirectionalRelay(fd1: a1, fd2: b0) + try relay.start() + + // Write enough to trigger backpressure but don't read from b1. + let largeData = [UInt8](repeating: 0x42, count: 65536) + let a0flags = fcntl(a0, F_GETFL) + _ = fcntl(a0, F_SETFL, a0flags | O_NONBLOCK) + _ = largeData.withUnsafeBufferPointer { buf in + write(a0, buf.baseAddress!, buf.count) + } + + // Give relay time to enter backpressure state (readSource suspended). + usleep(100_000) // 100ms + + // Stop the relay while it's backpressured. This should not hang or leak. + relay.stop() + + // Close the external ends — the relay's fds should already be closed by stop(). + close(a0) + close(b1) + + // The relay should complete (cancel handlers should have fired). + let completed = await withTaskGroup(of: Bool.self) { group in + group.addTask { + await relay.waitForCompletion() + return true + } + group.addTask { + try? await Task.sleep(nanoseconds: 3_000_000_000) // 3 seconds + return false + } + let result = await group.next()! + group.cancelAll() + return result + } + + #expect(completed, "Relay should complete after stop() even when backpressured") + } +} diff --git a/vminitd/Sources/VminitdCore/VsockProxy.swift b/vminitd/Sources/VminitdCore/VsockProxy.swift index e18ebc76..393fad13 100644 --- a/vminitd/Sources/VminitdCore/VsockProxy.swift +++ b/vminitd/Sources/VminitdCore/VsockProxy.swift @@ -43,6 +43,7 @@ actor VsockProxy { private var listener: Socket? private var task: Task<(), Never>? + private var connectionTasks: [UUID: Task<(), Never>] = [:] init( id: String, @@ -98,6 +99,9 @@ extension VsockProxy { try listener.close() + for (_, t) in connectionTasks { t.cancel() } + connectionTasks.removeAll() + if action == .dial { let fm = FileManager.default if fm.fileExists(atPath: path.path) { @@ -153,7 +157,9 @@ extension VsockProxy { let task = Task { do { for try await conn in stream { - Task { + let connID = UUID() + let connTask = Task { + defer { self.connectionTasks[connID] = nil } log?.debug( "accepting connection", metadata: [ @@ -171,6 +177,8 @@ extension VsockProxy { self.log?.error("failed to handle connection: \(error)") } } + // Safe: actor serialization ensures this runs before connTask can execute its defer. + connectionTasks[connID] = connTask } } catch { self.log?.error("failed to accept connection: \(error)")