1pub mod debuglog;
9pub mod netboot;
10pub mod tftp;
11
12use thiserror::Error;
13
14struct ValidStr<B>(B);
16
17fn as_buffer_view_mut<'a, B: packet::BufferViewMut<&'a mut [u8]>>(
19 v: B,
20) -> impl packet::BufferViewMut<&'a mut [u8]> {
21 v
22}
23
24fn find_null_termination<B: zerocopy::SplitByteSlice>(b: &B) -> Option<usize> {
25 b.as_ref().iter().enumerate().find_map(|(index, c)| (*c == 0).then_some(index))
26}
27
28#[derive(Debug, Eq, PartialEq, Clone, Error)]
29pub enum ValidStrError {
30 #[error("missing null termination")]
31 NoNullTermination,
32 #[error("failed to decode: {0}")]
33 Encoding(std::str::Utf8Error),
34}
35
36impl<B> ValidStr<B>
37where
38 B: zerocopy::SplitByteSlice,
39{
40 fn new(bytes: B) -> Result<Self, std::str::Utf8Error> {
43 match std::str::from_utf8(bytes.as_ref()) {
45 Ok(_) => Ok(Self(bytes)),
46 Err(e) => Err(e),
47 }
48 }
49
50 fn truncate_null(self) -> (Self, B) {
57 let Self(bytes) = self;
58 let split = find_null_termination(&bytes).unwrap_or_else(|| bytes.as_ref().len());
59 let (bytes, rest) = bytes.split_at(split).ok().unwrap();
60 (Self(bytes), rest)
61 }
62
63 fn as_str(&self) -> &str {
64 unsafe { std::str::from_utf8_unchecked(self.0.as_ref()) }
67 }
68
69 fn new_null_terminated_from_buffer<BV: packet::BufferView<B>>(
78 buffer: &mut BV,
79 ) -> Result<Self, ValidStrError> {
80 let v = buffer.as_ref();
81 let eos = find_null_termination(&v).ok_or(ValidStrError::NoNullTermination)?;
82 let bytes = buffer.take_front(eos + 1).unwrap();
84 let (bytes, null_char) = bytes.split_at(eos).ok().unwrap();
85 debug_assert!(
88 matches!(null_char.as_ref(), [0]),
89 "bad null character value: {:?}",
90 null_char.as_ref()
91 );
92 let _ = null_char;
93 Self::new(bytes).map_err(ValidStrError::Encoding)
94 }
95}
96
97impl<B> std::fmt::Debug for ValidStr<B>
98where
99 B: zerocopy::SplitByteSlice,
100{
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
102 self.as_str().fmt(f)
103 }
104}
105
106impl<B> AsRef<str> for ValidStr<B>
107where
108 B: zerocopy::SplitByteSlice,
109{
110 fn as_ref(&self) -> &str {
111 self.as_str()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 use assert_matches::assert_matches;
120
121 #[test]
122 fn test_new_valid_str() {
123 const VALID: &'static str = "some valid string";
124 const INVALID: [u8; 2] = [0xc3, 0x28];
125 assert_eq!(
126 ValidStr::new(VALID.as_bytes()).expect("can create from valid string").as_str(),
127 VALID
128 );
129 assert_matches!(ValidStr::new(&INVALID[..]), Err(_));
130 }
131
132 #[test]
133 fn test_truncate_null() {
134 const VALID: &'static str = "some valid string\x00 rest";
135 let (trunc, rest) =
136 ValidStr::new(VALID.as_bytes()).expect("can create from valid string").truncate_null();
137 assert_eq!(trunc.as_str(), "some valid string");
138 assert_eq!(rest, "\x00 rest".as_bytes());
139 }
140
141 #[test]
142 fn test_get_from_bufer() {
143 fn make_buffer(contents: &str) -> packet::Buf<&[u8]> {
144 packet::Buf::new(contents.as_bytes(), ..)
145 }
146 fn get_from_buffer<'a>(
147 mut bv: impl packet::BufferView<&'a [u8]>,
148 ) -> (Result<ValidStr<&'a [u8]>, ValidStrError>, &'a str) {
149 let valid_str = ValidStr::new_null_terminated_from_buffer(&mut bv);
150 (valid_str, std::str::from_utf8(bv.into_rest()).unwrap())
151 }
152
153 let mut buffer = make_buffer("no null termination");
154 let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
155 assert_matches!(valid_str, Err(ValidStrError::NoNullTermination));
156 assert_eq!(rest, "no null termination");
157
158 let mut buffer = make_buffer("null\x00termination");
159 let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
160 let valid_str = valid_str.expect("can find termination");
161 assert_matches!(valid_str.as_str(), "null");
162 assert_eq!(rest, "termination");
163
164 let mut buffer = make_buffer("");
165 let (valid_str, rest) = get_from_buffer(buffer.buffer_view());
166 assert_matches!(valid_str, Err(ValidStrError::NoNullTermination));
167 assert_eq!(rest, "");
168 }
169}