Nothing Special   »   [go: up one dir, main page]

Skip to content

Commit

Permalink
move check for a non-blocking device to the object interface
Browse files Browse the repository at this point in the history
  • Loading branch information
stlankes committed Sep 23, 2024
1 parent 61f8812 commit c0dec02
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 114 deletions.
8 changes: 4 additions & 4 deletions src/fd/eventfd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ impl ObjectInterface for EventFd {
cx.wake_by_ref();
}
Poll::Ready(Ok(len))
} else if self.flags.contains(EventFlags::EFD_NONBLOCK) {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
guard.read_queue.push_back(cx.waker().clone());
Poll::Pending
Expand Down Expand Up @@ -115,6 +117,8 @@ impl ObjectInterface for EventFd {
}

Poll::Ready(Ok(len))
} else if self.flags.contains(EventFlags::EFD_NONBLOCK) {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
guard.write_queue.push_back(cx.waker().clone());
Poll::Pending
Expand Down Expand Up @@ -163,8 +167,4 @@ impl ObjectInterface for EventFd {
})
.await
}

async fn is_nonblocking(&self) -> io::Result<bool> {
Ok(self.flags.contains(EventFlags::EFD_NONBLOCK))
}
}
32 changes: 3 additions & 29 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use async_trait::async_trait;
use smoltcp::wire::{IpEndpoint, IpListenEndpoint};

use crate::arch::kernel::core_local::core_scheduler;
use crate::executor::{block_on, poll_on};
use crate::executor::block_on;
use crate::fs::{DirectoryEntry, FileAttr, SeekWhence};
use crate::io;

Expand Down Expand Up @@ -166,12 +166,6 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
Err(io::Error::EINVAL)
}

/// `is_nonblocking` returns `true`, if `read`, `write`, `recv` and send operations
/// don't block.
async fn is_nonblocking(&self) -> io::Result<bool> {
Ok(false)
}

/// `fstat`
async fn fstat(&self) -> io::Result<FileAttr> {
Err(io::Error::EINVAL)
Expand Down Expand Up @@ -271,17 +265,7 @@ pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result<usize> {
return Ok(0);
}

if block_on(obj.is_nonblocking(), None)? {
poll_on(obj.read(buf), Some(Duration::ZERO)).map_err(|x| {
if x == io::Error::ETIME {
io::Error::EAGAIN
} else {
x
}
})
} else {
block_on(obj.read(buf), None)
}
block_on(obj.read(buf), None)
}

pub(crate) fn lseek(fd: FileDescriptor, offset: isize, whence: SeekWhence) -> io::Result<isize> {
Expand All @@ -297,17 +281,7 @@ pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result<usize> {
return Ok(0);
}

if block_on(obj.is_nonblocking(), None)? {
poll_on(obj.write(buf), Some(Duration::ZERO)).map_err(|x| {
if x == io::Error::ETIME {
io::Error::EAGAIN
} else {
x
}
})
} else {
block_on(obj.write(buf), None)
}
block_on(obj.write(buf), None)
}

async fn poll_fds(fds: &mut [PollFd]) -> io::Result<u64> {
Expand Down
12 changes: 4 additions & 8 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ impl Socket {
})
.map_err(|_| io::Error::EIO),
)
} else if self.is_nonblocking {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
socket.register_recv_waker(cx.waker());
Poll::Pending
Expand Down Expand Up @@ -232,6 +234,8 @@ impl Socket {
// we already send some data => return 0 as signal to stop the
// async write
Poll::Ready(Ok(0))
} else if self.is_nonblocking {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
socket.register_send_waker(cx.waker());
Poll::Pending
Expand Down Expand Up @@ -374,10 +378,6 @@ impl Socket {
.map(Endpoint::Ip))
}

async fn is_nonblocking(&self) -> io::Result<bool> {
Ok(self.is_nonblocking)
}

async fn listen(&mut self, backlog: i32) -> io::Result<()> {
let nagle_enabled = self.with(|socket| socket.nagle_enabled());

Expand Down Expand Up @@ -511,10 +511,6 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.getsockname().await
}

async fn is_nonblocking(&self) -> io::Result<bool> {
self.read().await.is_nonblocking().await
}

async fn listen(&self, backlog: i32) -> io::Result<()> {
self.write().await.listen(backlog).await
}
Expand Down
40 changes: 22 additions & 18 deletions src/fd/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ impl ObjectInterface for async_lock::RwLock<NullSocket> {}
pub struct Socket {
port: u32,
cid: u32,
nonblocking: bool,
is_nonblocking: bool,
}

impl Socket {
pub fn new() -> Self {
Self {
port: 0,
cid: u32::MAX,
nonblocking: false,
is_nonblocking: false,
}
}

Expand Down Expand Up @@ -229,10 +229,6 @@ impl Socket {
))))
}

async fn is_nonblocking(&self) -> io::Result<bool> {
Ok(self.nonblocking)
}

async fn listen(&self, _backlog: i32) -> io::Result<()> {
Ok(())
}
Expand All @@ -247,8 +243,12 @@ impl Socket {

match raw.state {
VsockState::Listen => {
raw.rx_waker.register(cx.waker());
Poll::Pending
if self.is_nonblocking {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
raw.rx_waker.register(cx.waker());
Poll::Pending
}
}
VsockState::ReceiveRequest => {
let result = {
Expand Down Expand Up @@ -300,10 +300,10 @@ impl Socket {
if cmd == IoCtl::NonBlocking {
if value {
trace!("set vsock device to nonblocking mode");
self.nonblocking = true;
self.is_nonblocking = true;
} else {
trace!("set vsock device to blocking mode");
self.nonblocking = false;
self.is_nonblocking = false;
}

Ok(())
Expand All @@ -326,8 +326,12 @@ impl Socket {
let len = core::cmp::min(buffer.len(), raw.buffer.len());

if len == 0 {
raw.rx_waker.register(cx.waker());
Poll::Pending
if self.is_nonblocking {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
raw.rx_waker.register(cx.waker());
Poll::Pending
}
} else {
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
buffer[..len].copy_from_slice(tmp.as_slice());
Expand Down Expand Up @@ -363,8 +367,12 @@ impl Socket {
match raw.state {
VsockState::Connected => {
if diff >= raw.peer_buf_alloc {
raw.tx_waker.register(cx.waker());
Poll::Pending
if self.is_nonblocking {
Poll::Ready(Err(io::Error::EAGAIN))
} else {
raw.tx_waker.register(cx.waker());
Poll::Pending
}
} else {
const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
Expand Down Expand Up @@ -448,10 +456,6 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.getsockname().await
}

async fn is_nonblocking(&self) -> io::Result<bool> {
self.read().await.is_nonblocking().await
}

async fn listen(&self, backlog: i32) -> io::Result<()> {
self.write().await.listen(backlog).await
}
Expand Down
61 changes: 6 additions & 55 deletions src/syscalls/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use core::ffi::{c_char, c_void};
use core::mem::size_of;
#[allow(unused_imports)]
use core::ops::DerefMut;
use core::time::Duration;

use cfg_if::cfg_if;
#[cfg(any(feature = "tcp", feature = "udp"))]
Expand All @@ -23,7 +22,6 @@ use crate::fd::socket::vsock::{self, VsockEndpoint, VsockListenEndpoint};
use crate::fd::{
get_object, insert_object, Endpoint, ListenEndpoint, ObjectInterface, SocketOption,
};
use crate::io;
use crate::syscalls::{block_on, IoCtl};

pub const AF_INET: i32 = 0;
Expand Down Expand Up @@ -634,25 +632,8 @@ pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: so
obj.map_or_else(
|e| -num::ToPrimitive::to_i32(&e).unwrap(),
|v| {
let non_blocking = block_on((*v).is_nonblocking(), None).unwrap();
let timeout = if non_blocking {
Some(Duration::ZERO)
} else {
None
};

block_on((*v).connect(endpoint), timeout).map_or_else(
|e| {
let e = if e == io::Error::ETIME {
io::Error::EAGAIN
} else {
e
};

-num::ToPrimitive::to_i32(&e).unwrap()
},
|_| 0,
)
block_on((*v).connect(endpoint), None)
.map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0)
},
)
}
Expand Down Expand Up @@ -958,23 +939,8 @@ pub unsafe extern "C" fn sys_sendto(
obj.map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|v| {
let non_blocking = block_on((*v).is_nonblocking(), None).unwrap();
let timeout = if non_blocking {
Some(Duration::ZERO)
} else {
None
};

block_on((*v).sendto(slice, endpoint), timeout).map_or_else(
|e| {
let e = if non_blocking && e == io::Error::ETIME {
io::Error::EAGAIN
} else {
e
};

-num::ToPrimitive::to_isize(&e).unwrap()
},
block_on((*v).sendto(slice, endpoint), None).map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|v| v.try_into().unwrap(),
)
},
Expand All @@ -999,23 +965,8 @@ pub unsafe extern "C" fn sys_recvfrom(
obj.map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|v| {
let non_blocking = block_on((*v).is_nonblocking(), None).unwrap();
let timeout = if non_blocking {
Some(Duration::ZERO)
} else {
None
};

block_on((*v).recvfrom(slice), timeout).map_or_else(
|e| {
let e = if non_blocking && e == io::Error::ETIME {
io::Error::EAGAIN
} else {
e
};

-num::ToPrimitive::to_isize(&e).unwrap()
},
block_on((*v).recvfrom(slice), None).map_or_else(
|e| -num::ToPrimitive::to_isize(&e).unwrap(),
|(len, endpoint)| {
if !addr.is_null() && !addrlen.is_null() {
#[allow(unused_variables)]
Expand Down

0 comments on commit c0dec02

Please sign in to comment.