scylla_cql/utils/
parse.rs

1//! Simple general-purpose recursive-descent parser.
2//! Used for parsing strings in the CQL protocol.
3
4use std::fmt::Display;
5
6/// An error that can occur during parsing.
7#[derive(Copy, Clone)]
8#[non_exhaustive]
9pub struct ParseError {
10    remaining: usize,
11    cause: ParseErrorCause,
12}
13
14impl ParseError {
15    /// Given the original string, returns the 1-based position
16    /// of the error in characters.
17    /// If an incorrect string was given, the function may return 0.
18    pub fn calculate_position(&self, original: &str) -> Option<usize> {
19        calculate_position(original, self.remaining)
20    }
21
22    /// Returns the error cause.
23    pub fn get_cause(&self) -> ParseErrorCause {
24        self.cause
25    }
26}
27
28/// Cause of the parsing error.
29/// Should be lightweight so that it can be quickly discarded.
30#[derive(Copy, Clone, Debug, PartialEq, Eq)]
31#[non_exhaustive]
32pub enum ParseErrorCause {
33    /// Expected a specific string, but it was not found.
34    Expected(&'static str),
35    /// Other error, described by a string.
36    Other(&'static str),
37}
38
39impl Display for ParseErrorCause {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            ParseErrorCause::Expected(e) => write!(f, "expected {e:?}"),
43            ParseErrorCause::Other(e) => f.write_str(e),
44        }
45    }
46}
47
48/// Result of a parsing operation.
49pub type ParseResult<T> = Result<T, ParseError>;
50
51/// A utility class for building simple recursive-descent parsers.
52///
53/// Basically, a wrapper over &str with nice methods that help with parsing.
54#[derive(Clone, Copy)]
55#[must_use]
56pub struct ParserState<'s> {
57    pub(crate) s: &'s str,
58}
59
60impl<'s> ParserState<'s> {
61    /// Creates a new parser from given input string.
62    pub fn new(s: &'s str) -> Self {
63        Self { s }
64    }
65
66    /// Applies given parsing function until it returns false
67    /// and returns the final parser state.
68    pub fn parse_while(
69        self,
70        mut parser: impl FnMut(Self) -> ParseResult<(bool, Self)>,
71    ) -> ParseResult<Self> {
72        let mut me = self;
73        loop {
74            let (proceed, new_me) = parser(me)?;
75            if !proceed {
76                return Ok(new_me);
77            }
78            me = new_me;
79        }
80    }
81
82    /// If the input string contains given string at the beginning,
83    /// returns a new parser state with given string skipped.
84    /// Otherwise, returns an error.
85    pub fn accept(self, part: &'static str) -> ParseResult<Self> {
86        match self.s.strip_prefix(part) {
87            Some(s) => Ok(Self { s }),
88            None => Err(self.error(ParseErrorCause::Expected(part))),
89        }
90    }
91
92    /// Returns new parser state with whitespace skipped from the beginning.
93    pub fn skip_white(self) -> Self {
94        let (_, me) = self.take_while(char::is_whitespace);
95        me
96    }
97
98    /// Parses a sequence of digits as an integer.
99    /// Consumes characters until it finds a character that is not a digit.
100    ///
101    /// An error is returned if:
102    /// * The first character is not a digit
103    /// * The integer is larger than u16
104    pub fn parse_u16(self) -> ParseResult<(u16, Self)> {
105        let (digits, p) = self.take_while(|c| c.is_ascii_digit());
106        if let Ok(value) = digits.parse() {
107            Ok((value, p))
108        } else {
109            Err(p.error(ParseErrorCause::Other("Expected 16-bit unsigned integer")))
110        }
111    }
112
113    /// Skips characters from the beginning while they satisfy given predicate
114    /// and returns new parser state which
115    pub fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) {
116        let idx = self.s.find(move |c| !pred(c)).unwrap_or(self.s.len());
117        let new = Self { s: &self.s[idx..] };
118        (&self.s[..idx], new)
119    }
120
121    /// Returns the number of remaining bytes to parse.
122    pub fn get_remaining(self) -> usize {
123        self.s.len()
124    }
125
126    /// Returns true if the input string was parsed completely.
127    pub fn is_at_eof(self) -> bool {
128        self.s.is_empty()
129    }
130
131    /// Returns an error with given cause, associated with given position.
132    pub fn error(self, cause: ParseErrorCause) -> ParseError {
133        ParseError {
134            remaining: self.get_remaining(),
135            cause,
136        }
137    }
138
139    /// Given the original string, returns the 1-based position
140    /// of the error in characters.
141    /// If an incorrect string was given, the function may return None.
142    pub fn calculate_position(self, original: &str) -> Option<usize> {
143        calculate_position(original, self.get_remaining())
144    }
145}
146
147fn calculate_position(original: &str, remaining: usize) -> Option<usize> {
148    let prefix_len = original.len().checked_sub(remaining)?;
149    let prefix = original.get(..prefix_len)?;
150    Some(prefix.chars().count() + 1)
151}