diff --git a/Cargo.lock b/Cargo.lock index 531cb28db1..d0d6afd300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -408,12 +408,6 @@ dependencies = [ "powerfmt", ] -[[package]] -name = "dyn-clone" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" - [[package]] name = "either" version = "1.12.0" @@ -621,7 +615,6 @@ dependencies = [ "build-time", "cfg-if", "crossbeam-utils", - "dyn-clone", "fdt", "float-cmp", "free-list", diff --git a/Cargo.toml b/Cargo.toml index e23c08b637..ab6639488f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,7 +84,6 @@ bitflags = "2.6" build-time = "0.1.3" cfg-if = "1" crossbeam-utils = { version = "0.8", default-features = false } -dyn-clone = "1.0" fdt = { version = "0.1", features = ["pretty-printing"] } free-list = "0.3" fuse-abi = { version = "0.1", features = ["zerocopy"], optional = true } diff --git a/src/fd/mod.rs b/src/fd/mod.rs index 8981252480..164c07a375 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -6,7 +6,6 @@ use core::task::Poll::{Pending, Ready}; use core::time::Duration; use async_trait::async_trait; -use dyn_clone::DynClone; #[cfg(any(feature = "tcp", feature = "udp"))] use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; @@ -144,7 +143,7 @@ impl Default for AccessPermission { } #[async_trait] -pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { +pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { /// check if an IO event is possible async fn poll(&self, _event: PollEvent) -> io::Result { Ok(PollEvent::empty()) @@ -187,7 +186,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { /// `accept` a connection on a socket #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] - async fn accept(&self) -> io::Result { + async fn accept(&self) -> io::Result<(Arc, Endpoint)> { Err(io::Error::EINVAL) } @@ -398,11 +397,6 @@ pub(crate) fn insert_object(obj: Arc) -> io::Result) -> io::Result<()> { - block_on(core_scheduler().replace_object(fd, obj), None) -} - // The dup system call allocates a new file descriptor that refers // to the same open file description as the descriptor oldfd. The new // file descriptor number is guaranteed to be the lowest-numbered diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 680a192f33..ba721e85c2 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -1,7 +1,8 @@ use alloc::boxed::Box; +use alloc::collections::VecDeque; +use alloc::sync::Arc; use core::future; -use core::ops::DerefMut; -use core::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, Ordering}; +use core::sync::atomic::{AtomicU16, Ordering}; use core::task::Poll; use async_trait::async_trait; @@ -10,7 +11,7 @@ use smoltcp::socket::tcp; use smoltcp::time::Duration; use crate::executor::block_on; -use crate::executor::network::{now, Handle, NetworkState, NIC}; +use crate::executor::network::{now, Handle, NIC}; use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption}; use crate::{io, DEFAULT_KEEP_ALIVE_INTERVAL}; @@ -29,26 +30,29 @@ fn get_ephemeral_port() -> u16 { #[derive(Debug)] pub struct Socket { - handle: Handle, - port: AtomicU16, - backlog: AtomicU32, - nonblocking: AtomicBool, + handle: VecDeque, + port: u16, + is_nonblocking: bool, + is_listen: bool, } impl Socket { - pub fn new(handle: Handle) -> Self { + pub fn new(h: Handle) -> Self { + let mut handle = VecDeque::new(); + handle.push_back(h); + Self { handle, - port: AtomicU16::new(0), - backlog: AtomicU32::new(0), - nonblocking: AtomicBool::new(false), + port: 0, + is_nonblocking: false, + is_listen: false, } } fn with(&self, f: impl FnOnce(&mut tcp::Socket<'_>) -> R) -> R { let mut guard = NIC.lock(); let nic = guard.as_nic_mut().unwrap(); - let result = f(nic.get_mut_socket::>(self.handle)); + let result = f(nic.get_mut_socket::>(*self.handle.get(0).unwrap())); nic.poll_common(now()); result @@ -57,14 +61,14 @@ impl Socket { fn with_context(&self, f: impl FnOnce(&mut tcp::Socket<'_>, &mut iface::Context) -> R) -> R { let mut guard = NIC.lock(); let nic = guard.as_nic_mut().unwrap(); - let (s, cx) = nic.get_socket_and_context::>(self.handle); + let (s, cx) = nic.get_socket_and_context::>(*self.handle.get(0).unwrap()); let result = f(s, cx); nic.poll_common(now()); result } - async fn async_close(&self) -> io::Result<()> { + async fn close(&self) -> io::Result<()> { future::poll_fn(|_cx| { self.with(|socket| { if socket.is_active() { @@ -77,6 +81,18 @@ impl Socket { }) .await?; + if self.handle.len() > 1 { + let mut guard = NIC.lock(); + let nic = guard.as_nic_mut().unwrap(); + + for handle in self.handle.iter().skip(1) { + let socket = nic.get_mut_socket::>(*handle); + if socket.is_active() { + socket.close(); + } + } + } + future::poll_fn(|cx| { self.with(|socket| { if !socket.is_active() { @@ -90,10 +106,7 @@ impl Socket { }) .await } -} -#[async_trait] -impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { future::poll_fn(|cx| { self.with(|socket| match socket.state() { @@ -124,9 +137,7 @@ impl ObjectInterface for Socket { _ => { let mut available = PollEvent::empty(); - if socket.can_recv() - || socket.may_recv() && self.backlog.load(Ordering::Acquire) > 0 - { + if socket.can_recv() || socket.may_recv() && self.is_listen { // In case, we just establish a fresh connection in non-blocking mode, we try to read data. available.insert( PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND, @@ -241,10 +252,10 @@ impl ObjectInterface for Socket { Ok(pos) } - async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] if let ListenEndpoint::Ip(endpoint) = endpoint { - self.port.store(endpoint.port, Ordering::Release); + self.port = endpoint.port; Ok(()) } else { Err(io::Error::EIO) @@ -276,11 +287,11 @@ impl ObjectInterface for Socket { } } - async fn accept(&self) -> io::Result { + async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> { future::poll_fn(|cx| { self.with(|socket| match socket.state() { tcp::State::Closed => { - let _ = socket.listen(self.port.load(Ordering::Acquire)); + let _ = socket.listen(self.port); Poll::Ready(()) } tcp::State::Listen | tcp::State::Established => Poll::Ready(()), @@ -312,12 +323,35 @@ impl ObjectInterface for Socket { }) .await?; + let connection_handle = self.handle.pop_front().unwrap(); let mut guard = NIC.lock(); let nic = guard.as_nic_mut().map_err(|_| io::Error::EIO)?; - let socket = nic.get_mut_socket::>(self.handle); + let socket = nic.get_mut_socket::>(connection_handle); socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL))); + let endpoint = Endpoint::Ip(socket.remote_endpoint().unwrap()); + let nagle_enabled = socket.nagle_enabled(); + + // fill up queue for pending connections + let new_handle = nic.create_tcp_handle().unwrap(); + self.handle.push_back(new_handle); + let socket = nic.get_mut_socket::>(new_handle); + socket.set_nagle_enabled(nagle_enabled); + socket + .listen(self.port) + .map(|_| ()) + .map_err(|_| io::Error::EIO)?; + + let mut handle = VecDeque::new(); + handle.push_back(connection_handle); + + let socket = Socket { + handle, + port: self.port, + is_nonblocking: self.is_nonblocking, + is_listen: false, + }; - Ok(Endpoint::Ip(socket.remote_endpoint().unwrap())) + Ok((socket, endpoint)) } async fn getpeername(&self) -> io::Result> { @@ -333,31 +367,20 @@ impl ObjectInterface for Socket { } async fn is_nonblocking(&self) -> io::Result { - Ok(self.nonblocking.load(Ordering::Acquire)) + Ok(self.is_nonblocking) } - async fn listen(&self, backlog: i32) -> io::Result<()> { + async fn listen(&mut self, backlog: i32) -> io::Result<()> { + let nagle_enabled = self.with(|socket| socket.nagle_enabled()); + self.with(|socket| { if !socket.is_open() { if backlog > 0 { - self.backlog - .store(backlog.try_into().unwrap(), Ordering::Relaxed); - socket - .listen(self.port.load(Ordering::Acquire)) + .listen(self.port) .map(|_| ()) .map_err(|_| io::Error::EIO)?; - let rx_size = socket.recv_queue(); - let tx_size = socket.send_queue(); - let is_nagle = socket.nagle_enabled(); - for _ in 1..backlog { - let rx_buffer = tcp::SocketBuffer::new(vec![0; rx_size]); - let tx_buffer = tcp::SocketBuffer::new(vec![0; tx_size]); - let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer); - tcp_socket.set_nagle_enabled(is_nagle); - } - Ok(()) } else { Err(io::Error::EINVAL) @@ -365,7 +388,23 @@ impl ObjectInterface for Socket { } else { Err(io::Error::EIO) } - }) + })?; + self.is_listen = true; + + let mut guard = NIC.lock(); + let nic = guard.as_nic_mut().unwrap(); + for _ in 1..backlog { + let handle = nic.create_tcp_handle().unwrap(); + self.handle.push_back(handle); + + let s = nic.get_mut_socket::>(handle); + s.set_nagle_enabled(nagle_enabled); + s.listen(self.port) + .map(|_| ()) + .map_err(|_| io::Error::EIO)?; + } + + Ok(()) } async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> { @@ -401,14 +440,14 @@ impl ObjectInterface for Socket { } } - async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + async fn ioctl(&mut self, cmd: IoCtl, value: bool) -> io::Result<()> { if cmd == IoCtl::NonBlocking { if value { trace!("set device to nonblocking mode"); - self.nonblocking.store(true, Ordering::Release); + self.is_nonblocking = true; } else { trace!("set device to blocking mode"); - self.nonblocking.store(false, Ordering::Release); + self.is_nonblocking = false; } Ok(()) @@ -418,37 +457,73 @@ impl ObjectInterface for Socket { } } -impl Clone for Socket { - fn clone(&self) -> Self { +impl Drop for Socket { + fn drop(&mut self) { + let _ = block_on(self.close(), None); + let mut guard = NIC.lock(); + for h in self.handle.iter() { + guard.as_nic_mut().unwrap().destroy_socket(*h); + } + } +} - let handle = if let NetworkState::Initialized(nic) = guard.deref_mut() { - nic.create_tcp_handle().unwrap() - } else { - panic!("Unable to create handle"); - }; +#[async_trait] +impl ObjectInterface for async_lock::RwLock { + async fn poll(&self, event: PollEvent) -> io::Result { + self.read().await.poll(event).await + } - drop(guard); - let port = self.port.load(Ordering::Acquire); - let backlog = self.backlog.load(Ordering::Acquire); - let obj = Self { - handle, - port: AtomicU16::new(port), - backlog: AtomicU32::new(backlog), - nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)), - }; + async fn read(&self, buffer: &mut [u8]) -> io::Result { + self.read().await.read(buffer).await + } - if port > 0 { - let _ = block_on(obj.listen(backlog.try_into().unwrap()), None); - } + async fn write(&self, buffer: &[u8]) -> io::Result { + self.read().await.write(buffer).await + } - obj + async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + self.write().await.bind(endpoint).await } -} -impl Drop for Socket { - fn drop(&mut self) { - let _ = block_on(self.async_close(), None); - NIC.lock().as_nic_mut().unwrap().destroy_socket(self.handle); + async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + self.read().await.connect(endpoint).await + } + + async fn accept(&self) -> io::Result<(Arc, Endpoint)> { + let (handle, endpoint) = self.write().await.accept().await?; + Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint)) + } + + async fn getpeername(&self) -> io::Result> { + self.read().await.getpeername().await + } + + async fn getsockname(&self) -> io::Result> { + self.read().await.getsockname().await + } + + async fn is_nonblocking(&self) -> io::Result { + self.read().await.is_nonblocking().await + } + + async fn listen(&self, backlog: i32) -> io::Result<()> { + self.write().await.listen(backlog).await + } + + async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> { + self.read().await.setsockopt(opt, optval).await + } + + async fn getsockopt(&self, opt: SocketOption) -> io::Result { + self.read().await.getsockopt(opt).await + } + + async fn shutdown(&self, how: i32) -> io::Result<()> { + self.read().await.shutdown(how).await + } + + async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + self.write().await.ioctl(cmd, value).await } } diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index b492be884f..5ad51b1c12 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -1,7 +1,7 @@ use alloc::boxed::Box; +use alloc::sync::Arc; use alloc::vec::Vec; use core::future; -use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use core::task::Poll; use async_trait::async_trait; @@ -17,7 +17,7 @@ use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io::{self, Error}; #[derive(Debug)] -pub(crate) struct VsockListenEndpoint { +pub struct VsockListenEndpoint { pub port: u32, pub cid: Option, } @@ -29,7 +29,7 @@ impl VsockListenEndpoint { } #[derive(Debug)] -pub(crate) struct VsockEndpoint { +pub struct VsockEndpoint { pub port: u32, pub cid: u32, } @@ -40,31 +40,38 @@ impl VsockEndpoint { } } +#[derive(Debug)] +pub struct NullSocket; + +impl NullSocket { + pub const fn new() -> Self { + Self {} + } +} + +#[async_trait] +impl ObjectInterface for async_lock::RwLock {} + #[derive(Debug)] pub struct Socket { - port: AtomicU32, - cid: AtomicU32, - nonblocking: AtomicBool, + port: u32, + cid: u32, + nonblocking: bool, } impl Socket { pub fn new() -> Self { Self { - port: AtomicU32::new(0), - cid: AtomicU32::new(u32::MAX), - nonblocking: AtomicBool::new(false), + port: 0, + cid: u32::MAX, + nonblocking: false, } } -} -#[async_trait] -impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { - let port = self.port.load(Ordering::Acquire); - future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); - let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + let raw = guard.get_mut_socket(self.port).ok_or(Error::EINVAL)?; match raw.state { VsockState::Shutdown | VsockState::ReceiveRequest => { @@ -130,14 +137,14 @@ impl ObjectInterface for Socket { .await } - async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> { match endpoint { ListenEndpoint::Vsock(ep) => { - self.port.store(ep.port, Ordering::Release); + self.port = ep.port; if let Some(cid) = ep.cid { - self.cid.store(cid, Ordering::Release); + self.cid = cid; } else { - self.cid.store(u32::MAX, Ordering::Release); + self.cid = u32::MAX; } VSOCK_MAP.lock().bind(ep.port) } @@ -146,13 +153,13 @@ impl ObjectInterface for Socket { } } - async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> { match endpoint { Endpoint::Vsock(ep) => { const HEADER_SIZE: usize = core::mem::size_of::(); let port = VSOCK_MAP.lock().connect(ep.port, ep.cid)?; - self.port.store(port, Ordering::Release); - self.port.store(ep.cid, Ordering::Release); + self.port = port; + self.port = ep.cid; future::poll_fn(|_cx| { if let Some(mut driver_guard) = hardware::get_vsock_driver().unwrap().try_lock() @@ -204,9 +211,8 @@ impl ObjectInterface for Socket { } async fn getpeername(&self) -> io::Result> { - let port = self.port.load(Ordering::Acquire); let guard = VSOCK_MAP.lock(); - let raw = guard.get_socket(port).ok_or(Error::EINVAL)?; + let raw = guard.get_socket(self.port).ok_or(Error::EINVAL)?; Ok(Some(Endpoint::Vsock(VsockEndpoint::new( raw.remote_port, @@ -218,22 +224,22 @@ impl ObjectInterface for Socket { let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid(); Ok(Some(Endpoint::Vsock(VsockEndpoint::new( - self.port.load(Ordering::Acquire), + self.port, local_cid.try_into().unwrap(), )))) } async fn is_nonblocking(&self) -> io::Result { - Ok(self.nonblocking.load(Ordering::Acquire)) + Ok(self.nonblocking) } async fn listen(&self, _backlog: i32) -> io::Result<()> { Ok(()) } - async fn accept(&self) -> io::Result { - let port = self.port.load(Ordering::Acquire); - let cid = self.cid.load(Ordering::Acquire); + async fn accept(&mut self) -> io::Result<(NullSocket, Endpoint)> { + let port = self.port; + let cid = self.cid; let endpoint = future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); @@ -283,21 +289,21 @@ impl ObjectInterface for Socket { }) .await?; - Ok(Endpoint::Vsock(endpoint)) + Ok((NullSocket::new(), Endpoint::Vsock(endpoint))) } async fn shutdown(&self, _how: i32) -> io::Result<()> { Ok(()) } - async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + async fn ioctl(&mut self, cmd: IoCtl, value: bool) -> io::Result<()> { if cmd == IoCtl::NonBlocking { if value { trace!("set vsock device to nonblocking mode"); - self.nonblocking.store(true, Ordering::Release); + self.nonblocking = true; } else { trace!("set vsock device to blocking mode"); - self.nonblocking.store(false, Ordering::Release); + self.nonblocking = false; } Ok(()) @@ -310,7 +316,7 @@ impl ObjectInterface for Socket { // https://github.com/rust-lang/rust-clippy/issues/11380 #[allow(clippy::needless_pass_by_ref_mut)] async fn read(&self, buffer: &mut [u8]) -> io::Result { - let port = self.port.load(Ordering::Acquire); + let port = self.port; future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; @@ -348,7 +354,7 @@ impl ObjectInterface for Socket { } async fn write(&self, buffer: &[u8]) -> io::Result { - let port = self.port.load(Ordering::Acquire); + let port = self.port; future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; @@ -400,20 +406,61 @@ impl ObjectInterface for Socket { } } -impl Clone for Socket { - fn clone(&self) -> Self { - Self { - port: AtomicU32::new(self.port.load(Ordering::Acquire)), - cid: AtomicU32::new(self.cid.load(Ordering::Acquire)), - nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)), - } - } -} - impl Drop for Socket { fn drop(&mut self) { - let port = self.port.load(Ordering::Acquire); let mut guard = VSOCK_MAP.lock(); - guard.remove_socket(port); + guard.remove_socket(self.port); + } +} + +#[async_trait] +impl ObjectInterface for async_lock::RwLock { + async fn poll(&self, event: PollEvent) -> io::Result { + self.read().await.poll(event).await + } + + async fn read(&self, buffer: &mut [u8]) -> io::Result { + self.read().await.read(buffer).await + } + + async fn write(&self, buffer: &[u8]) -> io::Result { + self.read().await.write(buffer).await + } + + async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + self.write().await.bind(endpoint).await + } + + async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + self.write().await.connect(endpoint).await + } + + async fn accept(&self) -> io::Result<(Arc, Endpoint)> { + let (handle, endpoint) = self.write().await.accept().await?; + Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint)) + } + + async fn getpeername(&self) -> io::Result> { + self.read().await.getpeername().await + } + + async fn getsockname(&self) -> io::Result> { + self.read().await.getsockname().await + } + + async fn is_nonblocking(&self) -> io::Result { + self.read().await.is_nonblocking().await + } + + async fn listen(&self, backlog: i32) -> io::Result<()> { + self.write().await.listen(backlog).await + } + + async fn shutdown(&self, how: i32) -> io::Result<()> { + self.read().await.shutdown(how).await + } + + async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + self.write().await.ioctl(cmd, value).await } } diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 3b58803d9a..342ec1abe1 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -537,26 +537,6 @@ impl PerCoreScheduler { .await } - /// Replace an existing IO interface by a new one - #[allow(dead_code)] - pub async fn replace_object( - &self, - fd: FileDescriptor, - obj: Arc, - ) -> io::Result<()> { - future::poll_fn(|cx| { - without_interrupts(|| { - let borrowed = self.current_task.borrow(); - let mut pinned_obj = core::pin::pin!(borrowed.object_map.write()); - - let mut guard = ready!(pinned_obj.as_mut().poll(cx)); - guard.insert(fd, obj.clone()); - Ready(Ok(())) - }) - }) - .await - } - /// Duplicate a IO interface and returns a new file descriptor as /// identifier to the new copy pub async fn dup_object(&self, fd: FileDescriptor) -> io::Result { diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index 5550a75176..fe5a460cef 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -21,8 +21,7 @@ use crate::fd::socket::udp; #[cfg(feature = "vsock")] use crate::fd::socket::vsock::{self, VsockEndpoint, VsockListenEndpoint}; use crate::fd::{ - get_object, insert_object, replace_object, Endpoint, ListenEndpoint, ObjectInterface, - SocketOption, + get_object, insert_object, Endpoint, ListenEndpoint, ObjectInterface, SocketOption, }; use crate::io; use crate::syscalls::{block_on, IoCtl}; @@ -423,13 +422,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 #[cfg(feature = "vsock")] if domain == AF_VSOCK && type_.intersects(SockType::SOCK_STREAM) { - let socket = vsock::Socket::new(); + let socket = Arc::new(async_lock::RwLock::new(vsock::Socket::new())); if type_.contains(SockType::SOCK_NONBLOCK) { block_on(socket.ioctl(IoCtl::NonBlocking, true), None).unwrap(); } - let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + let fd = insert_object(socket).expect("FD is already used"); return fd; } @@ -445,13 +444,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 if type_.contains(SockType::SOCK_DGRAM) { let handle = nic.create_udp_handle().unwrap(); drop(guard); - let socket = udp::Socket::new(handle); + let socket = Arc::new(udp::Socket::new(handle)); if type_.contains(SockType::SOCK_NONBLOCK) { block_on(socket.ioctl(IoCtl::NonBlocking, true), None).unwrap(); } - let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + let fd = insert_object(socket).expect("FD is already used"); return fd; } @@ -460,13 +459,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 if type_.contains(SockType::SOCK_STREAM) { let handle = nic.create_tcp_handle().unwrap(); drop(guard); - let socket = tcp::Socket::new(handle); + let socket = Arc::new(async_lock::RwLock::new(tcp::Socket::new(handle))); if type_.contains(SockType::SOCK_NONBLOCK) { block_on(socket.ioctl(IoCtl::NonBlocking, true), None).unwrap(); } - let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + let fd = insert_object(socket).expect("FD is already used"); return fd; } @@ -485,12 +484,10 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut |v| { block_on((*v).accept(), None).map_or_else( |e| -num::ToPrimitive::to_i32(&e).unwrap(), - |endpoint| match endpoint { + |(obj, endpoint)| match endpoint { #[cfg(any(feature = "tcp", feature = "udp"))] Endpoint::Ip(endpoint) => { - let new_obj = dyn_clone::clone_box(&*v); - replace_object(fd, Arc::from(new_obj)).unwrap(); - let new_fd = insert_object(v).unwrap(); + let new_fd = insert_object(obj).unwrap(); if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; @@ -517,8 +514,6 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut } #[cfg(feature = "vsock")] Endpoint::Vsock(endpoint) => { - //let new_obj = dyn_clone::clone_box(&*v); - //replace_object(fd, Arc::from(new_obj)).unwrap(); let new_fd = insert_object(v.clone()).unwrap(); if !addr.is_null() && !addrlen.is_null() {