aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/parse/mod.rs19
-rw-r--r--src/parse/regex/byte_range.rs179
-rw-r--r--src/parse/regex/dfa.rs110
-rw-r--r--src/parse/regex/enfa.rs383
-rw-r--r--src/parse/regex/mod.rs201
-rw-r--r--src/run/builtin.rs31
-rw-r--r--src/run/mod.rs7
7 files changed, 914 insertions, 16 deletions
diff --git a/src/parse/mod.rs b/src/parse/mod.rs
index 68a5e56..5815730 100644
--- a/src/parse/mod.rs
+++ b/src/parse/mod.rs
@@ -7,6 +7,8 @@ mod test;
mod span;
+pub mod regex;
+
pub trait Stage: PartialEq {
type Str: std::fmt::Debug + Clone + PartialEq;
}
@@ -1844,7 +1846,7 @@ impl Parse for Pipes<PreExpansion> {
#[derive(Debug, Clone, PartialEq)]
pub struct CaseBranch {
- pub pattern: BString,
+ pub pattern: regex::Pattern,
pub block: Block,
}
@@ -1855,10 +1857,8 @@ pub struct Case<T: Stage> {
}
impl CmdDisplay for CaseBranch {
- fn cdisplay(&self, w: &mut dyn std::io::Write) -> std::io::Result<()> {
- write!(w, "cbranch(b\"{}\", ", self.pattern.escape_ascii())?;
- self.block.cdisplay(w)?;
- write!(w, ")")
+ fn cdisplay(&self, _w: &mut dyn std::io::Write) -> std::io::Result<()> {
+ todo!()
}
}
@@ -1875,14 +1875,7 @@ impl Parse for CaseBranch {
fn parse(b: &mut Cursor<'_>) -> Result<Self> {
b.spaces();
- let mut pattern = Vec::new();
- while b.has() && b.peek() != b'{' {
- pattern.push(b.adv());
- }
- while let Some(b' ' | b'\n' | b'\t' | b'\r') = pattern.last() {
- pattern.pop();
- }
-
+ let pattern = regex::Pattern::parse(b)?;
let block = Block::parse(b)?;
Ok(Self { pattern, block })
diff --git a/src/parse/regex/byte_range.rs b/src/parse/regex/byte_range.rs
new file mode 100644
index 0000000..1ca6d8f
--- /dev/null
+++ b/src/parse/regex/byte_range.rs
@@ -0,0 +1,179 @@
+use std::ops::RangeInclusive;
+
+#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub struct ByteRange {
+ /// inclusive
+ from: u8,
+ /// inclusive
+ to: u8,
+}
+
+impl From<RangeInclusive<u8>> for ByteRange {
+ fn from(value: RangeInclusive<u8>) -> Self {
+ Self::new_range(*value.start(), *value.end())
+ }
+}
+
+impl ByteRange {
+ pub fn new_range(from: u8, to: u8) -> Self {
+ assert!(from <= to);
+ Self { from, to }
+ }
+
+ #[cfg(test)]
+ pub fn new_single(c: u8) -> Self {
+ Self::new_range(c, c)
+ }
+
+ pub fn contains(&self, c: u8) -> bool {
+ self.from <= c && c <= self.to
+ }
+
+ pub fn overlaps(&self, other: Self) -> bool {
+ self.from.max(other.from) <= self.to.min(other.to)
+ }
+
+ pub fn non_overlapping(sets: Vec<ByteRange>) -> Vec<ByteRange> {
+ let begins = sets.iter().map(|cs| (cs.from, 1));
+ let ends = sets.iter().map(|cs| (cs.to, 2));
+ let mut edges: Vec<_> = begins.chain(ends).collect();
+ edges.sort();
+ edges.iter_mut().for_each(|c| {
+ if c.1 == 2 {
+ c.1 = -1;
+ }
+ });
+
+ let mut last = None;
+ let mut depth = 0;
+ let mut out = Vec::new();
+
+ for (mut loc, delta) in edges {
+ if let Some(last) = last {
+ if last <= loc {
+ out.push(ByteRange::new_range(last, loc));
+ loc = loc + 1;
+ }
+ }
+
+ depth += delta;
+
+ if depth > 0 {
+ last = Some(loc);
+ } else {
+ last = None;
+ }
+ }
+
+ out
+ }
+}
+
+impl std::fmt::Debug for ByteRange {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ if self.from == self.to {
+ write!(f, "{}", [self.from].escape_ascii())
+ } else {
+ write!(
+ f,
+ "{}-{}",
+ [self.from].escape_ascii(),
+ [self.to].escape_ascii()
+ )
+ }
+ }
+}
+
+#[cfg(test)]
+mod non_overlapping_tests {
+ use std::ops::RangeInclusive;
+
+ use super::ByteRange;
+
+ fn middle(r: ByteRange) -> u8 {
+ let a = r.from as u8;
+ let b = r.to as u8;
+ (a + (b - a) / 2) as u8
+ }
+
+ fn prev(c: u8) -> u8 {
+ c - 1
+ }
+
+ fn next(c: u8) -> u8 {
+ c + 1
+ }
+
+ fn run(ranges: Vec<RangeInclusive<u8>>) {
+ let ranges1: Vec<ByteRange> = ranges.into_iter().map(Into::into).collect();
+ let ranges2 = ByteRange::non_overlapping(ranges1.clone());
+
+ let r1 = |c| ranges1.iter().any(|cr| cr.contains(c));
+ let r2 = |c| ranges2.iter().any(|cr| cr.contains(c));
+
+ for &range in ranges1.iter() {
+ assert!(r1(range.from));
+ assert!(r1(range.to));
+ assert!(r1(middle(range)));
+
+ assert!(r2(range.from));
+ assert!(r2(range.to));
+ assert!(r2(middle(range)));
+
+ assert_eq!(r1(prev(range.from)), r2(prev(range.from)));
+ assert_eq!(r1(next(range.from)), r2(next(range.from)));
+ }
+
+ for i in 0..ranges2.len() {
+ for j in 0..i {
+ assert!(
+ !ranges2[i].overlaps(ranges2[j]),
+ "{i} and {j} overlap: {:?}, {:?}",
+ ranges2[i],
+ ranges2[j]
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn overlap_correct() {
+ assert!(ByteRange::new_range(b'a', b'g').overlaps(ByteRange::new_single(b'f')));
+ assert!(!ByteRange::new_range(b'a', b'g').overlaps(ByteRange::new_single(b'h')));
+ }
+
+ #[test]
+ fn empty() {
+ run(vec![]);
+ }
+
+ #[test]
+ fn singleton() {
+ run(vec![b'0'..=b'9']);
+ }
+
+ #[test]
+ fn contained1() {
+ run(vec![b'0'..=b'9', b'5'..=b'6']);
+ }
+
+ #[test]
+ fn contained2() {
+ run(vec![b'5'..=b'6', b'0'..=b'9']);
+ }
+
+ #[test]
+ fn overlap2() {
+ run(vec![b'1'..=b'6', b'4'..=b'9'])
+ }
+
+ #[test]
+ fn overlap3() {
+ run(vec![b'a'..=b'f', b'd'..=b'j', b'g'..=b'm'])
+ }
+
+ #[test]
+ fn overlap4() {
+ run(vec![b'a'..=b'f', b'd'..=b'j', b'g'..=b'm', b'k'..=b'q'])
+ }
+}
diff --git a/src/parse/regex/dfa.rs b/src/parse/regex/dfa.rs
new file mode 100644
index 0000000..aba6238
--- /dev/null
+++ b/src/parse/regex/dfa.rs
@@ -0,0 +1,110 @@
+use core::fmt;
+use std::collections::HashMap;
+
+use super::{
+ byte_range::ByteRange,
+ enfa::{ENFA, MultiState},
+};
+
+pub type StateId = usize;
+
+pub struct State {
+ trans: HashMap<ByteRange, StateId>,
+ default_trans: StateId,
+ accept: bool,
+}
+
+pub struct DFA {
+ start: StateId,
+ states: Vec<State>,
+}
+
+impl fmt::Debug for DFA {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ writeln!(f, "DFA {{")?;
+ for (i, s) in self.states.iter().enumerate() {
+ if self.start == i {
+ write!(f, "-> {i}: ")?;
+ } else {
+ write!(f, " {i}: ")?;
+ }
+
+ for (chr, to) in s.trans.iter() {
+ write!(f, "{chr:?} to {to}, ")?;
+ }
+
+ write!(f, "default to {}", s.default_trans)?;
+ if s.accept {
+ write!(f, ", accept")?;
+ }
+ writeln!(f)?;
+ }
+ writeln!(f, "}}")
+ }
+}
+
+impl DFA {
+ pub fn matches(&self, x: &[u8]) -> bool {
+ let mut state = self.start;
+ 'next_byte: for &b in x.iter() {
+ for (range, &next_state) in self.states[state].trans.iter() {
+ if range.contains(b) {
+ state = next_state;
+ continue 'next_byte;
+ }
+ }
+ state = self.states[state].default_trans;
+ }
+ self.states[state].accept
+ }
+}
+
+impl From<ENFA> for DFA {
+ fn from(mut nfa: ENFA) -> Self {
+ nfa.simplify();
+
+ for s in nfa.states.iter() {
+ if !s.epsilon_trans.is_empty() {
+ panic!(
+ "NFA simplification did not remove epsilon transitions - cannot proceed with powerset construction."
+ );
+ }
+ }
+
+ let mut multi_states = nfa.all_multi_states();
+ multi_states.insert(nfa.void_multi_state());
+ let mut len = 0;
+ let multi_to_dfa: HashMap<MultiState, StateId> = multi_states
+ .clone()
+ .into_iter()
+ .map(|ms| {
+ len += 1;
+ (ms, len - 1)
+ })
+ .collect();
+
+ let void = multi_to_dfa[&nfa.void_multi_state()];
+
+ let mut states: Vec<State> = (0..len)
+ .map(|_| State {
+ trans: HashMap::new(),
+ default_trans: void,
+ accept: false,
+ })
+ .collect();
+
+ for ms in multi_states.iter() {
+ let i: usize = multi_to_dfa[&ms];
+ states[i].accept = ms.accept();
+ for t in ms.possible_transitions() {
+ let k = multi_to_dfa[&ms.transition(t)];
+ states[i].trans.insert(t, k);
+ }
+ }
+
+ Self {
+ start: multi_to_dfa[&nfa.start_multi_state()],
+ states,
+ }
+ }
+}
diff --git a/src/parse/regex/enfa.rs b/src/parse/regex/enfa.rs
new file mode 100644
index 0000000..71998c9
--- /dev/null
+++ b/src/parse/regex/enfa.rs
@@ -0,0 +1,383 @@
+use std::{
+ collections::HashSet,
+ hash::{DefaultHasher, Hash, Hasher},
+};
+
+use super::Pattern;
+use super::byte_range::ByteRange;
+
+/// NFA with epsilon transitions
+#[derive(Clone)]
+pub struct ENFA {
+ pub states: Vec<EState>,
+}
+
+#[derive(Clone)]
+pub struct MultiState<'a> {
+ nfa: &'a ENFA,
+ states: Vec<StateId>,
+ accept: bool,
+ hash: u64,
+}
+
+impl<'a> PartialEq for MultiState<'a> {
+ fn eq(&self, other: &Self) -> bool {
+ (self.nfa as *const ENFA as u64) == (other.nfa as *const ENFA as u64)
+ && self.states == other.states
+ && self.accept == other.accept
+ && self.hash == other.hash
+ }
+}
+impl<'a> Eq for MultiState<'a> {}
+
+impl<'a> MultiState<'a> {
+ pub fn new(nfa: &'a ENFA, mut states: Vec<StateId>) -> Self {
+ states.sort();
+ states.dedup();
+ states.shrink_to_fit();
+
+ let accept = states.iter().any(|&x| nfa.states[x].accept);
+ let mut hasher = DefaultHasher::new();
+ states.hash(&mut hasher);
+ let hash = hasher.finish();
+
+ Self {
+ nfa,
+ states,
+ accept,
+ hash,
+ }
+ }
+
+ /// all the chars that will make an interesting transition
+ pub fn possible_transitions(&self) -> Vec<ByteRange> {
+ let mut vec: Vec<_> = self
+ .states
+ .iter()
+ .flat_map(|&i| self.nfa.states[i].trans.iter().map(|x| x.0))
+ .collect();
+ vec = ByteRange::non_overlapping(vec);
+ vec.sort();
+ vec.dedup();
+ vec.shrink_to_fit();
+ vec
+ }
+
+ pub fn transition(&self, ch: ByteRange) -> Self {
+ let new_states = self
+ .states
+ .iter()
+ .flat_map(|&s| {
+ self.nfa.states[s]
+ .trans
+ .iter()
+ .filter_map(|&(c, k)| if c.overlaps(ch) { Some(k) } else { None })
+ })
+ .collect();
+
+ Self::new(self.nfa, new_states)
+ }
+
+ pub fn accept(&self) -> bool {
+ self.accept
+ }
+}
+
+impl<'a> Hash for MultiState<'a> {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.hash.hash(state)
+ }
+}
+
+macro_rules! set {
+ () => {
+ std::collections::HashSet::new()
+ };
+ ( $( $x:expr ),* ) => {{
+ let mut set = std::collections::HashSet::new();
+ $(
+ set.insert($x);
+ )*
+ set
+ }};
+}
+
+impl ENFA {
+ fn shift(self, amt: usize) -> Vec<EState> {
+ let mut s = self.states;
+
+ for state in s.iter_mut() {
+ let trans = state.trans.iter().map(|(c, id)| (*c, id + amt)).collect();
+ let epsilon_trans = state.epsilon_trans.iter().map(|e| e + amt).collect();
+
+ *state = EState {
+ trans,
+ epsilon_trans,
+ accept: false,
+ }
+ }
+
+ s
+ }
+
+ fn epsilon_dfs(&self, i: StateId, visited: &mut [bool]) {
+ if visited[i] {
+ return;
+ }
+ visited[i] = true;
+ for &k in self.states[i].epsilon_trans.iter() {
+ self.epsilon_dfs(k, visited);
+ }
+ }
+
+ fn resolve_epsilon(&mut self) {
+ // state X --> { state Y, Z, W which get inlined into X }
+ let includes: Vec<Vec<StateId>> = (0..self.states.len())
+ .map(|i| {
+ let mut reach = vec![false; self.states.len()];
+ self.epsilon_dfs(i, &mut reach);
+ reach
+ .into_iter()
+ .enumerate()
+ .filter_map(|(x, r)| if r { Some(x) } else { None })
+ .collect()
+ })
+ .collect();
+
+ // inlining
+ for i in 0..self.states.len() {
+ for &k in includes[i].iter() {
+ let new = self.states[k].trans.clone();
+ self.states[i].trans.extend(new);
+ self.states[i].epsilon_trans.clear();
+ if self.states[k].accept {
+ self.states[i].accept = true;
+ }
+ }
+ }
+ }
+
+ fn remove_unreachable(&mut self) {
+ let mut used = vec![false; self.states.len()];
+ used[0] = true;
+ for s in self.states.iter() {
+ for &i in s.epsilon_trans.iter() {
+ used[i] = true;
+ }
+ for &(_, i) in s.trans.iter() {
+ used[i] = true;
+ }
+ }
+ let mut remap = vec![0; self.states.len()];
+ let mut shift = 0;
+ for i in 0..self.states.len() {
+ if used[i] {
+ remap[i] = i - shift;
+ } else {
+ shift += 1;
+ }
+ }
+ for i in (0..self.states.len()).rev() {
+ if !used[i] {
+ self.states.remove(i);
+ }
+ }
+ for s in self.states.iter_mut() {
+ s.epsilon_trans = s
+ .epsilon_trans
+ .clone()
+ .into_iter()
+ .map(|i| remap[i])
+ .collect();
+ s.trans = s
+ .trans
+ .clone()
+ .into_iter()
+ .map(|(c, i)| (c, remap[i]))
+ .collect();
+ }
+ }
+
+ pub fn simplify(&mut self) {
+ self.resolve_epsilon();
+ self.remove_unreachable();
+ }
+
+ pub fn start_multi_state<'a>(&'a self) -> MultiState<'a> {
+ MultiState::new(self, vec![0])
+ }
+
+ pub fn void_multi_state<'a>(&'a self) -> MultiState<'a> {
+ MultiState::new(self, vec![])
+ }
+
+ pub fn all_multi_states<'a>(&'a self) -> HashSet<MultiState<'a>> {
+ let mut states = set![self.start_multi_state()];
+ let mut q = vec![self.start_multi_state()];
+
+ while let Some(state) = q.pop() {
+ let chars = state.possible_transitions();
+
+ for chr in chars {
+ let new = state.transition(chr);
+
+ if !states.contains(&new) {
+ states.insert(new.clone());
+ q.push(new);
+ }
+ }
+ }
+
+ states
+ }
+
+ fn looping(self) -> Self {
+ let mut states = vec![EState::start()];
+ states.append(&mut self.shift(1));
+ let len = states.len();
+ states[0].epsilon_trans = set![1, len];
+ states[len - 1].epsilon_trans = set![0, len];
+ states.push(EState::terminal());
+ Self { states }
+ }
+
+ fn repeat(self, times: usize) -> Self {
+ let reps = vec![self; times];
+ Self::concat(reps)
+ }
+
+ /// between 0 and x repetitions
+ fn optx(self, x: usize) -> Self {
+ let len = self.states.len();
+ let mut repped = self.repeat(x);
+ assert_eq!(repped.states.len(), x * len);
+ for i in 1..=x {
+ repped.states[0].epsilon_trans.insert(i * len - 1);
+ }
+ repped
+ }
+
+ fn concat(nfas: Vec<Self>) -> Self {
+ if nfas.is_empty() {
+ return Self {
+ states: vec![EState::terminal()],
+ };
+ }
+
+ let mut states: Vec<EState> = Vec::new();
+ for nfa in nfas.into_iter() {
+ let len = states.len();
+ let mut ns = nfa.shift(len);
+ if let Some(n) = states.last_mut() {
+ n.epsilon_trans = set![len];
+ }
+ states.append(&mut ns);
+ }
+
+ let len = states.len();
+ states[len - 1].accept = true;
+
+ Self { states }
+ }
+}
+
+impl std::fmt::Debug for ENFA {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ writeln!(f, "NFA {{")?;
+ for (i, s) in self.states.iter().enumerate() {
+ write!(f, " {i}: ")?;
+ for k in s.epsilon_trans.iter() {
+ write!(f, "~>{k} ")?;
+ }
+ for (c, k) in s.trans.iter() {
+ write!(f, "{c:?}=>{k} ")?;
+ }
+ if s.accept {
+ write!(f, "accept")?;
+ }
+ writeln!(f)?;
+ }
+ write!(f, "}}")
+ }
+}
+
+pub type StateId = usize;
+
+#[derive(Debug, Clone)]
+pub struct EState {
+ pub trans: HashSet<(ByteRange, StateId)>,
+ pub epsilon_trans: HashSet<StateId>,
+ pub accept: bool,
+}
+
+impl EState {
+ fn start() -> Self {
+ Self {
+ trans: HashSet::new(),
+ epsilon_trans: HashSet::new(),
+ accept: false,
+ }
+ }
+ fn terminal() -> Self {
+ Self {
+ trans: HashSet::new(),
+ epsilon_trans: HashSet::new(),
+ accept: true,
+ }
+ }
+}
+
+impl From<Pattern> for ENFA {
+ fn from(value: Pattern) -> Self {
+ match value {
+ Pattern::Byte(c) => Self::from(Pattern::Range(c, c)),
+ Pattern::Range(c1, c2) => Self {
+ states: vec![
+ EState {
+ trans: set![(ByteRange::new_range(c1, c2), 1)],
+ epsilon_trans: set![],
+ accept: false,
+ },
+ EState::terminal(),
+ ],
+ },
+ Pattern::Alt(alts) => {
+ let nfas: Vec<ENFA> = alts.into_iter().map(ENFA::from).collect();
+ let mut states = vec![EState::start()];
+ let mut ends = vec![];
+ for nfa in nfas.into_iter() {
+ let len = states.len();
+ states[0].epsilon_trans.insert(len);
+ states.append(&mut (nfa.shift(len)));
+ ends.push(states.len() - 1);
+ }
+ states.push(EState::terminal());
+ for end in ends.into_iter() {
+ let last = states.len() - 1;
+ states[end].epsilon_trans.insert(last);
+ }
+ Self { states }
+ }
+ Pattern::Concat(seq) => {
+ let nfas: Vec<Self> = seq.into_iter().map(ENFA::from).collect();
+ Self::concat(nfas)
+ }
+ Pattern::Rep(regex, min, None) => {
+ let nfa = ENFA::from(*regex);
+ let base = nfa.clone().repeat(min as usize);
+ let tail = nfa.looping();
+ Self::concat(vec![base, tail])
+ }
+ Pattern::Rep(regex, min, Some(max)) => {
+ assert!(min < max);
+ let nfa = Self::from(*regex);
+ let base = nfa.clone().repeat(min as usize);
+ let tail = nfa.optx((max - min) as usize);
+ Self::concat(vec![base, tail])
+ }
+ Pattern::Nothing => Self {
+ states: vec![EState::terminal()],
+ },
+ }
+ }
+}
diff --git a/src/parse/regex/mod.rs b/src/parse/regex/mod.rs
new file mode 100644
index 0000000..1c761a1
--- /dev/null
+++ b/src/parse/regex/mod.rs
@@ -0,0 +1,201 @@
+use super::{Parse, ParseError, Result};
+
+mod byte_range;
+mod dfa;
+mod enfa;
+
+#[derive(PartialEq, Debug, Clone)]
+pub enum Pattern {
+ Byte(u8),
+ Range(u8, u8),
+ Alt(Vec<Pattern>),
+ Concat(Vec<Pattern>),
+ Rep(Box<Pattern>, u32, Option<u32>),
+ Nothing,
+}
+
+impl Parse for Pattern {
+ fn parse(b: &mut super::Cursor<'_>) -> super::Result<Self> {
+ parse_alt(b)
+ }
+}
+
+fn parse_alt(s: &mut super::Cursor<'_>) -> Result<Pattern> {
+ let mut seqs = vec![];
+ loop {
+ let seq = parse_seq(s)?;
+ if seq != Pattern::Nothing {
+ seqs.push(seq);
+ }
+ if s.has() && s.peek() == b'|' {
+ s.adv();
+ } else {
+ break;
+ }
+ }
+
+ Ok(match seqs.len() {
+ 0 => Pattern::Nothing,
+ 1 => seqs.into_iter().next().unwrap(),
+ _ => Pattern::Alt(seqs),
+ })
+}
+
+fn parse_seq(s: &mut super::Cursor<'_>) -> Result<Pattern> {
+ let mut reps = vec![];
+ loop {
+ let rep = parse_rep(s)?;
+ if rep != Pattern::Nothing {
+ reps.push(rep);
+ } else {
+ break;
+ }
+ }
+
+ Ok(match reps.len() {
+ 0 => Pattern::Nothing,
+ 1 => reps.into_iter().next().unwrap(),
+ _ => Pattern::Concat(reps),
+ })
+}
+
+fn parse_rep(s: &mut super::Cursor<'_>) -> Result<Pattern> {
+ let atom = parse_atom(s)?;
+
+ if atom == Pattern::Nothing {
+ return Ok(atom);
+ }
+
+ if !s.has() {
+ return Ok(atom);
+ }
+
+ match s.peek() {
+ b'*' => {
+ s.adv();
+ Ok(Pattern::Rep(Box::new(atom), 0, None))
+ }
+ b'+' => {
+ s.adv();
+ Ok(Pattern::Rep(Box::new(atom), 1, None))
+ }
+ b'?' => {
+ s.adv();
+ Ok(Pattern::Rep(Box::new(atom), 0, Some(1)))
+ }
+ _ => Ok(atom),
+ }
+
+ // TODO: non-greedy
+}
+
+const SYMBOLS: &[u8] = b"{}[]()*+-?| ";
+fn is_symbol(x: u8) -> bool {
+ SYMBOLS.contains(&x)
+}
+
+fn parse_atom(s: &mut super::Cursor<'_>) -> Result<Pattern> {
+ if !s.has() {
+ return Ok(Pattern::Nothing);
+ }
+
+ match s.peek() {
+ b'[' => {
+ s.adv();
+ let mut ranges = Vec::new();
+ loop {
+ if !s.has() {
+ return Err(ParseError::Eof);
+ }
+
+ let tok = s.adv();
+
+ if tok == b']' {
+ if ranges.is_empty() {
+ todo!("error handling for empty alternative list");
+ }
+ return Ok(Pattern::Alt(ranges));
+ }
+
+ if is_symbol(tok) {
+ return Err(ParseError::Unknown(tok));
+ }
+
+ if s.has() && s.peek() == b'-' {
+ s.adv();
+
+ if !s.has() {
+ return Err(ParseError::Eof);
+ }
+ let tok2 = s.adv();
+
+ if is_symbol(tok2) {
+ return Err(ParseError::Unknown(tok2));
+ }
+
+ ranges.push(Pattern::Range(tok, tok2));
+ } else {
+ ranges.push(Pattern::Byte(tok));
+ }
+ }
+ }
+ b'(' => {
+ s.adv();
+ let inner = Pattern::parse(s)?;
+ if !s.has() {
+ return Err(ParseError::Eof);
+ }
+ if s.adv() != b')' {
+ return Err(ParseError::Expected(')'));
+ }
+ Ok(inner)
+ }
+ x if is_symbol(x) => Ok(Pattern::Nothing),
+ ch => {
+ s.adv();
+ Ok(Pattern::Byte(ch))
+ }
+ }
+}
+
+pub struct CompiledPattern {
+ dfa: dfa::DFA,
+}
+
+impl std::fmt::Debug for CompiledPattern {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.dfa.fmt(f)
+ }
+}
+
+impl Pattern {
+ pub fn compile(self) -> CompiledPattern {
+ let enfa = enfa::ENFA::from(self);
+ let dfa = dfa::DFA::from(enfa);
+ CompiledPattern { dfa }
+ }
+}
+
+impl CompiledPattern {
+ pub fn matches(&self, bytes: &[u8]) -> bool {
+ self.dfa.matches(bytes)
+ }
+}
+
+#[cfg(test)]
+macro_rules! regex_matches {
+ ($regex:literal, $match:literal, $true:literal) => {
+ assert_eq!(
+ Pattern::parse_from_bytes($regex.as_bytes())
+ .unwrap()
+ .compile()
+ .matches($match.as_bytes()),
+ $true
+ )
+ };
+}
+
+#[test]
+fn foo_matches_foo() {
+ regex_matches!("foo", "foo", true);
+}
diff --git a/src/run/builtin.rs b/src/run/builtin.rs
index fab7565..6f7cc5d 100644
--- a/src/run/builtin.rs
+++ b/src/run/builtin.rs
@@ -1098,7 +1098,36 @@ mod dbg {
Ok(())
}
}
+
+ #[derive(Copy, Clone)]
+ pub struct case_match;
+ impl Builtin for case_match {
+ fn name(&self) -> &str {
+ "case_match"
+ }
+
+ fn io(
+ &self,
+ _session: Arc<Mutex<Session>>,
+ args: &[BString],
+ _stdin: &mut dyn Read,
+ stdout: &mut dyn Write,
+ ) -> Result {
+ let regex = match crate::parse::regex::Pattern::parse_from_bytes(&args[0]) {
+ Ok(r) => r,
+ Err(e) => {
+ writeln!(stdout, "not a valid regex: {e:?}")?;
+ return Err(Error::Exit(1));
+ },
+ };
+
+ let compiled = regex.compile();
+ writeln!(stdout, "{compiled:?}")?;
+
+ Ok(())
+ }
+ }
}
#[cfg(debug_assertions)]
-pub use dbg::{debug, re};
+pub use dbg::{debug, re, case_match};
diff --git a/src/run/mod.rs b/src/run/mod.rs
index f86278d..c3ceb76 100644
--- a/src/run/mod.rs
+++ b/src/run/mod.rs
@@ -456,8 +456,9 @@ impl Executor {
stdout: OutputWriter,
) -> SpawnedCmd {
for branch in c.branches.into_iter() {
- // TODO: regex case patterns
- if branch.pattern == c.discriminant {
+ // TODO: do not compile every time
+ let compiled = branch.pattern.compile();
+ if compiled.matches(&c.discriminant) {
return self.execute_block(branch.block, stdin, stdout);
}
}
@@ -732,6 +733,8 @@ const BUILTINS: &[&'static dyn BuiltinClone] = &[
&builtin::logo,
&builtin::export,
&builtin::pish_theme,
+ #[cfg(debug_assertions)]
+ &builtin::case_match,
];
pub fn builtin_map() -> HashMap<BString, &'static dyn BuiltinClone> {