diff --git a/vmm/src/serial_manager.rs b/vmm/src/serial_manager.rs index 638a2ff37..b8b8c4192 100644 --- a/vmm/src/serial_manager.rs +++ b/vmm/src/serial_manager.rs @@ -176,6 +176,8 @@ pub struct SerialManager { handle: Option>, pty_write_out: Option>, socket_path: Option, + /// Tracks the active TCP client so `Drop` can shut it down before joining. + active_tcp_stream: Arc>>, } impl SerialManager { @@ -278,6 +280,7 @@ impl SerialManager { handle: None, pty_write_out, socket_path, + active_tcp_stream: Arc::new(Mutex::new(None)), })) } @@ -320,6 +323,9 @@ impl SerialManager { let pty_write_out = self.pty_write_out.clone(); let mut reader: Option = None; let mut reader_tcp: Option = None; + // Keep `active_tcp_stream` and `reader_tcp` in sync: the former is the + // shutdown handle, the latter is the worker's read-side owner. + let active_tcp_stream = self.active_tcp_stream.clone(); // In case of PTY, we want to be able to detect a connection on the // other end of the PTY. This is done by detecting there's no event @@ -428,6 +434,7 @@ impl SerialManager { previous_reader .shutdown(Shutdown::Both) .map_err(Error::AcceptConnection)?; + active_tcp_stream.lock().unwrap().take(); if let Some(distributor) = &write_distributor { distributor.remove_writer(TypeId::of::()); } @@ -441,6 +448,8 @@ impl SerialManager { // Accept them, create a reader and a writer. let (tcp_stream, _) = listener.accept().map_err(Error::AcceptConnection)?; + let active_tcp_stream_handle = + tcp_stream.try_clone().map_err(Error::CloneStream)?; let writer = tcp_stream.try_clone().map_err(Error::CloneStream)?; @@ -454,6 +463,8 @@ impl SerialManager { ), ) .map_err(Error::Epoll)?; + *active_tcp_stream.lock().unwrap() = + Some(active_tcp_stream_handle); reader_tcp = Some(tcp_stream); if let Some(distributor) = &write_distributor { distributor.add_writer(writer); @@ -497,6 +508,7 @@ impl SerialManager { .shutdown(Shutdown::Both) .map_err(Error::ShutdownConnection)?; reader_tcp = None; + active_tcp_stream.lock().unwrap().take(); if let Some(distributor) = &write_distributor { @@ -566,6 +578,10 @@ impl SerialManager { impl Drop for SerialManager { fn drop(&mut self) { + if let Some(tcp_stream) = self.active_tcp_stream.lock().unwrap().take() { + tcp_stream.shutdown(Shutdown::Both).ok(); + } + self.kill_evt.write(1).ok(); if let Some(handle) = self.handle.take() { handle.join().ok();