1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
use std::{borrow::Cow, num::TryFromIntError, ops::ControlFlow, sync::Arc};

use crate::{
    frame::{frame_errors::CqlRequestSerializationError, types::SerialConsistency},
    types::serialize::row::SerializedValues,
};
use bytes::{Buf, BufMut};
use thiserror::Error;

use crate::{
    frame::request::{RequestOpcode, SerializableRequest},
    frame::types,
};

use super::{DeserializableRequest, RequestDeserializationError};

// Query flags
const FLAG_VALUES: u8 = 0x01;
const FLAG_SKIP_METADATA: u8 = 0x02;
const FLAG_PAGE_SIZE: u8 = 0x04;
const FLAG_WITH_PAGING_STATE: u8 = 0x08;
const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
const FLAG_WITH_NAMES_FOR_VALUES: u8 = 0x40;
const ALL_FLAGS: u8 = FLAG_VALUES
    | FLAG_SKIP_METADATA
    | FLAG_PAGE_SIZE
    | FLAG_WITH_PAGING_STATE
    | FLAG_WITH_SERIAL_CONSISTENCY
    | FLAG_WITH_DEFAULT_TIMESTAMP
    | FLAG_WITH_NAMES_FOR_VALUES;

#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
pub struct Query<'q> {
    pub contents: Cow<'q, str>,
    pub parameters: QueryParameters<'q>,
}

impl SerializableRequest for Query<'_> {
    const OPCODE: RequestOpcode = RequestOpcode::Query;

    fn serialize(&self, buf: &mut Vec<u8>) -> Result<(), CqlRequestSerializationError> {
        types::write_long_string(&self.contents, buf)
            .map_err(QuerySerializationError::StatementStringSerialization)?;
        self.parameters
            .serialize(buf)
            .map_err(QuerySerializationError::QueryParametersSerialization)?;
        Ok(())
    }
}

impl DeserializableRequest for Query<'_> {
    fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
        let contents = Cow::Owned(types::read_long_string(buf)?.to_owned());
        let parameters = QueryParameters::deserialize(buf)?;

        Ok(Self {
            contents,
            parameters,
        })
    }
}

#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
pub struct QueryParameters<'a> {
    pub consistency: types::Consistency,
    pub serial_consistency: Option<types::SerialConsistency>,
    pub timestamp: Option<i64>,
    pub page_size: Option<i32>,
    pub paging_state: PagingState,
    pub skip_metadata: bool,
    pub values: Cow<'a, SerializedValues>,
}

impl Default for QueryParameters<'_> {
    fn default() -> Self {
        Self {
            consistency: Default::default(),
            serial_consistency: None,
            timestamp: None,
            page_size: None,
            paging_state: PagingState::start(),
            skip_metadata: false,
            values: Cow::Borrowed(SerializedValues::EMPTY),
        }
    }
}

impl QueryParameters<'_> {
    pub fn serialize(
        &self,
        buf: &mut impl BufMut,
    ) -> Result<(), QueryParametersSerializationError> {
        types::write_consistency(self.consistency, buf);

        let paging_state_bytes = self.paging_state.as_bytes_slice();

        let mut flags = 0;
        if !self.values.is_empty() {
            flags |= FLAG_VALUES;
        }

        if self.skip_metadata {
            flags |= FLAG_SKIP_METADATA;
        }

        if self.page_size.is_some() {
            flags |= FLAG_PAGE_SIZE;
        }

        if paging_state_bytes.is_some() {
            flags |= FLAG_WITH_PAGING_STATE;
        }

        if self.serial_consistency.is_some() {
            flags |= FLAG_WITH_SERIAL_CONSISTENCY;
        }

        if self.timestamp.is_some() {
            flags |= FLAG_WITH_DEFAULT_TIMESTAMP;
        }

        buf.put_u8(flags);

        if !self.values.is_empty() {
            self.values.write_to_request(buf);
        }

        if let Some(page_size) = self.page_size {
            types::write_int(page_size, buf);
        }

        if let Some(paging_state_bytes) = paging_state_bytes {
            types::write_bytes(paging_state_bytes, buf)?;
        }

        if let Some(serial_consistency) = self.serial_consistency {
            types::write_serial_consistency(serial_consistency, buf);
        }

        if let Some(timestamp) = self.timestamp {
            types::write_long(timestamp, buf);
        }

        Ok(())
    }
}

impl QueryParameters<'_> {
    pub fn deserialize(buf: &mut &[u8]) -> Result<Self, RequestDeserializationError> {
        let consistency = types::read_consistency(buf)?;

        let flags = buf.get_u8();
        let unknown_flags = flags & (!ALL_FLAGS);
        if unknown_flags != 0 {
            return Err(RequestDeserializationError::UnknownFlags {
                flags: unknown_flags,
            });
        }
        let values_flag = (flags & FLAG_VALUES) != 0;
        let skip_metadata = (flags & FLAG_SKIP_METADATA) != 0;
        let page_size_flag = (flags & FLAG_PAGE_SIZE) != 0;
        let paging_state_flag = (flags & FLAG_WITH_PAGING_STATE) != 0;
        let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
        let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
        let values_have_names_flag = (flags & FLAG_WITH_NAMES_FOR_VALUES) != 0;

        if values_have_names_flag {
            return Err(RequestDeserializationError::NamedValuesUnsupported);
        }

        let values = Cow::Owned(if values_flag {
            SerializedValues::new_from_frame(buf)?
        } else {
            SerializedValues::new()
        });

        let page_size = page_size_flag.then(|| types::read_int(buf)).transpose()?;
        let paging_state = if paging_state_flag {
            PagingState::new_from_raw_bytes(types::read_bytes(buf)?)
        } else {
            PagingState::start()
        };
        let serial_consistency = serial_consistency_flag
            .then(|| types::read_consistency(buf))
            .transpose()?
            .map(
                |consistency| match SerialConsistency::try_from(consistency) {
                    Ok(serial_consistency) => Ok(serial_consistency),
                    Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency(
                        consistency,
                    )),
                },
            )
            .transpose()?;
        let timestamp = if default_timestamp_flag {
            Some(types::read_long(buf)?)
        } else {
            None
        };

        Ok(Self {
            consistency,
            serial_consistency,
            timestamp,
            page_size,
            paging_state,
            skip_metadata,
            values,
        })
    }
}

#[derive(Debug, Clone)]
pub enum PagingStateResponse {
    HasMorePages { state: PagingState },
    NoMorePages,
}

impl PagingStateResponse {
    /// Determines if the query has finished or it should be resumed with given
    /// [PagingState] in order to fetch next pages.
    #[inline]
    pub fn finished(&self) -> bool {
        matches!(*self, Self::NoMorePages)
    }

    pub(crate) fn new_from_raw_bytes(raw_paging_state: Option<&[u8]>) -> Self {
        match raw_paging_state {
            Some(raw_bytes) => Self::HasMorePages {
                state: PagingState::new_from_raw_bytes(raw_bytes),
            },
            None => Self::NoMorePages,
        }
    }

    /// Converts the response into [ControlFlow], signalling whether the query has finished
    /// or it should be resumed with given [PagingState] in order to fetch next pages.
    #[inline]
    pub fn into_paging_control_flow(self) -> ControlFlow<(), PagingState> {
        match self {
            Self::HasMorePages {
                state: next_page_handle,
            } => ControlFlow::Continue(next_page_handle),
            Self::NoMorePages => ControlFlow::Break(()),
        }
    }

    /// Swaps the paging state response with PagingStateResponse::NoMorePages.
    ///
    /// Only for use in driver's inner code, as an optimisation.
    #[doc(hidden)]
    pub fn take(&mut self) -> Self {
        std::mem::replace(self, Self::NoMorePages)
    }
}

/// The state of a paged query, i.e. where to resume fetching result rows
/// upon next request.
///
/// Cheaply clonable.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PagingState(Option<Arc<[u8]>>);

impl PagingState {
    /// A start state - the state of a not-yet-started paged query.
    #[inline]
    pub fn start() -> Self {
        Self(None)
    }

    /// Returns the inner representation of [PagingState].
    /// One can use this to store paging state for a longer time,
    /// and later restore it using [Self::new_from_raw_bytes].
    /// In case None is returned, this signifies
    /// [PagingState::start()] being underneath.
    #[inline]
    pub fn as_bytes_slice(&self) -> Option<&Arc<[u8]>> {
        self.0.as_ref()
    }

    /// Creates PagingState from its inner representation.
    /// One can use this to restore paging state after longer time,
    /// having previously stored it using [Self::as_bytes_slice].
    #[inline]
    pub fn new_from_raw_bytes(raw_bytes: impl Into<Arc<[u8]>>) -> Self {
        Self(Some(raw_bytes.into()))
    }
}

impl Default for PagingState {
    fn default() -> Self {
        Self::start()
    }
}

/// An error type returned when serialization of QUERY request fails.
#[non_exhaustive]
#[derive(Error, Debug, Clone)]
pub enum QuerySerializationError {
    /// Failed to serialize query parameters.
    #[error("Invalid query parameters: {0}")]
    QueryParametersSerialization(QueryParametersSerializationError),

    /// Failed to serialize the CQL statement string.
    #[error("Failed to serialize a statement content: {0}")]
    StatementStringSerialization(TryFromIntError),
}

/// An error type returned when serialization of query parameters fails.
#[non_exhaustive]
#[derive(Error, Debug, Clone)]
pub enum QueryParametersSerializationError {
    /// Failed to serialize paging state.
    #[error("Malformed paging state: {0}")]
    BadPagingState(#[from] TryFromIntError),
}