From 15501132916dfbc24f23b619e6d5408f258fc0d9 Mon Sep 17 00:00:00 2001 From: Jonas Maier <> Date: Wed, 11 Mar 2026 12:30:07 +0100 Subject: can wait for threads & processes with a timeout now --- Cargo.toml | 2 +- src/ctrlc.rs | 63 ++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 6 ++++ src/run/mod.rs | 41 ++++++++++++++++++++------- src/wait/child.rs | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/wait/mod.rs | 5 ++++ src/wait/thread.rs | 51 ++++++++++++++++++++++++++++++++++ 7 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 src/ctrlc.rs create mode 100644 src/wait/child.rs create mode 100644 src/wait/mod.rs create mode 100644 src/wait/thread.rs diff --git a/Cargo.toml b/Cargo.toml index 936e3ca..f7a738c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,4 @@ libc = "0.2.182" sqlite = "0.37.0" termios = "0.3" pish_derive = { path = "./pish_derive" } -nix = { version = "0.31.2", features = ["poll"] } +nix = { version = "0.31.2", features = ["poll", "signal"] } diff --git a/src/ctrlc.rs b/src/ctrlc.rs new file mode 100644 index 0000000..4c0153a --- /dev/null +++ b/src/ctrlc.rs @@ -0,0 +1,63 @@ +use crate::Session; +use libc::c_int; +use nix::sys::signal::*; +use std::sync::*; + +static SESSION: Mutex>>> = Mutex::new(None); + +fn handle() { + let Ok(mut se) = SESSION.lock() else { return }; + let Some(se) = se.as_mut() else { return }; + let Ok(mut se) = se.lock() else { return }; + se.ctrlc.pressed = true; +} + +extern "C" fn c_handle(_signal: c_int) { + // cannot propagate panic into C-land + let _ = std::panic::catch_unwind(|| { + if let Err(e) = std::panic::catch_unwind(handle) { + eprintln!("{e:?}"); // might panic + } + }); +} + +#[derive(Default)] +pub struct CtrlC { + pressed: bool, +} + +struct Teardown; +impl Drop for Teardown { + fn drop(&mut self) { + teardown(); + } +} + +fn teardown() { + unsafe { + let _ = signal(Signal::SIGINT, SigHandler::SigDfl); + } + if let Ok(mut se) = SESSION.lock() { + *se = None; + } +} + +#[must_use] +pub fn setup(session: Arc>) -> impl Drop { + *SESSION.lock().unwrap() = Some(session); + unsafe { + signal(Signal::SIGINT, SigHandler::Handler(c_handle)) + .expect("failed to set ctrl+c signal handler"); + } + Teardown +} + +pub fn peek(session: &Session) -> bool { + session.ctrlc.pressed +} + +pub fn pop(session: &mut Session) -> bool { + let x = session.ctrlc.pressed; + session.ctrlc.pressed = false; + x +} diff --git a/src/main.rs b/src/main.rs index 897cdb9..a7138f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ use std::time::Duration; pub mod basedir; pub mod completion; +pub mod ctrlc; pub mod cursor; pub mod date; pub mod defer; @@ -27,11 +28,13 @@ pub mod raw; pub mod reload; pub mod run; pub mod serialization; +pub mod wait; use linebuf::LineBuf; use raw::*; use crate::completion::PathCache; +use crate::ctrlc::CtrlC; use crate::cursor::{Direction, move_cursor}; use crate::history::HistoryEntry; use crate::parse::{Ast, PreExpansion}; @@ -82,6 +85,7 @@ pub struct Session { funs: HashMap>, socket_running: Option, path_cache: PathCache, + ctrlc: CtrlC, /// n before end of history.len() /// 0 == not checking history @@ -278,6 +282,7 @@ fn event_loop() { vars: HashMap::new(), funs: HashMap::new(), path_cache: Default::default(), + ctrlc: Default::default(), }; print!("{}", se.prompt()); @@ -286,6 +291,7 @@ fn event_loop() { completion::populate_path_cache(session.clone()); let _sock_dropper = export_fun::listen(session.clone()); + let _ctrlc = ctrlc::setup(session.clone()); loop { let mut buf = [0u8; 1]; diff --git a/src/run/mod.rs b/src/run/mod.rs index 5666574..6af1717 100644 --- a/src/run/mod.rs +++ b/src/run/mod.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use std::fs::File; use std::io::{PipeReader, PipeWriter}; use std::path::PathBuf; -use std::process::Child; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; use crate::parse::{self, Ast, PostExpansion, PreExpansion}; +use crate::wait::{ChildWaiter, ThreadWaiter}; use crate::*; mod builtin; @@ -147,18 +147,24 @@ impl Write for Output { } enum SpawnedCmd { - Builtin(JoinHandle>), - Fun(JoinHandle>), - Child(Child), + Builtin(ThreadWaiter>), + Fun(ThreadWaiter>), + Child(ChildWaiter), SpawnError(io::Error), + Joined(Result<(), ExecError>), } impl SpawnedCmd { fn join(self) -> Result<(), ExecError> { match self { - SpawnedCmd::Builtin(handle) => handle.join().map_err(|_| ExecError::Panic)??, - SpawnedCmd::Fun(handle) => handle.join().map_err(|_| ExecError::Panic)??, - SpawnedCmd::Child(mut child) => { + SpawnedCmd::Builtin(handle) => { + handle.into_inner().join().map_err(|_| ExecError::Panic)?? + } + SpawnedCmd::Fun(handle) => { + handle.into_inner().join().map_err(|_| ExecError::Panic)?? + } + SpawnedCmd::Child(child) => { + let mut child = child.into_inner(); let exit_code = child.wait()?; match exit_code.code() { Some(0) => (), @@ -167,9 +173,24 @@ impl SpawnedCmd { } } SpawnedCmd::SpawnError(err) => Err(err)?, + SpawnedCmd::Joined(res) => res?, } Ok(()) } + + /// returns whether the spawned command is already joined + fn join_timeout(&mut self, timeout_ms: u16) -> bool { + match self { + SpawnedCmd::Builtin(tw) => tw.try_join(timeout_ms), + SpawnedCmd::Fun(tw) => tw.try_join(timeout_ms), + SpawnedCmd::Child(child) => match child.wait(timeout_ms) { + Ok(None) => false, + _ => true, + }, + SpawnedCmd::SpawnError(_) => true, + SpawnedCmd::Joined(_) => true, + } + } } impl Executor { @@ -192,7 +213,7 @@ impl Executor { CommandKind::Builtin(builtin) => { builtin.special(self.se.clone(), &args[1..]); let cloned_session = self.se.clone(); - let handle = std::thread::spawn(move || { + let handle = wait::spawn(move || { builtin.io(cloned_session, &args[1..], &mut stdin, &mut stdout) }); SpawnedCmd::Builtin(handle) @@ -201,7 +222,7 @@ impl Executor { let mut this = self.clone(); this.args = Some(args); - let handle = std::thread::spawn(move || { + let handle = wait::spawn(move || { let ast = ast.expand(&mut this)?; this.execute(ast, stdin, stdout)?; Ok(()) @@ -221,7 +242,7 @@ impl Executor { crate::export_fun::prepare_command(self.se.clone(), &mut command); match command.spawn() { - Ok(c) => SpawnedCmd::Child(c), + Ok(c) => SpawnedCmd::Child(ChildWaiter::new(c).unwrap()), Err(e) => SpawnedCmd::SpawnError(e), } } diff --git a/src/wait/child.rs b/src/wait/child.rs new file mode 100644 index 0000000..29a7d70 --- /dev/null +++ b/src/wait/child.rs @@ -0,0 +1,81 @@ +//! based on https://www.man7.org/linux/man-pages/man2/pidfd_open.2.html +#![cfg(target_os = "linux")] + +use std::{ + io, + mem::ManuallyDrop, + ops::{Deref, DerefMut}, + os::fd::{BorrowedFd, RawFd}, + process::{Child, ExitStatus}, + ptr, +}; + +use libc::{SYS_pidfd_open, syscall}; +use nix::poll::{PollFd, PollFlags}; + +pub struct ChildWaiter { + fd: RawFd, + child: Child, +} + +#[derive(Debug)] +struct PidFdOpenError; + +impl std::fmt::Display for PidFdOpenError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "pid_fd_open") + } +} + +impl std::error::Error for PidFdOpenError {} + +impl ChildWaiter { + pub fn new(child: Child) -> io::Result { + let fd = unsafe { syscall(SYS_pidfd_open, child.id(), 0) }; + if fd < 0 { + Err(io::Error::new(io::ErrorKind::Other, PidFdOpenError)) + } else { + let fd = fd as RawFd; + Ok(Self { child, fd }) + } + } + + pub fn wait(&mut self, timeout_ms: u16) -> io::Result> { + let mut poll_fds = [PollFd::new( + unsafe { BorrowedFd::borrow_raw(self.fd) }, + PollFlags::POLLIN, + )]; + let _ = nix::poll::poll(&mut poll_fds, timeout_ms); + self.child.try_wait() + } + + pub fn into_inner(self) -> Child { + unsafe { + libc::close(self.fd); + } + let this = ManuallyDrop::new(self); + unsafe { ptr::read(&this.child) } + } +} + +impl Deref for ChildWaiter { + type Target = Child; + + fn deref(&self) -> &Self::Target { + &self.child + } +} + +impl DerefMut for ChildWaiter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.child + } +} + +impl Drop for ChildWaiter { + fn drop(&mut self) { + unsafe { + libc::close(self.fd); + } + } +} diff --git a/src/wait/mod.rs b/src/wait/mod.rs new file mode 100644 index 0000000..63a083b --- /dev/null +++ b/src/wait/mod.rs @@ -0,0 +1,5 @@ +mod child; +mod thread; + +pub use child::*; +pub use thread::*; diff --git a/src/wait/thread.rs b/src/wait/thread.rs new file mode 100644 index 0000000..0eadc3c --- /dev/null +++ b/src/wait/thread.rs @@ -0,0 +1,51 @@ +use std::{ + sync::mpsc::{Receiver, channel}, + thread::JoinHandle, time::Duration, +}; + +use crate::defer; + +pub struct ThreadWaiter { + handle: JoinHandle, + chan: Receiver<()>, + done: bool, +} + +pub fn spawn(fun: F) -> ThreadWaiter +where + T: Send + 'static, + F: Send + 'static, + F: FnOnce() -> T, +{ + let (tx, rx) = channel(); + + let handle = std::thread::spawn(move || { + defer! { + let _ = tx.send(()); + }; + fun() + }); + + ThreadWaiter { + handle, + chan: rx, + done: false, + } +} + +impl ThreadWaiter { + pub fn try_join(&mut self, timeout_ms: u16) -> bool { + if self.done { + return true; + } + + if let Ok(()) = self.chan.recv_timeout(Duration::from_millis(timeout_ms as _)) { + self.done = true; + } + + self.done + } + pub fn into_inner(self) -> JoinHandle { + self.handle + } +} -- cgit v1.2.3