use core::fmt; use std::{ fs::File, io::{self, PipeReader, PipeWriter, Read, Write}, os::fd::{AsFd, BorrowedFd}, process::Stdio, }; use nix::poll::{PollFd, PollFlags}; pub enum Input { Stdin, Pipe(PipeReader), File(File), } pub enum Output { Stdout, Pipe(PipeWriter), File(File), } impl From for Stdio { fn from(value: Input) -> Self { match value { Input::Stdin => Stdio::inherit(), Input::Pipe(reader) => reader.into(), Input::File(file) => file.into(), } } } impl From for Stdio { fn from(value: Output) -> Stdio { match value { Output::Stdout => Stdio::inherit(), Output::Pipe(writer) => writer.into(), Output::File(file) => file.into(), } } } pub struct Canceler { tx: PipeWriter, } impl Canceler { pub fn cancel(&mut self) { let _ = self.tx.write(b"."); } } pub struct InputReader { input: Input, cancel: PipeReader, canceled: bool, } impl InputReader { pub fn new(input: Input) -> (InputReader, Canceler) { let (cancel, tx) = std::io::pipe().unwrap(); ( Self { input, cancel, canceled: false, }, Canceler { tx }, ) } } const TIMEOUT_MS: u16 = 20; enum PollStatus { Cancel, Ready, Wait, } fn check<'a>( canceled: &mut bool, cancel: &PipeReader, fd: BorrowedFd<'a>, flags: PollFlags, ) -> PollStatus { if *canceled { return PollStatus::Cancel; } let mut poll_fds = [ PollFd::new(cancel.as_fd(), PollFlags::POLLIN), PollFd::new(fd, flags), ]; if nix::poll::poll(&mut poll_fds, TIMEOUT_MS).is_err() { *canceled = true; return PollStatus::Cancel; }; if let Some(event) = poll_fds[0].revents() { if event.contains(PollFlags::POLLIN) { *canceled = true; return PollStatus::Cancel; } } if let Some(event) = poll_fds[1].revents() { if event.contains(flags) { return PollStatus::Ready; } } PollStatus::Wait } impl InputReader { fn cancel(&mut self) -> PollStatus { self.canceled = true; PollStatus::Cancel } fn poll(&mut self) -> PollStatus { let stdin = io::stdin(); let read_fd = match &self.input { Input::Stdin => stdin.as_fd(), Input::Pipe(pipe) => pipe.as_fd(), Input::File(file) => file.as_fd(), }; check(&mut self.canceled, &self.cancel, read_fd, PollFlags::POLLIN) } } #[derive(Debug, Clone, Copy)] struct Canceled; impl fmt::Display for Canceled { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "canceled") } } impl std::error::Error for Canceled {} impl Read for InputReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { match self.poll() { PollStatus::Cancel => return Err(io::Error::new(io::ErrorKind::Other, Canceled)), PollStatus::Ready => (), PollStatus::Wait => continue, } return match &mut self.input { Input::Stdin => io::stdin().read(buf), Input::Pipe(reader) => reader.read(buf), Input::File(file) => file.read(buf), }; } } } pub struct OutputWriter { output: Output, cancel: PipeReader, canceled: bool, } impl OutputWriter { pub fn new(output: Output) -> (Self, Canceler) { let (cancel, tx) = std::io::pipe().unwrap(); ( Self { output, cancel, canceled: false, }, Canceler { tx }, ) } fn poll(&mut self) -> PollStatus { let stdout = io::stdout(); let write_fd = match &self.output { Output::Stdout => stdout.as_fd(), Output::Pipe(pipe) => pipe.as_fd(), Output::File(file) => file.as_fd(), }; check( &mut self.canceled, &self.cancel, write_fd, PollFlags::POLLOUT, ) } } impl Write for OutputWriter { fn write(&mut self, buf: &[u8]) -> io::Result { loop { match self.poll() { PollStatus::Cancel => return Err(io::Error::new(io::ErrorKind::Other, Canceled)), PollStatus::Ready => (), PollStatus::Wait => continue, } return match &mut self.output { Output::Stdout => io::stdout().write(buf), Output::Pipe(writer) => writer.write(buf), Output::File(file) => file.write(buf), }; } } fn flush(&mut self) -> io::Result<()> { match &mut self.output { Output::Stdout => io::stdout().flush(), Output::Pipe(writer) => writer.flush(), Output::File(file) => file.flush(), } } } impl From for Input { fn from(value: InputReader) -> Self { value.input } } impl From for Output { fn from(value: OutputWriter) -> Self { value.output } }