netsvc_proto/
lib.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! This crate contains facilities to interact with netsvc over the
6//! network.
7
8pub mod debuglog;
9pub mod netboot;
10pub mod tftp;
11
12use thiserror::Error;
13
14/// A witness type for a valid string backed by a [`zerocopy::SplitByteSlice`].
15struct ValidStr<B>(B);
16
17/// Helper to convince the compiler we're holding buffer views.
18fn as_buffer_view_mut<'a, B: packet::BufferViewMut<&'a mut [u8]>>(
19    v: B,
20) -> impl packet::BufferViewMut<&'a mut [u8]> {
21    v
22}
23
24fn find_null_termination<B: zerocopy::SplitByteSlice>(b: &B) -> Option<usize> {
25    b.as_ref().iter().enumerate().find_map(|(index, c)| (*c == 0).then_some(index))
26}
27
28#[derive(Debug, Eq, PartialEq, Clone, Error)]
29pub enum ValidStrError {
30    #[error("missing null termination")]
31    NoNullTermination,
32    #[error("failed to decode: {0}")]
33    Encoding(std::str::Utf8Error),
34}
35
36impl<B> ValidStr<B>
37where
38    B: zerocopy::SplitByteSlice,
39{
40    /// Attempts to create a new `ValidStr` that wraps all the contents of
41    /// `bytes`.
42    fn new(bytes: B) -> Result<Self, std::str::Utf8Error> {
43        // NB: map doesn't work here because of lifetimes.
44        match std::str::from_utf8(bytes.as_ref()) {
45            Ok(_) => Ok(Self(bytes)),
46            Err(e) => Err(e),
47        }
48    }
49
50    /// Splits this `ValidStr` into a valid string up to the first null
51    /// character and the rest of the internal container if there is one.
52    ///
53    /// The returned `ValidStr` is guaranteed to not contain a null character,
54    /// and the returned tail `ByteSlice` may either be a slice starting with a
55    /// null character or an empty slice.
56    fn truncate_null(self) -> (Self, B) {
57        let Self(bytes) = self;
58        let split = find_null_termination(&bytes).unwrap_or_else(|| bytes.as_ref().len());
59        let (bytes, rest) = bytes.split_at(split).ok().unwrap();
60        (Self(bytes), rest)
61    }
62
63    fn as_str(&self) -> &str {
64        // safety: ValidStr is a witness type for a valid UTF8 string that
65        // keeps the byte slice reference
66        unsafe { std::str::from_utf8_unchecked(self.0.as_ref()) }
67    }
68
69    /// Attempts to create a new `ValidStr` from the provided `BufferView`,
70    /// consuming the buffer until the first null termination character.
71    ///
72    /// The returned `ValidStr` will not contain the null character, but the
73    /// null character will be consumed from `buffer`.
74    ///
75    /// Note that the bytes might be consumed from the buffer view even in case
76    /// of errors.
77    fn new_null_terminated_from_buffer<BV: packet::BufferView<B>>(
78        buffer: &mut BV,
79    ) -> Result<Self, ValidStrError> {
80        let v = buffer.as_ref();
81        let eos = find_null_termination(&v).ok_or(ValidStrError::NoNullTermination)?;
82        // Unwrap is safe, we just found null termination above.
83        let bytes = buffer.take_front(eos + 1).unwrap();
84        let (bytes, null_char) = bytes.split_at(eos).ok().unwrap();
85        // TODO(https://github.com/rust-lang/rust/issues/82775): Use
86        // debug_assert_matches from std when available.
87        debug_assert!(
88            matches!(null_char.as_ref(), [0]),
89            "bad null character value: {:?}",
90            null_char.as_ref()
91        );
92        let _ = null_char;
93        Self::new(bytes).map_err(ValidStrError::Encoding)
94    }
95}
96
97impl<B> std::fmt::Debug for ValidStr<B>
98where
99    B: zerocopy::SplitByteSlice,
100{
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
102        self.as_str().fmt(f)
103    }
104}
105
106impl<B> AsRef<str> for ValidStr<B>
107where
108    B: zerocopy::SplitByteSlice,
109{
110    fn as_ref(&self) -> &str {
111        self.as_str()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    use assert_matches::assert_matches;
120
121    #[test]
122    fn test_new_valid_str() {
123        const VALID: &'static str = "some valid string";
124        const INVALID: [u8; 2] = [0xc3, 0x28];
125        assert_eq!(
126            ValidStr::new(VALID.as_bytes()).expect("can create from valid string").as_str(),
127            VALID
128        );
129        assert_matches!(ValidStr::new(&INVALID[..]), Err(_));
130    }
131
132    #[test]
133    fn test_truncate_null() {
134        const VALID: &'static str = "some valid string\x00 rest";
135        let (trunc, rest) =
136            ValidStr::new(VALID.as_bytes()).expect("can create from valid string").truncate_null();
137        assert_eq!(trunc.as_str(), "some valid string");
138        assert_eq!(rest, "\x00 rest".as_bytes());
139    }
140
141    #[test]
142    fn test_get_from_bufer() {
143        fn make_buffer(contents: &str) -> packet::Buf<&[u8]> {
144            packet::Buf::new(contents.as_bytes(), ..)
145        }
146        fn get_from_buffer<'a>(
147            mut bv: impl packet::BufferView<&'a [u8]>,
148        ) -> (Result<ValidStr<&'a [u8]>, ValidStrError>, &'a str) {
149            let valid_str = ValidStr::new_null_terminated_from_buffer(&mut bv);
150            (valid_str, std::str::from_utf8(bv.into_rest()).unwrap())
151        }
152
153        let mut buffer = make_buffer("no null termination");
154        let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
155        assert_matches!(valid_str, Err(ValidStrError::NoNullTermination));
156        assert_eq!(rest, "no null termination");
157
158        let mut buffer = make_buffer("null\x00termination");
159        let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
160        let valid_str = valid_str.expect("can find termination");
161        assert_matches!(valid_str.as_str(), "null");
162        assert_eq!(rest, "termination");
163
164        let mut buffer = make_buffer("");
165        let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
166        assert_matches!(valid_str, Err(ValidStrError::NoNullTermination));
167        assert_eq!(rest, "");
168    }
169}