use core::fmt; use std::{ fs::File, io::{self, PipeReader, PipeWriter, Read, Write}, os::fd::{AsFd, BorrowedFd}, process::Stdio, sync::{ Arc, atomic::{AtomicBool, Ordering::SeqCst}, }, }; use nix::poll::{PollFd, PollFlags}; pub enum Input { Stdin, Pipe(PipeReader), File(File), } pub enum Output { Stdout, Pipe(PipeWriter), File(File), } impl From for Stdio { fn from(value: Input) -> Self { match value { Input::Stdin => Stdio::inherit(), Input::Pipe(reader) => reader.into(), Input::File(file) => file.into(), } } } impl From for Stdio { fn from(value: Output) -> Stdio { match value { Output::Stdout => Stdio::inherit(), Output::Pipe(writer) => writer.into(), Output::File(file) => file.into(), } } } impl Input { pub fn try_clone(&self) -> io::Result { Ok(match self { Input::Stdin => Input::Stdin, Input::Pipe(pr) => Input::Pipe(pr.try_clone()?), Input::File(f) => Input::File(f.try_clone()?), }) } } impl Output { pub fn try_clone(&self) -> io::Result { Ok(match self { Output::Stdout => Output::Stdout, Output::Pipe(pw) => Output::Pipe(pw.try_clone()?), Output::File(f) => Output::File(f.try_clone()?), }) } } pub struct Canceler { tx: PipeWriter, } impl Canceler { pub fn cancel(&mut self) { let _ = self.tx.write(b"."); } } pub struct InputReader { input: Input, cancel: PipeReader, canceled: Arc, } impl InputReader { pub fn new(input: Input) -> (InputReader, Canceler) { let (cancel, tx) = std::io::pipe().unwrap(); ( Self { input, cancel, canceled: Arc::new(AtomicBool::new(false)), }, Canceler { tx }, ) } pub fn try_clone(&self) -> io::Result { let input = self.input.try_clone()?; let cancel = self.cancel.try_clone()?; let canceled = self.canceled.clone(); Ok(Self { input, cancel, canceled, }) } } const TIMEOUT_MS: u16 = 1000; enum PollStatus { Cancel, Ready, Wait, } fn check<'a>( canceled: &AtomicBool, cancel: &PipeReader, fd: BorrowedFd<'a>, flags: PollFlags, ) -> PollStatus { if canceled.load(SeqCst) { return PollStatus::Cancel; } let mut poll_fds = [ PollFd::new(cancel.as_fd(), PollFlags::POLLIN), PollFd::new(fd, flags), ]; if nix::poll::poll(&mut poll_fds, TIMEOUT_MS).is_err() { canceled.store(true, SeqCst); return PollStatus::Cancel; }; if let Some(event) = poll_fds[0].revents() { if event.contains(PollFlags::POLLIN) { canceled.store(true, SeqCst); return PollStatus::Cancel; } } if let Some(event) = poll_fds[1].revents() { if event.contains(flags) { return PollStatus::Ready; } } PollStatus::Wait } impl InputReader { fn poll(&mut self) -> PollStatus { let stdin = io::stdin(); let read_fd = match &self.input { Input::Stdin => stdin.as_fd(), Input::Pipe(pipe) => pipe.as_fd(), Input::File(file) => file.as_fd(), }; check(&*self.canceled, &self.cancel, read_fd, PollFlags::POLLIN) } } #[derive(Debug, Clone, Copy)] struct Canceled; impl fmt::Display for Canceled { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "canceled") } } impl std::error::Error for Canceled {} impl Read for InputReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { match self.poll() { PollStatus::Cancel => return Err(io::Error::new(io::ErrorKind::Other, Canceled)), PollStatus::Ready => (), PollStatus::Wait => continue, } return match &mut self.input { Input::Stdin => io::stdin().read(buf), Input::Pipe(reader) => reader.read(buf), Input::File(file) => file.read(buf), }; } } } pub struct OutputWriter { output: Output, cancel: PipeReader, canceled: Arc, } impl OutputWriter { pub fn new(output: Output) -> (Self, Canceler) { let (cancel, tx) = std::io::pipe().unwrap(); ( Self { output, cancel, canceled: Arc::new(AtomicBool::new(false)), }, Canceler { tx }, ) } fn poll(&mut self) -> PollStatus { let stdout = io::stdout(); let write_fd = match &self.output { Output::Stdout => stdout.as_fd(), Output::Pipe(pipe) => pipe.as_fd(), Output::File(file) => file.as_fd(), }; check( &mut self.canceled, &self.cancel, write_fd, PollFlags::POLLOUT, ) } pub fn try_clone(&self) -> io::Result { let output = self.output.try_clone()?; let cancel = self.cancel.try_clone()?; let canceled = self.canceled.clone(); Ok(Self { output, cancel, canceled, }) } } impl Write for OutputWriter { fn write(&mut self, buf: &[u8]) -> io::Result { loop { match self.poll() { PollStatus::Cancel => return Err(io::Error::new(io::ErrorKind::Other, Canceled)), PollStatus::Ready => (), PollStatus::Wait => continue, } return match &mut self.output { Output::Stdout => io::stdout().write(buf), Output::Pipe(writer) => writer.write(buf), Output::File(file) => file.write(buf), }; } } fn flush(&mut self) -> io::Result<()> { match &mut self.output { Output::Stdout => io::stdout().flush(), Output::Pipe(writer) => writer.flush(), Output::File(file) => file.flush(), } } } impl From for Input { fn from(value: InputReader) -> Self { value.input } } impl From for Output { fn from(value: OutputWriter) -> Self { value.output } }