aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJonas Maier <jonas@x77.dev>2026-03-10 20:36:25 +0100
committerJonas Maier <jonas@x77.dev>2026-03-10 20:39:51 +0100
commitf41cdc03e1b36a65877e009065eb609caf5d1b13 (patch)
tree19af08160dfb3d7ea793449b10fb3183b0738cee
parent83608b29e3959da5e5ee2aed07db106c3b1c338f (diff)
downloadpish-f41cdc03e1b36a65877e009065eb609caf5d1b13.tar.gz
explicit & implicit socket authentication
-rw-r--r--src/export_fun.rs113
-rw-r--r--src/main.rs2
2 files changed, 95 insertions, 20 deletions
diff --git a/src/export_fun.rs b/src/export_fun.rs
index 26ab88c..8576cce 100644
--- a/src/export_fun.rs
+++ b/src/export_fun.rs
@@ -1,5 +1,7 @@
//! allow sub-programs to invoke arbitrary user-defined functions via a unix socket
+use libc::gid_t;
+use libc::uid_t;
use nix::poll::PollFd;
use nix::poll::PollFlags;
use nix::poll::poll;
@@ -24,6 +26,7 @@ use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs::symlink;
use std::os::unix::net::AncillaryData;
use std::os::unix::net::SocketAncillary;
+use std::os::unix::net::SocketCred;
use std::os::unix::net::UnixListener;
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
@@ -36,6 +39,17 @@ use std::sync::mpsc::Receiver;
use std::thread;
use std::time::Duration;
+const PATH: &str = "PATH";
+const DATA: &str = "PISH_DATA";
+
+fn get_uid() -> gid_t {
+ unsafe { libc::getuid() }
+}
+
+fn get_gid() -> gid_t {
+ unsafe { libc::getgid() }
+}
+
fn handle_server(session: Arc<Mutex<Session>>, mut stream: UnixStream) -> io::Result<()> {
// TODO: figure out how to get a reasonable limit on CLI arg len
let mut buf = [0u8; 8192];
@@ -48,14 +62,39 @@ fn handle_server(session: Arc<Mutex<Session>>, mut stream: UnixStream) -> io::Re
let bytelen = stream.recv_vectored_with_ancillary(&mut iov, &mut ancillary)?;
+ let mut explicit_auth = false;
+
for msg in ancillary.messages() {
- if let Ok(AncillaryData::ScmRights(rights)) = msg {
- for fd in rights {
- fds.push(fd);
+ let Ok(msg) = msg else { continue };
+ match msg {
+ AncillaryData::ScmRights(rights) => {
+ for fd in rights {
+ fds.push(fd);
+ }
+ }
+ AncillaryData::ScmCredentials(creds) => {
+ for cred in creds {
+ eprintln!("cred: {}/{}", cred.get_uid(), cred.get_gid());
+ if cred.get_gid() == get_gid() && cred.get_uid() == get_uid() {
+ explicit_auth = true;
+ }
+ }
}
}
}
+ let mut implicit_auth = false;
+ if let Ok(peer) = stream.peer_cred()
+ && peer.uid == get_uid()
+ && peer.gid == get_gid()
+ {
+ implicit_auth = true;
+ }
+
+ if !explicit_auth && !implicit_auth {
+ return Ok(());
+ }
+
if fds.len() != 3 {
// malformed
return Ok(());
@@ -104,11 +143,18 @@ fn handle_server(session: Arc<Mutex<Session>>, mut stream: UnixStream) -> io::Re
Ok(())
}
-fn handle_client(mut stream: UnixStream) -> io::Result<()> {
+fn handle_client(mut stream: UnixStream, uid: uid_t, gid: gid_t) -> io::Result<()> {
// give up all my file descriptors
let mut ancillary_buffer = [0; 128];
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
- ancillary.add_fds(&[0, 1, 2]);
+ assert!(ancillary.add_fds(&[0, 1, 2]));
+
+ // add credentials
+ let mut creds = SocketCred::new();
+ creds.set_pid(std::process::id() as _);
+ creds.set_uid(uid);
+ creds.set_gid(gid);
+ assert!(ancillary.add_creds(&[creds]));
// cli params
let buf = crate::serialization::serialize_cli_args();
@@ -139,27 +185,56 @@ pub fn prepare_command(session: Arc<Mutex<Session>>, cmd: &mut Command) {
return;
};
- let my_path = std::env::var_os("PATH").expect("no PATH - seriously?");
+ let my_path = std::env::var_os(PATH).expect("no PATH - seriously?");
let mut new_path = sr.bin_path.as_os_str().as_bytes().to_vec();
new_path.push(b':');
new_path.extend_from_slice(my_path.as_bytes());
- cmd.env("PATH", OsStr::from_bytes(&new_path));
- cmd.env("PISH_SOCKET", sr.socket_path.as_os_str());
+
+ // we attach uid/gid such that launched subprocess know as which user to authenticate
+ // it might be sudo-ed, which can again authenticate as the expected uid/gid
+ // a subprocess with CAP_ADMIN can also impersonate uid/gids to the best of my
+ // knowledge
+ let mut data = Vec::new();
+ data.extend_from_slice(sr.socket_path.as_os_str().as_bytes());
+ data.push(b':');
+ data.extend_from_slice(get_uid().to_string().as_bytes());
+ data.push(b':');
+ data.extend_from_slice(get_gid().to_string().as_bytes());
+
+ cmd.env(PATH, OsStr::from_bytes(&new_path));
+ cmd.env(DATA, OsStr::from_bytes(&data));
}
pub fn maybe_run_defined_function() {
- if let Some(program_name) = std::env::args_os().next() {
- let program_name = program_name.as_bytes();
- if !program_name.contains(&b'/')
- && program_name != b"pish"
- && let Some(socket) = std::env::var_os("PISH_SOCKET")
- {
- if let Ok(stream) = UnixStream::connect(socket) {
- let _ = handle_client(stream);
- }
- exit(-1);
- }
+ let Some(program_name) = std::env::args_os().next() else {
+ return;
+ };
+
+ let program_name = program_name.as_bytes();
+
+ if program_name.contains(&b'/') {
+ return;
+ }
+
+ if program_name == b"pish" {
+ return;
+ }
+
+ let Some(data) = std::env::var_os(DATA) else {
+ return;
+ };
+
+ let data = data.as_bytes().to_vec();
+ let mut it = data.split(|x| *x == b':');
+ let socket = it.next().unwrap();
+ let uid: uid_t = String::from_utf8_lossy(it.next().unwrap()).parse().unwrap();
+ let gid: gid_t = String::from_utf8_lossy(it.next().unwrap()).parse().unwrap();
+
+ if let Ok(stream) = UnixStream::connect(OsStr::from_bytes(&socket)) {
+ let _ = handle_client(stream, uid, gid);
}
+
+ exit(-1);
}
fn unique_string() -> String {
diff --git a/src/main.rs b/src/main.rs
index 1a481c2..897cdb9 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(unix_socket_ancillary_data)]
+#![feature(unix_socket_ancillary_data, peer_credentials_unix_socket)]
#![allow(clippy::needless_range_loop)]
use std::collections::HashMap;