//! allow sub-programs to invoke arbitrary user-defined functions via a unix socket use libc::gid_t; use libc::uid_t; use nix::poll::PollFd; use nix::poll::PollFlags; use nix::poll::poll; use crate::Session; use crate::defer; use crate::rw::*; use crate::run::get_command_kind; use std::env::current_exe; use std::ffi::OsStr; use std::fs; use std::fs::File; use std::io; use std::io::IoSliceMut; use std::io::Read; use std::io::Write; use std::os::fd::AsFd; use std::os::fd::FromRawFd; use std::os::fd::OwnedFd; use std::os::unix::ffi::OsStrExt; use std::os::unix::fs::symlink; use std::os::unix::net::AncillaryData; use std::os::unix::net::SocketAncillary; use std::os::unix::net::SocketCred; use std::os::unix::net::UnixListener; use std::os::unix::net::UnixStream; use std::path::PathBuf; use std::process::Command; use std::process::exit; use std::sync::Arc; use std::sync::Mutex; use std::sync::mpsc; use std::sync::mpsc::Receiver; use std::thread; use std::time::Duration; const PATH: &str = "PATH"; const DATA: &str = "PISH_DATA"; fn get_uid() -> gid_t { unsafe { libc::getuid() } } fn get_gid() -> gid_t { unsafe { libc::getgid() } } fn handle_server(session: Arc>, mut stream: UnixStream) -> io::Result<()> { // TODO: figure out how to get a reasonable limit on CLI arg len let mut buf = [0u8; 8192]; let mut iov = [IoSliceMut::new(&mut buf)]; let mut ancillary_buf = [0u8; 128]; let mut ancillary = SocketAncillary::new(&mut ancillary_buf); let mut fds: Vec = Vec::new(); let bytelen = stream.recv_vectored_with_ancillary(&mut iov, &mut ancillary)?; let mut explicit_auth = false; for msg in ancillary.messages() { let Ok(msg) = msg else { continue }; match msg { AncillaryData::ScmRights(rights) => { for fd in rights { fds.push(fd); } } AncillaryData::ScmCredentials(creds) => { for cred in creds { eprintln!("cred: {}/{}", cred.get_uid(), cred.get_gid()); if cred.get_gid() == get_gid() && cred.get_uid() == get_uid() { explicit_auth = true; } } } } } let mut implicit_auth = false; if let Ok(peer) = stream.peer_cred() && peer.uid == get_uid() && peer.gid == get_gid() { implicit_auth = true; } if !explicit_auth && !implicit_auth { return Ok(()); } if fds.len() != 3 { // malformed return Ok(()); } let Ok(cli_args) = crate::serialization::deserialize_cli_args(&buf[..bytelen]) else { // cli args malformed return Ok(()); }; if cli_args.is_empty() { // malformed return Ok(()); }; let se = session.lock().unwrap(); match get_command_kind(&se, cli_args[0].as_slice()) { crate::run::CommandKind::Fun(_) => (), crate::run::CommandKind::Path(_) | crate::run::CommandKind::Builtin(_) => { return Ok(()); } } drop(se); let stdin = File::from(unsafe { OwnedFd::from_raw_fd(fds[0]) }); let stdout = File::from(unsafe { OwnedFd::from_raw_fd(fds[1]) }); let res = crate::run::Executor::execute_fun( session, cli_args, Input::File(stdin), Output::File(stdout), ); let exit_code = match res { Ok(_) => 0, Err(e) => match e { crate::run::ExecError::ExecError(x) => x, _ => -3, }, }; let _ = stream.set_write_timeout(Some(Duration::from_secs(1))); stream.write_all(&exit_code.to_le_bytes())?; Ok(()) } fn handle_client(mut stream: UnixStream, uid: uid_t, gid: gid_t) -> io::Result<()> { // give up all my file descriptors let mut ancillary_buffer = [0; 128]; let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]); assert!(ancillary.add_fds(&[0, 1, 2])); // add credentials let mut creds = SocketCred::new(); creds.set_pid(std::process::id() as _); creds.set_uid(uid); creds.set_gid(gid); assert!(ancillary.add_creds(&[creds])); // cli params let buf = crate::serialization::serialize_cli_args(); let bufs = &mut [io::IoSlice::new(&buf[..])][..]; // send stream.send_vectored_with_ancillary(bufs, &mut ancillary)?; // recv exit code let mut exit_buf = [0; 4]; let res = stream.read_exact(&mut exit_buf); let exit_code = match res { Ok(_) => i32::from_le_bytes(exit_buf), Err(_) => -2, }; exit(exit_code) } /// sets up the commands `PATH` to allow it to invoke this shell sessions user-defined functions pub fn prepare_command(session: Arc>, cmd: &mut Command) { let Ok(session) = session.lock() else { return; }; let Some(sr) = session.socket_running.as_ref() else { return; }; let my_path = std::env::var_os(PATH).expect("no PATH - seriously?"); let mut new_path = sr.bin_path.as_os_str().as_bytes().to_vec(); new_path.push(b':'); new_path.extend_from_slice(my_path.as_bytes()); // we attach uid/gid such that launched subprocess know as which user to authenticate // it might be sudo-ed, which can again authenticate as the expected uid/gid // a subprocess with CAP_ADMIN can also impersonate uid/gids to the best of my // knowledge let mut data = Vec::new(); data.extend_from_slice(sr.socket_path.as_os_str().as_bytes()); data.push(b':'); data.extend_from_slice(get_uid().to_string().as_bytes()); data.push(b':'); data.extend_from_slice(get_gid().to_string().as_bytes()); cmd.env(PATH, OsStr::from_bytes(&new_path)); cmd.env(DATA, OsStr::from_bytes(&data)); } pub fn maybe_run_defined_function() { let Some(program_name) = std::env::args_os().next() else { return; }; let program_name = program_name.as_bytes(); if program_name.contains(&b'/') { return; } if program_name == b"pish" { return; } let Some(data) = std::env::var_os(DATA) else { return; }; let data = data.as_bytes().to_vec(); let mut it = data.split(|x| *x == b':'); let socket = it.next().unwrap(); let uid: uid_t = String::from_utf8_lossy(it.next().unwrap()).parse().unwrap(); let gid: gid_t = String::from_utf8_lossy(it.next().unwrap()).parse().unwrap(); if let Ok(stream) = UnixStream::connect(OsStr::from_bytes(&socket)) { let _ = handle_client(stream, uid, gid); } exit(-1); } fn unique_string() -> String { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::process; use std::time::{SystemTime, UNIX_EPOCH}; let pid = process::id(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_nanos(); let mut hasher = DefaultHasher::new(); pid.hash(&mut hasher); now.hash(&mut hasher); let hash = hasher.finish(); let hex = format!("{:016x}", hash); hex[..12.min(hex.len())].to_string() } pub struct SocketRunning { bin_path: PathBuf, socket_path: PathBuf, recv: Receiver<()>, } #[must_use] struct SocketDropper { session: Arc>, } impl Drop for SocketDropper { fn drop(&mut self) { // mark socket for closing by `take`ing the socket_running value let session = &self.session; let Ok(mut se) = session.lock() else { return }; let Some(sr) = se.socket_running.take() else { return; }; // need to unlock since background thread also accesses session drop(se); // wait 1s for background to exit if let Err(e) = sr.recv.recv_timeout(Duration::from_secs(1)) { eprintln!( "background thread is still running({e:?}, session might not be cleaned up\r" ); } } } #[must_use] pub fn listen(session: Arc>) -> impl Drop { let session_id = unique_string(); let session_dir = crate::basedir::data_dir().join("session").join(session_id); let bin_dir = session_dir.join("bin"); std::fs::create_dir_all(&bin_dir).unwrap(); let socket_path = session_dir.join("cmd.sock"); let (send, recv) = mpsc::channel(); { let mut se = session.lock().unwrap(); assert!(se.socket_running.is_none()); se.socket_running = Some(SocketRunning { bin_path: bin_dir, socket_path: socket_path.clone(), recv, }); } let se = session.clone(); thread::spawn(move || { defer! { let _ = fs::remove_dir_all(session_dir); let _ = send.send(()); }; let listener = UnixListener::bind(socket_path).unwrap(); listener.set_nonblocking(true).unwrap(); let timeout_ms: u16 = 200; loop { // poll socket with timeout let mut poll_fds = [PollFd::new(listener.as_fd(), PollFlags::POLLIN)]; let is_ready = match poll(&mut poll_fds, timeout_ms) { Ok(0) => false, Ok(_) => true, Err(_) => false, }; // check if we should terminate match se.lock() { Err(_) => break, Ok(se) if se.socket_running.is_none() => break, _ => (), } if is_ready && let Ok((stream, _addr)) = listener.accept() { let se = se.clone(); thread::spawn(move || handle_server(se, stream)); } } }); SocketDropper { session } } fn create_function_hook_res( session: Arc>, fun_name: &[u8], ) -> Result<(), Box> { let session = session.lock().map_err(|e| format!("{e:?}"))?; let sock_run = session.socket_running.as_ref().ok_or("no socket running")?; let exe_path = current_exe()?; let symlink_path = sock_run.bin_path.join(OsStr::from_bytes(fun_name)); symlink(exe_path, symlink_path)?; Ok(()) } pub fn create_function_hook(session: Arc>, fun_name: &[u8]) { if let Err(e) = create_function_hook_res(session, fun_name) { println!("failed to create function hook: {e:?}"); } }