1use alloc::vec::Vec;
4use core::num::NonZeroUsize;
5use core::{fmt, mem};
6#[cfg(feature = "std")]
7use std::error::Error as StdError;
8
9use super::UnbufferedConnectionCommon;
10use crate::Error;
11use crate::client::ClientConnectionData;
12use crate::msgs::deframer::buffers::DeframerSliceBuffer;
13use crate::server::ServerConnectionData;
14
15impl UnbufferedConnectionCommon<ClientConnectionData> {
16    pub fn process_tls_records<'c, 'i>(
19        &'c mut self,
20        incoming_tls: &'i mut [u8],
21    ) -> UnbufferedStatus<'c, 'i, ClientConnectionData> {
22        self.process_tls_records_common(incoming_tls, |_| false, |_, _| unreachable!())
23    }
24}
25
26impl UnbufferedConnectionCommon<ServerConnectionData> {
27    pub fn process_tls_records<'c, 'i>(
30        &'c mut self,
31        incoming_tls: &'i mut [u8],
32    ) -> UnbufferedStatus<'c, 'i, ServerConnectionData> {
33        self.process_tls_records_common(
34            incoming_tls,
35            |conn| conn.peek_early_data().is_some(),
36            |conn, incoming_tls| ReadEarlyData::new(conn, incoming_tls).into(),
37        )
38    }
39}
40
41impl<Data> UnbufferedConnectionCommon<Data> {
42    fn process_tls_records_common<'c, 'i>(
43        &'c mut self,
44        incoming_tls: &'i mut [u8],
45        mut early_data_available: impl FnMut(&mut Self) -> bool,
46        early_data_state: impl FnOnce(&'c mut Self, &'i mut [u8]) -> ConnectionState<'c, 'i, Data>,
47    ) -> UnbufferedStatus<'c, 'i, Data> {
48        let mut buffer = DeframerSliceBuffer::new(incoming_tls);
49        let mut buffer_progress = self.core.hs_deframer.progress();
50
51        let (discard, state) = loop {
52            if early_data_available(self) {
53                break (
54                    buffer.pending_discard(),
55                    early_data_state(self, incoming_tls),
56                );
57            }
58
59            if !self
60                .core
61                .common_state
62                .received_plaintext
63                .is_empty()
64            {
65                break (
66                    buffer.pending_discard(),
67                    ReadTraffic::new(self, incoming_tls).into(),
68                );
69            }
70
71            if let Some(chunk) = self
72                .core
73                .common_state
74                .sendable_tls
75                .pop()
76            {
77                break (
78                    buffer.pending_discard(),
79                    EncodeTlsData::new(self, chunk).into(),
80                );
81            }
82
83            let deframer_output = if self
84                .core
85                .common_state
86                .has_received_close_notify
87            {
88                None
89            } else {
90                match self
91                    .core
92                    .deframe(None, buffer.filled_mut(), &mut buffer_progress)
93                {
94                    Err(err) => {
95                        buffer.queue_discard(buffer_progress.take_discard());
96                        return UnbufferedStatus {
97                            discard: buffer.pending_discard(),
98                            state: Err(err),
99                        };
100                    }
101                    Ok(r) => r,
102                }
103            };
104
105            if let Some(msg) = deframer_output {
106                let mut state =
107                    match mem::replace(&mut self.core.state, Err(Error::HandshakeNotComplete)) {
108                        Ok(state) => state,
109                        Err(e) => {
110                            buffer.queue_discard(buffer_progress.take_discard());
111                            self.core.state = Err(e.clone());
112                            return UnbufferedStatus {
113                                discard: buffer.pending_discard(),
114                                state: Err(e),
115                            };
116                        }
117                    };
118
119                match self.core.process_msg(msg, state, None) {
120                    Ok(new) => state = new,
121
122                    Err(e) => {
123                        buffer.queue_discard(buffer_progress.take_discard());
124                        self.core.state = Err(e.clone());
125                        return UnbufferedStatus {
126                            discard: buffer.pending_discard(),
127                            state: Err(e),
128                        };
129                    }
130                }
131
132                buffer.queue_discard(buffer_progress.take_discard());
133
134                self.core.state = Ok(state);
135            } else if self.wants_write {
136                break (
137                    buffer.pending_discard(),
138                    TransmitTlsData { conn: self }.into(),
139                );
140            } else if self
141                .core
142                .common_state
143                .has_received_close_notify
144                && !self.emitted_peer_closed_state
145            {
146                self.emitted_peer_closed_state = true;
147                break (buffer.pending_discard(), ConnectionState::PeerClosed);
148            } else if self
149                .core
150                .common_state
151                .has_received_close_notify
152                && self
153                    .core
154                    .common_state
155                    .has_sent_close_notify
156            {
157                break (buffer.pending_discard(), ConnectionState::Closed);
158            } else if self
159                .core
160                .common_state
161                .may_send_application_data
162            {
163                break (
164                    buffer.pending_discard(),
165                    ConnectionState::WriteTraffic(WriteTraffic { conn: self }),
166                );
167            } else {
168                break (buffer.pending_discard(), ConnectionState::BlockedHandshake);
169            }
170        };
171
172        UnbufferedStatus {
173            discard,
174            state: Ok(state),
175        }
176    }
177}
178
179#[must_use]
181#[derive(Debug)]
182pub struct UnbufferedStatus<'c, 'i, Data> {
183    pub discard: usize,
192
193    pub state: Result<ConnectionState<'c, 'i, Data>, Error>,
199}
200
201#[non_exhaustive] pub enum ConnectionState<'c, 'i, Data> {
204    ReadTraffic(ReadTraffic<'c, 'i, Data>),
209
210    PeerClosed,
225
226    Closed,
231
232    ReadEarlyData(ReadEarlyData<'c, 'i, Data>),
234
235    EncodeTlsData(EncodeTlsData<'c, Data>),
240
241    TransmitTlsData(TransmitTlsData<'c, Data>),
254
255    BlockedHandshake,
260
261    WriteTraffic(WriteTraffic<'c, Data>),
275}
276
277impl<'c, 'i, Data> From<ReadTraffic<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
278    fn from(v: ReadTraffic<'c, 'i, Data>) -> Self {
279        Self::ReadTraffic(v)
280    }
281}
282
283impl<'c, 'i, Data> From<ReadEarlyData<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
284    fn from(v: ReadEarlyData<'c, 'i, Data>) -> Self {
285        Self::ReadEarlyData(v)
286    }
287}
288
289impl<'c, Data> From<EncodeTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
290    fn from(v: EncodeTlsData<'c, Data>) -> Self {
291        Self::EncodeTlsData(v)
292    }
293}
294
295impl<'c, Data> From<TransmitTlsData<'c, Data>> for ConnectionState<'c, '_, Data> {
296    fn from(v: TransmitTlsData<'c, Data>) -> Self {
297        Self::TransmitTlsData(v)
298    }
299}
300
301impl<Data> fmt::Debug for ConnectionState<'_, '_, Data> {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            Self::ReadTraffic(..) => f.debug_tuple("ReadTraffic").finish(),
305
306            Self::PeerClosed => write!(f, "PeerClosed"),
307
308            Self::Closed => write!(f, "Closed"),
309
310            Self::ReadEarlyData(..) => f.debug_tuple("ReadEarlyData").finish(),
311
312            Self::EncodeTlsData(..) => f.debug_tuple("EncodeTlsData").finish(),
313
314            Self::TransmitTlsData(..) => f
315                .debug_tuple("TransmitTlsData")
316                .finish(),
317
318            Self::BlockedHandshake => f
319                .debug_tuple("BlockedHandshake")
320                .finish(),
321
322            Self::WriteTraffic(..) => f.debug_tuple("WriteTraffic").finish(),
323        }
324    }
325}
326
327pub struct ReadTraffic<'c, 'i, Data> {
329    conn: &'c mut UnbufferedConnectionCommon<Data>,
330    _incoming_tls: &'i mut [u8],
332
333    chunk: Option<Vec<u8>>,
336}
337
338impl<'c, 'i, Data> ReadTraffic<'c, 'i, Data> {
339    fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, _incoming_tls: &'i mut [u8]) -> Self {
340        Self {
341            conn,
342            _incoming_tls,
343            chunk: None,
344        }
345    }
346
347    pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
350        self.chunk = self
351            .conn
352            .core
353            .common_state
354            .received_plaintext
355            .pop();
356        self.chunk.as_ref().map(|chunk| {
357            Ok(AppDataRecord {
358                discard: 0,
359                payload: chunk,
360            })
361        })
362    }
363
364    pub fn peek_len(&self) -> Option<NonZeroUsize> {
368        self.conn
369            .core
370            .common_state
371            .received_plaintext
372            .peek()
373            .and_then(|ch| NonZeroUsize::new(ch.len()))
374    }
375}
376
377pub struct ReadEarlyData<'c, 'i, Data> {
379    conn: &'c mut UnbufferedConnectionCommon<Data>,
380
381    _incoming_tls: &'i mut [u8],
383
384    chunk: Option<Vec<u8>>,
387}
388
389impl<'c, 'i> ReadEarlyData<'c, 'i, ServerConnectionData> {
390    fn new(
391        conn: &'c mut UnbufferedConnectionCommon<ServerConnectionData>,
392        _incoming_tls: &'i mut [u8],
393    ) -> Self {
394        Self {
395            conn,
396            _incoming_tls,
397            chunk: None,
398        }
399    }
400
401    pub fn next_record(&mut self) -> Option<Result<AppDataRecord<'_>, Error>> {
404        self.chunk = self.conn.pop_early_data();
405        self.chunk.as_ref().map(|chunk| {
406            Ok(AppDataRecord {
407                discard: 0,
408                payload: chunk,
409            })
410        })
411    }
412
413    pub fn peek_len(&self) -> Option<NonZeroUsize> {
417        self.conn
418            .peek_early_data()
419            .and_then(|ch| NonZeroUsize::new(ch.len()))
420    }
421}
422
423pub struct AppDataRecord<'i> {
425    pub discard: usize,
430
431    pub payload: &'i [u8],
433}
434
435pub struct WriteTraffic<'c, Data> {
437    conn: &'c mut UnbufferedConnectionCommon<Data>,
438}
439
440impl<Data> WriteTraffic<'_, Data> {
441    pub fn encrypt(
446        &mut self,
447        application_data: &[u8],
448        outgoing_tls: &mut [u8],
449    ) -> Result<usize, EncryptError> {
450        self.conn
451            .core
452            .maybe_refresh_traffic_keys();
453        self.conn
454            .core
455            .common_state
456            .write_plaintext(application_data.into(), outgoing_tls)
457    }
458
459    pub fn queue_close_notify(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncryptError> {
464        self.conn
465            .core
466            .common_state
467            .eager_send_close_notify(outgoing_tls)
468    }
469
470    pub fn refresh_traffic_keys(self) -> Result<(), Error> {
482        self.conn.core.refresh_traffic_keys()
483    }
484}
485
486pub struct EncodeTlsData<'c, Data> {
488    conn: &'c mut UnbufferedConnectionCommon<Data>,
489    chunk: Option<Vec<u8>>,
490}
491
492impl<'c, Data> EncodeTlsData<'c, Data> {
493    fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, chunk: Vec<u8>) -> Self {
494        Self {
495            conn,
496            chunk: Some(chunk),
497        }
498    }
499
500    pub fn encode(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncodeError> {
505        let Some(chunk) = self.chunk.take() else {
506            return Err(EncodeError::AlreadyEncoded);
507        };
508
509        let required_size = chunk.len();
510
511        if required_size > outgoing_tls.len() {
512            self.chunk = Some(chunk);
513            Err(InsufficientSizeError { required_size }.into())
514        } else {
515            let written = chunk.len();
516            outgoing_tls[..written].copy_from_slice(&chunk);
517
518            self.conn.wants_write = true;
519
520            Ok(written)
521        }
522    }
523}
524
525pub struct TransmitTlsData<'c, Data> {
527    pub(crate) conn: &'c mut UnbufferedConnectionCommon<Data>,
528}
529
530impl<Data> TransmitTlsData<'_, Data> {
531    pub fn done(self) {
533        self.conn.wants_write = false;
534    }
535
536    pub fn may_encrypt_app_data(&mut self) -> Option<WriteTraffic<'_, Data>> {
540        if self
541            .conn
542            .core
543            .common_state
544            .may_send_application_data
545        {
546            Some(WriteTraffic { conn: self.conn })
547        } else {
548            None
549        }
550    }
551}
552
553#[derive(Debug)]
555pub enum EncodeError {
556    InsufficientSize(InsufficientSizeError),
558
559    AlreadyEncoded,
561}
562
563impl From<InsufficientSizeError> for EncodeError {
564    fn from(v: InsufficientSizeError) -> Self {
565        Self::InsufficientSize(v)
566    }
567}
568
569impl fmt::Display for EncodeError {
570    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
571        match self {
572            Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
573                f,
574                "cannot encode due to insufficient size, {required_size} bytes are required"
575            ),
576            Self::AlreadyEncoded => "cannot encode, data has already been encoded".fmt(f),
577        }
578    }
579}
580
581#[cfg(feature = "std")]
582impl StdError for EncodeError {}
583
584#[derive(Debug)]
586pub enum EncryptError {
587    InsufficientSize(InsufficientSizeError),
589
590    EncryptExhausted,
592}
593
594impl From<InsufficientSizeError> for EncryptError {
595    fn from(v: InsufficientSizeError) -> Self {
596        Self::InsufficientSize(v)
597    }
598}
599
600impl fmt::Display for EncryptError {
601    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
602        match self {
603            Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
604                f,
605                "cannot encrypt due to insufficient size, {required_size} bytes are required"
606            ),
607            Self::EncryptExhausted => f.write_str("encrypter has been exhausted"),
608        }
609    }
610}
611
612#[cfg(feature = "std")]
613impl StdError for EncryptError {}
614
615#[derive(Clone, Copy, Debug)]
617pub struct InsufficientSizeError {
618    pub required_size: usize,
620}