Tage Johansson
2023-Aug-20 14:16 UTC
[Libguestfs] [libnbd PATCH v8 06/10] rust: async: Create an async friendly handle type
Create another handle type: AsyncHandle, which makes use of Rust's builtin asynchronous functions (see <https://doc.rust-lang.org/std/keyword.async.html>) and runs on top of the Tokio runtime (see <https://docs.rs/tokio>). For every asynchronous command, like aio_connect(), a corresponding `async` method is created on the handle. In this case it would be: async fn connect(...) -> Result<(), ...> When called, it will poll the file descriptor until the command is complete, and then return with a result. All the synchronous counterparts (like nbd_connect()) are excluded from this handle type as they are unnecessary and since they might interfear with the polling made by the Tokio runtime. For more details about how the asynchronous commands are executed, please see the comments in rust/src/async_handle.rs. --- generator/Rust.ml | 249 +++++++++++++++++++++++++++++++++++- generator/Rust.mli | 2 + generator/generator.ml | 2 + rust/Cargo.toml | 4 +- rust/Makefile.am | 2 + rust/src/async_handle.rs | 268 +++++++++++++++++++++++++++++++++++++++ rust/src/lib.rs | 8 ++ rust/src/utils.rs | 9 ++ scripts/git.orderfile | 1 + 9 files changed, 538 insertions(+), 7 deletions(-) create mode 100644 rust/src/async_handle.rs diff --git a/generator/Rust.ml b/generator/Rust.ml index 431c814..1bc81f0 100644 --- a/generator/Rust.ml +++ b/generator/Rust.ml @@ -61,11 +61,12 @@ let print_rust_flags { flag_prefix; flags } let rec to_upper_snake_case s let s = String.uppercase_ascii s in let s = explode s in - let s = filter_map ( - function - |'-' -> Some "_" | ':' -> None - | ch -> Some (String.make 1 ch) - ) s in + let s + filter_map + (function + | '-' -> Some "_" | ':' -> None | ch -> Some (String.make 1 ch)) + s + in String.concat "" s (* Split a string into a list of chars. In later OCaml we could @@ -75,7 +76,7 @@ and explode str let r = ref [] in for i = 0 to String.length str - 1 do let c = String.unsafe_get str i in - r := c :: !r; + r := c :: !r done; List.rev !r @@ -564,3 +565,239 @@ let generate_rust_bindings () pr "impl Handle {\n"; List.iter print_rust_handle_method handle_calls; pr "}\n\n" + +(*********************************************************) +(* The rest of the file conserns the asynchronous API. *) +(* *) +(* See the comments in rust/src/async_handle.rs for more *) +(* information about how it works. *) +(*********************************************************) + +let excluded_handle_calls : NameSet.t + NameSet.of_list + [ + "aio_get_fd"; + "aio_get_direction"; + "aio_notify_read"; + "aio_notify_write"; + "clear_debug_callback"; + "get_debug"; + "poll"; + "poll2"; + "set_debug"; + "set_debug_callback"; + ] + +(* A mapping with names as keys. *) +module NameMap = Map.Make (String) + +(* Strip "aio_" from the beginning of a string. *) +let strip_aio name : string + if String.starts_with ~prefix:"aio_" name then + String.sub name 4 (String.length name - 4) + else failwithf "Asynchronous call %s must begin with aio_" name + +(* A map with all asynchronous handle calls. The keys are names with "aio_" + stripped, the values are a tuple with the actual name (with "aio_"), the + [call] and the [async_kind]. *) +let async_handle_calls : ((string * call) * async_kind) NameMap.t + handle_calls + |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls)) + |> List.filter_map (fun (name, call) -> + call.async_kind + |> Option.map (fun async_kind -> + (strip_aio name, ((name, call), async_kind)))) + |> List.to_seq |> NameMap.of_seq + +(* A mapping with all synchronous (not asynchronous) handle calls. Excluded + are also all synchronous calls that has an asynchronous counterpart. So if + "foo" is the name of a handle call and an asynchronous call "aio_foo" + exists, then "foo" will not b in this map. *) +let sync_handle_calls : call NameMap.t + handle_calls + |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls)) + |> List.filter (fun (name, _) -> + (not (NameMap.mem name async_handle_calls)) + && not + (String.starts_with ~prefix:"aio_" name + && NameMap.mem (strip_aio name) async_handle_calls)) + |> List.to_seq |> NameMap.of_seq + +(* Get the Rust type for an argument in the asynchronous API. Like + [rust_arg_type] but no static lifetime on some buffers. *) +let rust_async_arg_type : arg -> string = function + | BytesPersistIn _ -> "&[u8]" + | BytesPersistOut _ -> "&mut [u8]" + | x -> rust_arg_type x + +(* Get the Rust type for an optional argument in the asynchronous API. Like + [rust_optarg_type] but no static lifetime on some closures. *) +let rust_async_optarg_type : optarg -> string = function + | OClosure x -> sprintf "Option<%s>" (rust_async_arg_type (Closure x)) + | x -> rust_optarg_type x + +(* A string of the argument list for a method on the handle, with both + mandotory and optional arguments. *) +let rust_async_handle_call_args { args; optargs } : string + let rust_args_names + List.map rust_arg_name args @ List.map rust_optarg_name optargs + and rust_args_types + List.map rust_async_arg_type args + @ List.map rust_async_optarg_type optargs + in + String.concat ", " + (List.map2 (sprintf "%s: %s") rust_args_names rust_args_types) + +(* Print the Rust function for a not asynchronous handle call. *) +let print_rust_sync_handle_call name call + print_rust_handle_call_comment call; + pr "pub fn %s(&self, %s) -> %s\n" name + (rust_async_handle_call_args call) + (rust_ret_type call); + print_ffi_call name "self.data.handle.handle" call; + pr "\n" + +(* Print the Rust function for an asynchronous handle call with a completion + callback. (Note that "callback" might be abbreviated with "cb" in the + following code. *) +let print_rust_async_handle_call_with_completion_cb name (aio_name, call) + (* An array of all optional arguments. Useful because we need to deel with + the index of the completion callback. *) + let optargs = Array.of_list call.optargs in + (* The index of the completion callback in [optargs] *) + let completion_cb_index + Array.find_map + (fun (i, optarg) -> + match optarg with + | OClosure { cbname } -> + if cbname = "completion" then Some i else None + | _ -> None) + (Array.mapi (fun x y -> (x, y)) optargs) + in + let completion_cb_index + match completion_cb_index with + | Some x -> x + | None -> + failwithf + "The handle call %s is claimed to have a completion callback among \ + its optional arguments by the async_kind field, but so does not \ + seem to be the case." + aio_name + in + let optargs_before_completion_cb + Array.to_list (Array.sub optargs 0 completion_cb_index) + and optargs_after_completion_cb + Array.to_list + (Array.sub optargs (completion_cb_index + 1) + (Array.length optargs - (completion_cb_index + 1))) + in + (* All optional arguments excluding the completion callback. *) + let optargs_without_completion_cb + optargs_before_completion_cb @ optargs_after_completion_cb + in + print_rust_handle_call_comment call; + pr "pub async fn %s(&self, %s) -> SharedResult<()> {\n" name + (rust_async_handle_call_args + { call with optargs = optargs_without_completion_cb }); + pr " // A oneshot channel to notify when the call is completed.\n"; + pr " let (ret_tx, ret_rx) = oneshot::channel::<SharedResult<()>>();\n"; + pr " let (ccb_tx, mut ccb_rx) = oneshot::channel::<c_int>();\n"; + (* Completion callback: *) + pr " let %s = Some(utils::fn_once_to_fn_mut(|err: &mut i32| {\n" + (rust_optarg_name (Array.get optargs completion_cb_index)); + pr " ccb_tx.send(*err).ok();\n"; + pr " 1\n"; + pr " }));\n"; + (* End of completion callback. *) + print_ffi_call aio_name "self.data.handle.handle" call; + pr "?;\n"; + pr " let mut ret_tx = Some(ret_tx);\n"; + pr " let completion_predicate = \n"; + pr " move |_handle: &Handle, res: &SharedResult<()>| {\n"; + pr " let ret = match res {\n"; + pr " Err(e) if e.is_fatal() => res.clone(),\n"; + pr " _ => {\n"; + pr " let Ok(errno) = ccb_rx.try_recv() else { return false; };\n"; + pr " if errno == 0 {\n"; + pr " Ok(())\n"; + pr " } else {\n"; + pr " if let Err(e) = res {\n"; + pr " Err(e.clone())\n"; + pr " } else {\n"; + pr " Err(Arc::new("; + pr " Error::Recoverable(ErrorKind::from_errno(errno))))\n"; + pr " }\n"; + pr " }\n"; + pr " },\n"; + pr " };\n"; + pr " ret_tx.take().unwrap().send(ret).ok();\n"; + pr " true\n"; + pr " };\n"; + pr " self.add_command(completion_predicate)?;\n"; + pr " ret_rx.await.unwrap()\n"; + pr "}\n\n" + +(* Print a Rust function for an asynchronous handle call which signals + completion by changing state. The predicate is a call like + "aio_is_connecting" which should get the value (like false) for the call to + be complete. *) +let print_rust_async_handle_call_changing_state name (aio_name, call) + (predicate, value) + let value = if value then "true" else "false" in + print_rust_handle_call_comment call; + pr "pub async fn %s(&self, %s) -> SharedResult<()>\n" name + (rust_async_handle_call_args call); + pr "{\n"; + print_ffi_call aio_name "self.data.handle.handle" call; + pr "?;\n"; + pr " let (ret_tx, ret_rx) = oneshot::channel::<SharedResult<()>>();\n"; + pr " let mut ret_tx = Some(ret_tx);\n"; + pr " let completion_predicate = \n"; + pr " move |handle: &Handle, res: &SharedResult<()>| {\n"; + pr " let ret = if let Err(_) = res {\n"; + pr " res.clone()\n"; + pr " } else {\n"; + pr " if handle.%s() != %s { return false; }\n" predicate value; + pr " else { Ok(()) }\n"; + pr " };\n"; + pr " ret_tx.take().unwrap().send(ret).ok();\n"; + pr " true\n"; + pr " };\n"; + pr " self.add_command(completion_predicate)?;\n"; + pr " ret_rx.await.unwrap()\n"; + pr "}\n\n" + +(* Print an impl with all handle calls. *) +let print_rust_async_handle_impls () + pr "impl AsyncHandle {\n"; + NameMap.iter print_rust_sync_handle_call sync_handle_calls; + NameMap.iter + (fun name (call, async_kind) -> + match async_kind with + | WithCompletionCallback -> + print_rust_async_handle_call_with_completion_cb name call + | ChangesState (predicate, value) -> + print_rust_async_handle_call_changing_state name call + (predicate, value)) + async_handle_calls; + pr "}\n\n" + +let print_rust_async_imports () + pr "use crate::{*, types::*};\n"; + pr "use os_socketaddr::OsSocketAddr;\n"; + pr "use std::ffi::*;\n"; + pr "use std::mem;\n"; + pr "use std::net::SocketAddr;\n"; + pr "use std::os::fd::{AsRawFd, OwnedFd};\n"; + pr "use std::os::unix::prelude::*;\n"; + pr "use std::path::PathBuf;\n"; + pr "use std::ptr;\n"; + pr "use std::sync::Arc;\n"; + pr "use tokio::sync::oneshot;\n"; + pr "\n" + +let generate_rust_async_bindings () + generate_header CStyle ~copyright:"Tage Johansson"; + pr "\n"; + print_rust_async_imports (); + print_rust_async_handle_impls () diff --git a/generator/Rust.mli b/generator/Rust.mli index 450e4ca..0960170 100644 --- a/generator/Rust.mli +++ b/generator/Rust.mli @@ -18,3 +18,5 @@ (* Print all flag-structs, enums, constants and handle calls in Rust code. *) val generate_rust_bindings : unit -> unit + +val generate_rust_async_bindings : unit -> unit diff --git a/generator/generator.ml b/generator/generator.ml index 8c9a585..11bec4d 100644 --- a/generator/generator.ml +++ b/generator/generator.ml @@ -69,3 +69,5 @@ let () RustSys.generate_rust_sys_bindings; output_to ~formatter:(Some Rustfmt) "rust/src/bindings.rs" Rust.generate_rust_bindings; + output_to ~formatter:(Some Rustfmt) "rust/src/async_bindings.rs" + Rust.generate_rust_async_bindings; diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 01555de..a9b5988 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -48,9 +48,11 @@ os_socketaddr = "0.2.4" thiserror = "1.0.40" log = { version = "0.4.19", optional = true } libc = "0.2.147" +tokio = { optional = true, version = "1.29.1", default-features = false, features = ["rt", "sync", "net"] } +epoll = "4.3.3" [features] -default = ["log"] +default = ["log", "tokio"] [dev-dependencies] anyhow = "1.0.72" diff --git a/rust/Makefile.am b/rust/Makefile.am index 7098c9a..2b5b85b 100644 --- a/rust/Makefile.am +++ b/rust/Makefile.am @@ -19,6 +19,7 @@ include $(top_srcdir)/subdir-rules.mk generator_built = \ libnbd-sys/src/generated.rs \ + src/async_bindings.rs \ src/bindings.rs \ $(NULL) @@ -30,6 +31,7 @@ source_files = \ src/handle.rs \ src/types.rs \ src/utils.rs \ + src/async_handle.rs \ examples/connect-command.rs \ examples/get-size.rs \ examples/fetch-first-sector.rs \ diff --git a/rust/src/async_handle.rs b/rust/src/async_handle.rs new file mode 100644 index 0000000..4223b80 --- /dev/null +++ b/rust/src/async_handle.rs @@ -0,0 +1,268 @@ +// nbd client library in userspace +// Copyright Tage Johansson +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +// This module implements an asynchronous handle working on top of the +// [Tokio](https://tokio.rs) runtime. When the handle is created, +// a "polling task" is spawned on the Tokio runtime. The purpose of that +// "polling task" is to call `aio_notify_*` when appropriate. It shares a +// reference to the handle as well as some other things with the handle in the +// [HandleData] struct. The "polling task" is sleeping when no command is in +// flight, but wakes up as soon as any command is issued. +// +// The commands are implemented as +// [`async fn`s](https://doc.rust-lang.org/std/keyword.async.html) +// in async_bindings.rs. When a new command is issued, it registers a +// completion predicate with [Handle::add_command]. That predicate takes a +// reference to the handle and should return [true] iff the command is complete. +// Whenever some work is performed in the polling task, the completion +// predicates for all pending commands are called. + +#![allow(unused_imports)] // XXX: remove this +use crate::sys; +use crate::Handle; +use crate::{Error, FatalErrorKind, Result}; +use crate::{AIO_DIRECTION_BOTH, AIO_DIRECTION_READ, AIO_DIRECTION_WRITE}; +use epoll::Events; +use std::sync::Arc; +use std::sync::Mutex; +use tokio::io::{unix::AsyncFd, Interest, Ready as IoReady}; +use tokio::sync::{broadcast, Notify}; +use tokio::task; + +/// A custom result type with a shared [crate::Error] as default error type. +pub type SharedResult<T, E = Arc<Error>> = Result<T, E>; + +/// An NBD handle using Rust's `async` functionality on top of the +/// [Tokio](https://docs.rs/tokio/) runtime. +pub struct AsyncHandle { + /// Data shared both by this struct and the polling task. + pub(crate) data: Arc<HandleData>, + + /// A task which soely purpose is to poll the NBD handle. + polling_task: tokio::task::AbortHandle, +} + +pub(crate) struct HandleData { + /// The underliing handle. + pub handle: Handle, + + /// A list of all pending commands. + /// + /// For every pending command (commands in flight), a predicate will be + /// stored in this list. Whenever some progress is made on the file + /// descriptor, the predicate is called with a reference to the handle + /// and a reference to the result of that call to `aio_notify_*`. + /// Iff the predicate returns [true], the command is considered completed + /// and removed from this list. + /// + /// If The polling task dies for some reason, this [SharedResult] will be + /// set to some error. + pub pending_commands: Mutex< + SharedResult< + Vec< + Box< + dyn FnMut(&Handle, &SharedResult<()>) -> bool + + Send + + Sync + + 'static, + >, + >, + >, + >, + + /// A notifier used by commands to notify the polling task when a new + /// asynchronous command is issued. + pub new_command: Notify, +} + +impl AsyncHandle { + pub fn new() -> Result<Self> { + let handle_data = Arc::new(HandleData { + handle: Handle::new()?, + pending_commands: Mutex::new(Ok(Vec::new())), + new_command: Notify::new(), + }); + + let handle_data_2 = handle_data.clone(); + let polling_task = task::spawn(async move { + // The polling task should never finish without an error. If the + // handle is dropped, the task is aborted so it'll not return in + // that case either. + let Err(err) = polling_task(&handle_data_2).await else { + unreachable!() + }; + let err = Arc::new(Error::Fatal(err)); + // Call the completion predicates for all pending commands with the + // error. + let mut pending_cmds + handle_data_2.pending_commands.lock().unwrap(); + let res = Err(err); + for f in pending_cmds.as_mut().unwrap().iter_mut() { + f(&handle_data_2.handle, &res); + } + *pending_cmds = Err(res.unwrap_err()); + }) + .abort_handle(); + Ok(Self { + data: handle_data, + polling_task, + }) + } + + /// Get the underliing C pointer to the handle. + pub(crate) fn raw_handle(&self) -> *mut sys::nbd_handle { + self.data.handle.raw_handle() + } + + /// Call this method when a new command is issued. As argument is passed a + /// predicate which should return [true] iff the command is completed. + pub(crate) fn add_command( + &self, + mut completion_predicate: impl FnMut(&Handle, &SharedResult<()>) -> bool + + Send + + Sync + + 'static, + ) -> SharedResult<()> { + if !completion_predicate(&self.data.handle, &Ok(())) { + let mut pending_cmds_lock + self.data.pending_commands.lock().unwrap(); + pending_cmds_lock + .as_mut() + .map_err(|e| e.clone())? + .push(Box::new(completion_predicate)); + self.data.new_command.notify_one(); + } + Ok(()) + } +} + +impl Drop for AsyncHandle { + fn drop(&mut self) { + self.polling_task.abort(); + } +} + +/// Get the read/write direction that the handle wants on the file descriptor. +fn get_fd_interest(handle: &Handle) -> Option<Interest> { + match handle.aio_get_direction() { + 0 => None, + AIO_DIRECTION_READ => Some(Interest::READABLE), + AIO_DIRECTION_WRITE => Some(Interest::WRITABLE), + AIO_DIRECTION_BOTH => Some(Interest::READABLE | Interest::WRITABLE), + _ => unreachable!(), + } +} + +/// A task that will run as long as the handle is alive. It will poll the +/// file descriptor when new data is availlable. +async fn polling_task(handle_data: &HandleData) -> Result<(), FatalErrorKind> { + let HandleData { + handle, + pending_commands, + new_command, + } = handle_data; + let fd = handle.aio_get_fd().map_err(Error::to_fatal)?; + // XXX: Might the file descriptor ever be changed? + let tokio_fd = AsyncFd::new(fd)?; + let epfd = epoll::create(false)?; + epoll::ctl( + epfd, + epoll::ControlOptions::EPOLL_CTL_ADD, + fd, + epoll::Event::new(Events::EPOLLIN | Events::EPOLLOUT, 42), + )?; + + // The following loop does approximately the following things: + // + // 1. Determine what Libnbd wants to do next on the file descriptor, + // (read/write/both/none), and store that in [fd_interest]. + // 2. Wait for either: + // a) That interest to be available on the file descriptor in which case: + // I. Call the correct `aio_notify_*` method. + // II. Execute step 1. + // III. Send the result of the call to `aio_notify_*` on + // [result_channel] to notify pending commands that some progress + // has been made. + // IV. Resume execution from step 2. + // b) A notification was received on [new_command] signaling that a new + // command was registered and that the intrest on the file descriptor + // might has changed. Resume execution from step 1. + loop { + let Some(fd_interest) = get_fd_interest(handle) else { + // The handle does not wait for any data of the file descriptor, + // so we wait until some command is issued. + new_command.notified().await; + continue; + }; + + if pending_commands + .lock() + .unwrap() + .as_ref() + .unwrap() + .is_empty() + { + // No command is pending so there is no point to do anything. + new_command.notified().await; + continue; + } + + // Wait for the requested interest to be available on the fd. + let mut ready_guard = tokio_fd.ready(fd_interest).await?; + let readyness = ready_guard.ready(); + let res = if readyness.is_readable() && fd_interest.is_readable() { + handle.aio_notify_read() + } else if readyness.is_writable() && fd_interest.is_writable() { + handle.aio_notify_write() + } else { + continue; + }; + let res = match res { + Ok(()) => Ok(()), + Err(e @ Error::Recoverable(_)) => Err(Arc::new(e)), + Err(Error::Fatal(e)) => return Err(e), + }; + + // Call the completion predicates of all pending commands. + let mut pending_cmds_lock = pending_commands.lock().unwrap(); + let pending_cmds = pending_cmds_lock.as_mut().unwrap(); + let mut i = 0; + while i < pending_cmds.len() { + if (pending_cmds[i])(handle, &res) { + let _ = pending_cmds.swap_remove(i); + } else { + i += 1; + } + } + drop(pending_cmds_lock); + + // Use epoll to check the current read/write availability on the fd. + // This is needed because Tokio does only support edge-triggered + // notifications but Libnbd requires level-triggered notifications. + let mut revent = epoll::Event { data: 0, events: 0 }; + // Setting timeout to 0 means that it will return immediately. + epoll::wait(epfd, 0, std::slice::from_mut(&mut revent))?; + let revents = Events::from_bits(revent.events).unwrap(); + if !revents.contains(Events::EPOLLIN) { + ready_guard.clear_ready_matching(IoReady::READABLE); + } + if !revents.contains(Events::EPOLLOUT) { + ready_guard.clear_ready_matching(IoReady::WRITABLE); + } + ready_guard.retain_ready(); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index a6f3131..56316b4 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -17,11 +17,19 @@ #![deny(warnings)] +#[cfg(feature = "tokio")] +mod async_bindings; +#[cfg(feature = "tokio")] +mod async_handle; mod bindings; mod error; mod handle; pub mod types; mod utils; +#[cfg(feature = "tokio")] +pub use async_bindings::*; +#[cfg(feature = "tokio")] +pub use async_handle::{AsyncHandle, SharedResult}; pub use bindings::*; pub use error::{Error, ErrorKind, FatalErrorKind, Result}; pub use handle::Handle; diff --git a/rust/src/utils.rs b/rust/src/utils.rs index b8200c1..8984ebb 100644 --- a/rust/src/utils.rs +++ b/rust/src/utils.rs @@ -21,3 +21,12 @@ use std::ffi::c_void; pub unsafe extern "C" fn drop_data<T>(data: *mut c_void) { drop(Box::from_raw(data as *mut T)) } + +/// Turn a [FnOnce] (with a single `&mut` argument) to a [FnMut] +/// which panics on the second invocation. +pub fn fn_once_to_fn_mut<T, U>( + f: impl FnOnce(&mut T) -> U, +) -> impl FnMut(&mut T) -> U { + let mut f = Some(f); + move |x| (f.take().unwrap())(x) +} diff --git a/scripts/git.orderfile b/scripts/git.orderfile index b988d87..60ec56d 100644 --- a/scripts/git.orderfile +++ b/scripts/git.orderfile @@ -69,6 +69,7 @@ rust/src/types.rs rust/src/utils.rs rust/src/lib.rs rust/src/handle.rs +rust/src/async_handle.rs rust/libnbd-sys/* rust/examples/* rust/tests/* -- 2.41.0
Eric Blake
2023-Aug-24 21:55 UTC
[Libguestfs] [libnbd PATCH v8 06/10] rust: async: Create an async friendly handle type
On Sun, Aug 20, 2023 at 02:16:25PM +0000, Tage Johansson wrote:> Create another handle type: AsyncHandle, which makes use of Rust's > builtin asynchronous functions (see > <https://doc.rust-lang.org/std/keyword.async.html>) and runs on top of > the Tokio runtime (see <https://docs.rs/tokio>). For every asynchronous > command, like aio_connect(), a corresponding `async` method is created > on the handle. In this case it would be: > async fn connect(...) -> Result<(), ...> > When called, it will poll the file descriptor until the command is > complete, and then return with a result. All the synchronous > counterparts (like nbd_connect()) are excluded from this handle type > as they are unnecessary and since they might interfear with the polling > made by the Tokio runtime. For more details about how the asynchronous > commands are executed, please see the comments in > rust/src/async_handle.rs. > --- > generator/Rust.ml | 249 +++++++++++++++++++++++++++++++++++- > generator/Rust.mli | 2 + > generator/generator.ml | 2 + > rust/Cargo.toml | 4 +- > rust/Makefile.am | 2 + > rust/src/async_handle.rs | 268 +++++++++++++++++++++++++++++++++++++++ > rust/src/lib.rs | 8 ++ > rust/src/utils.rs | 9 ++ > scripts/git.orderfile | 1 + > 9 files changed, 538 insertions(+), 7 deletions(-) > create mode 100644 rust/src/async_handle.rs > > diff --git a/generator/Rust.ml b/generator/Rust.ml > index 431c814..1bc81f0 100644 > --- a/generator/Rust.ml > +++ b/generator/Rust.ml > @@ -61,11 +61,12 @@ let print_rust_flags { flag_prefix; flags } > let rec to_upper_snake_case s > let s = String.uppercase_ascii s in > let s = explode s in > - let s = filter_map ( > - function > - |'-' -> Some "_" | ':' -> None > - | ch -> Some (String.make 1 ch) > - ) s in > + let s > + filter_map > + (function > + | '-' -> Some "_" | ':' -> None | ch -> Some (String.make 1 ch)) > + s > + in > String.concat "" sThis looks like it is just reformatting. While cleanup patches are okay (and I trust Rich's take on OCaml style more than my own), it's cleaner to do them in separate patches.> > (* Split a string into a list of chars. In later OCaml we could > @@ -75,7 +76,7 @@ and explode str > let r = ref [] in > for i = 0 to String.length str - 1 do > let c = String.unsafe_get str i in > - r := c :: !r; > + r := c :: !r > done; > List.rev !r > > @@ -564,3 +565,239 @@ let generate_rust_bindings () > pr "impl Handle {\n"; > List.iter print_rust_handle_method handle_calls; > pr "}\n\n" > + > +(*********************************************************) > +(* The rest of the file conserns the asynchronous API. *)concerns> +(* *) > +(* See the comments in rust/src/async_handle.rs for more *) > +(* information about how it works. *) > +(*********************************************************) > + > +let excluded_handle_calls : NameSet.t > + NameSet.of_list > + [ > + "aio_get_fd"; > + "aio_get_direction"; > + "aio_notify_read"; > + "aio_notify_write"; > + "clear_debug_callback"; > + "get_debug"; > + "poll"; > + "poll2"; > + "set_debug"; > + "set_debug_callback"; > + ] > + > +(* A mapping with names as keys. *) > +module NameMap = Map.Make (String) > + > +(* Strip "aio_" from the beginning of a string. *) > +let strip_aio name : string > + if String.starts_with ~prefix:"aio_" name then > + String.sub name 4 (String.length name - 4) > + else failwithf "Asynchronous call %s must begin with aio_" name > + > +(* A map with all asynchronous handle calls. The keys are names with "aio_" > + stripped, the values are a tuple with the actual name (with "aio_"), the > + [call] and the [async_kind]. *) > +let async_handle_calls : ((string * call) * async_kind) NameMap.tDo we need a 2-deep nested tuple, or can we use (string * call * async_kind)?> + handle_calls > + |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls)) > + |> List.filter_map (fun (name, call) -> > + call.async_kind > + |> Option.map (fun async_kind -> > + (strip_aio name, ((name, call), async_kind)))) > + |> List.to_seq |> NameMap.of_seq > + > +(* A mapping with all synchronous (not asynchronous) handle calls. Excluded > + are also all synchronous calls that has an asynchronous counterpart. So ifs/has/have/> + "foo" is the name of a handle call and an asynchronous call "aio_foo" > + exists, then "foo" will not b in this map. *)s/b /be /> +let sync_handle_calls : call NameMap.t > + handle_calls > + |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls)) > + |> List.filter (fun (name, _) -> > + (not (NameMap.mem name async_handle_calls)) > + && not > + (String.starts_with ~prefix:"aio_" name > + && NameMap.mem (strip_aio name) async_handle_calls)) > + |> List.to_seq |> NameMap.of_seq > + > +(* Get the Rust type for an argument in the asynchronous API. Like > + [rust_arg_type] but no static lifetime on some buffers. *) > +let rust_async_arg_type : arg -> string = function > + | BytesPersistIn _ -> "&[u8]" > + | BytesPersistOut _ -> "&mut [u8]" > + | x -> rust_arg_type x > + > +(* Get the Rust type for an optional argument in the asynchronous API. Like > + [rust_optarg_type] but no static lifetime on some closures. *) > +let rust_async_optarg_type : optarg -> string = function > + | OClosure x -> sprintf "Option<%s>" (rust_async_arg_type (Closure x)) > + | x -> rust_optarg_type x > + > +(* A string of the argument list for a method on the handle, with both > + mandotory and optional arguments. *) > +let rust_async_handle_call_args { args; optargs } : string > + let rust_args_names > + List.map rust_arg_name args @ List.map rust_optarg_name optargs > + and rust_args_types > + List.map rust_async_arg_type args > + @ List.map rust_async_optarg_type optargs > + in > + String.concat ", " > + (List.map2 (sprintf "%s: %s") rust_args_names rust_args_types) > + > +(* Print the Rust function for a not asynchronous handle call. *)s/not asynchronous/synchronous/> +let print_rust_sync_handle_call name call > + print_rust_handle_call_comment call; > + pr "pub fn %s(&self, %s) -> %s\n" name > + (rust_async_handle_call_args call) > + (rust_ret_type call); > + print_ffi_call name "self.data.handle.handle" call; > + pr "\n" > + > +(* Print the Rust function for an asynchronous handle call with a completion > + callback. (Note that "callback" might be abbreviated with "cb" in the > + following code. *) > +let print_rust_async_handle_call_with_completion_cb name (aio_name, call) > + (* An array of all optional arguments. Useful because we need to deel withs/deel/deal/> + the index of the completion callback. *) > + let optargs = Array.of_list call.optargs in > + (* The index of the completion callback in [optargs] *) > + let completion_cb_index > + Array.find_map > + (fun (i, optarg) -> > + match optarg with > + | OClosure { cbname } -> > + if cbname = "completion" then Some i else None > + | _ -> None) > + (Array.mapi (fun x y -> (x, y)) optargs) > + in > + let completion_cb_index > + match completion_cb_index with > + | Some x -> x > + | None -> > + failwithf > + "The handle call %s is claimed to have a completion callback among \ > + its optional arguments by the async_kind field, but so does not \s/so/that/> + seem to be the case." > + aio_name > + in > +++ b/rust/Cargo.toml > @@ -48,9 +48,11 @@ os_socketaddr = "0.2.4" > thiserror = "1.0.40" > log = { version = "0.4.19", optional = true } > libc = "0.2.147" > +tokio = { optional = true, version = "1.29.1", default-features = false, features = ["rt", "sync", "net"] } > +epoll = "4.3.3" > > [features] > -default = ["log"] > +default = ["log", "tokio"]It looks like you intend for tokio to be an optional dependency (you always get the bare-bones Rust bindings, but if tokio is installed, you also get the AsyncHandle bindings). Do we need to document that in README at all? Is there an easy way to set up CI tests to cover builds both with and without tokio, so that we can ensure we don't break builds on someone who chooses not to install the optional dependency?> +++ b/rust/src/async_handle.rs > @@ -0,0 +1,268 @@> + > +#![allow(unused_imports)] // XXX: remove thisHow hard is it to fix this line?> +use crate::sys; > +use crate::Handle; > +use crate::{Error, FatalErrorKind, Result}; > +use crate::{AIO_DIRECTION_BOTH, AIO_DIRECTION_READ, AIO_DIRECTION_WRITE}; > +use epoll::Events; > +use std::sync::Arc; > +use std::sync::Mutex; > +use tokio::io::{unix::AsyncFd, Interest, Ready as IoReady}; > +use tokio::sync::{broadcast, Notify}; > +use tokio::task; > + > +/// A custom result type with a shared [crate::Error] as default error type. > +pub type SharedResult<T, E = Arc<Error>> = Result<T, E>; > + > +/// An NBD handle using Rust's `async` functionality on top of the > +/// [Tokio](https://docs.rs/tokio/) runtime. > +pub struct AsyncHandle { > + /// Data shared both by this struct and the polling task. > + pub(crate) data: Arc<HandleData>, > + > + /// A task which soely purpose is to poll the NBD handle.s/soely/sole/> + polling_task: tokio::task::AbortHandle, > +} > + > +pub(crate) struct HandleData { > + /// The underliing handle. > + pub handle: Handle, > + > + /// A list of all pending commands. > + /// > + /// For every pending command (commands in flight), a predicate will be > + /// stored in this list. Whenever some progress is made on the file > + /// descriptor, the predicate is called with a reference to the handle > + /// and a reference to the result of that call to `aio_notify_*`. > + /// Iff the predicate returns [true], the command is considered completed > + /// and removed from this list. > + /// > + /// If The polling task dies for some reason, this [SharedResult] will be > + /// set to some error. > + pub pending_commands: Mutex< > + SharedResult< > + Vec< > + Box< > + dyn FnMut(&Handle, &SharedResult<()>) -> bool > + + Send > + + Sync > + + 'static, > + >, > + >, > + >, > + >, > + > + /// A notifier used by commands to notify the polling task when a new > + /// asynchronous command is issued. > + pub new_command: Notify, > +} > + > +impl AsyncHandle { > + pub fn new() -> Result<Self> { > + let handle_data = Arc::new(HandleData { > + handle: Handle::new()?, > + pending_commands: Mutex::new(Ok(Vec::new())), > + new_command: Notify::new(), > + }); > + > + let handle_data_2 = handle_data.clone(); > + let polling_task = task::spawn(async move { > + // The polling task should never finish without an error. If the > + // handle is dropped, the task is aborted so it'll not return ins/it'll not/it won't/> + // that case either. > + let Err(err) = polling_task(&handle_data_2).await else { > + unreachable!() > + }; > + let err = Arc::new(Error::Fatal(err)); > + // Call the completion predicates for all pending commands with the > + // error. > + let mut pending_cmds > + handle_data_2.pending_commands.lock().unwrap(); > + let res = Err(err); > + for f in pending_cmds.as_mut().unwrap().iter_mut() { > + f(&handle_data_2.handle, &res); > + } > + *pending_cmds = Err(res.unwrap_err()); > + }) > + .abort_handle(); > + Ok(Self { > + data: handle_data, > + polling_task, > + }) > + } > + > + /// Get the underliing C pointer to the handle. > + pub(crate) fn raw_handle(&self) -> *mut sys::nbd_handle { > + self.data.handle.raw_handle() > + } > + > + /// Call this method when a new command is issued. As argument is passed aNot sure if you meant 'An argument is passed' or something else here.> + /// predicate which should return [true] iff the command is completed. > + pub(crate) fn add_command( > + &self, > + mut completion_predicate: impl FnMut(&Handle, &SharedResult<()>) -> bool > + + Send > + + Sync > + + 'static, > + ) -> SharedResult<()> { > + if !completion_predicate(&self.data.handle, &Ok(())) { > + let mut pending_cmds_lock > + self.data.pending_commands.lock().unwrap(); > + pending_cmds_lock > + .as_mut() > + .map_err(|e| e.clone())? > + .push(Box::new(completion_predicate)); > + self.data.new_command.notify_one(); > + } > + Ok(()) > + } > +} > + > +impl Drop for AsyncHandle { > + fn drop(&mut self) { > + self.polling_task.abort(); > + } > +} > + > +/// Get the read/write direction that the handle wants on the file descriptor. > +fn get_fd_interest(handle: &Handle) -> Option<Interest> { > + match handle.aio_get_direction() { > + 0 => None, > + AIO_DIRECTION_READ => Some(Interest::READABLE), > + AIO_DIRECTION_WRITE => Some(Interest::WRITABLE), > + AIO_DIRECTION_BOTH => Some(Interest::READABLE | Interest::WRITABLE), > + _ => unreachable!(), > + } > +} > + > +/// A task that will run as long as the handle is alive. It will poll the > +/// file descriptor when new data is availlable.available> +async fn polling_task(handle_data: &HandleData) -> Result<(), FatalErrorKind> { > + let HandleData { > + handle, > + pending_commands, > + new_command, > + } = handle_data; > + let fd = handle.aio_get_fd().map_err(Error::to_fatal)?; > + // XXX: Might the file descriptor ever be changed? > + let tokio_fd = AsyncFd::new(fd)?;Regarding the XXX, my understanding is that aio_get_fd() returns the same fd for the life of the handle once a connection is established, so you can drop the comment. Changing the fd would imply creating a new socket, but we don't have automatic internal reconnect built into libnbd at this time. (You can do external reconnect by opening a new NBD handle - but then it's obvious that you will call aio_get_fd() on the new handle)> + let epfd = epoll::create(false)?; > + epoll::ctl( > + epfd, > + epoll::ControlOptions::EPOLL_CTL_ADD, > + fd, > + epoll::Event::new(Events::EPOLLIN | Events::EPOLLOUT, 42), > + )?; > + > + // The following loop does approximately the following things: > + // > + // 1. Determine what Libnbd wants to do next on the file descriptor, > + // (read/write/both/none), and store that in [fd_interest]. > + // 2. Wait for either: > + // a) That interest to be available on the file descriptor in which case: > + // I. Call the correct `aio_notify_*` method. > + // II. Execute step 1. > + // III. Send the result of the call to `aio_notify_*` on > + // [result_channel] to notify pending commands that some progress > + // has been made. > + // IV. Resume execution from step 2. > + // b) A notification was received on [new_command] signaling that a new > + // command was registered and that the intrest on the file descriptorinterest> + // might has changed. Resume execution from step 1.s/has/have/> + loop { > + let Some(fd_interest) = get_fd_interest(handle) else { > + // The handle does not wait for any data of the file descriptor, > + // so we wait until some command is issued. > + new_command.notified().await; > + continue; > + }; > + > + if pending_commands > + .lock() > + .unwrap() > + .as_ref() > + .unwrap() > + .is_empty() > + { > + // No command is pending so there is no point to do anything. > + new_command.notified().await; > + continue; > + } > + > + // Wait for the requested interest to be available on the fd. > + let mut ready_guard = tokio_fd.ready(fd_interest).await?; > + let readyness = ready_guard.ready();Typical spelling is readiness; would affect later lines of code> + let res = if readyness.is_readable() && fd_interest.is_readable() { > + handle.aio_notify_read() > + } else if readyness.is_writable() && fd_interest.is_writable() { > + handle.aio_notify_write() > + } else { > + continue; > + }; > + let res = match res { > + Ok(()) => Ok(()), > + Err(e @ Error::Recoverable(_)) => Err(Arc::new(e)), > + Err(Error::Fatal(e)) => return Err(e), > + }; > + > + // Call the completion predicates of all pending commands. > + let mut pending_cmds_lock = pending_commands.lock().unwrap(); > + let pending_cmds = pending_cmds_lock.as_mut().unwrap(); > + let mut i = 0; > + while i < pending_cmds.len() { > + if (pending_cmds[i])(handle, &res) { > + let _ = pending_cmds.swap_remove(i); > + } else { > + i += 1; > + } > + } > + drop(pending_cmds_lock); > + > + // Use epoll to check the current read/write availability on the fd. > + // This is needed because Tokio does only support edge-triggereds/does only support/supports only/> + // notifications but Libnbd requires level-triggered notifications. > + let mut revent = epoll::Event { data: 0, events: 0 }; > + // Setting timeout to 0 means that it will return immediately. > + epoll::wait(epfd, 0, std::slice::from_mut(&mut revent))?; > + let revents = Events::from_bits(revent.events).unwrap(); > + if !revents.contains(Events::EPOLLIN) { > + ready_guard.clear_ready_matching(IoReady::READABLE); > + } > + if !revents.contains(Events::EPOLLOUT) { > + ready_guard.clear_ready_matching(IoReady::WRITABLE); > + } > + ready_guard.retain_ready(); > + } > +} > diff --git a/rust/src/lib.rs b/rust/src/lib.rs > index a6f3131..56316b4 100644 > --- a/rust/src/lib.rs > +++ b/rust/src/lib.rs > @@ -17,11 +17,19 @@ > > #![deny(warnings)] > > +#[cfg(feature = "tokio")] > +mod async_bindings; > +#[cfg(feature = "tokio")] > +mod async_handle; > mod bindings; > mod error; > mod handle; > pub mod types; > mod utils; > +#[cfg(feature = "tokio")] > +pub use async_bindings::*; > +#[cfg(feature = "tokio")] > +pub use async_handle::{AsyncHandle, SharedResult}; > pub use bindings::*; > pub use error::{Error, ErrorKind, FatalErrorKind, Result}; > pub use handle::Handle; > diff --git a/rust/src/utils.rs b/rust/src/utils.rs > index b8200c1..8984ebb 100644 > --- a/rust/src/utils.rs > +++ b/rust/src/utils.rs > @@ -21,3 +21,12 @@ use std::ffi::c_void; > pub unsafe extern "C" fn drop_data<T>(data: *mut c_void) { > drop(Box::from_raw(data as *mut T)) > } > + > +/// Turn a [FnOnce] (with a single `&mut` argument) to a [FnMut] > +/// which panics on the second invocation. > +pub fn fn_once_to_fn_mut<T, U>( > + f: impl FnOnce(&mut T) -> U, > +) -> impl FnMut(&mut T) -> U { > + let mut f = Some(f); > + move |x| (f.take().unwrap())(x) > +} > diff --git a/scripts/git.orderfile b/scripts/git.orderfile > index b988d87..60ec56d 100644 > --- a/scripts/git.orderfile > +++ b/scripts/git.orderfile > @@ -69,6 +69,7 @@ rust/src/types.rs > rust/src/utils.rs > rust/src/lib.rs > rust/src/handle.rs > +rust/src/async_handle.rs > rust/libnbd-sys/* > rust/examples/* > rust/tests/* > -- > 2.41.0 > > _______________________________________________ > Libguestfs mailing list > Libguestfs at redhat.com > https://listman.redhat.com/mailman/listinfo/libguestfs >I'm still learning Rust, so a lot of this I just have to trust, but overall the patch seems like a good framework. While I definitely found some typos to fix, I'm less certain on whethere there are any major implementation flaws. -- Eric Blake, Principal Software Engineer Red Hat, Inc. Virtualization: qemu.org | libguestfs.org