1use packet_encoding::{Decodable, Encodable};
6use std::collections::BTreeMap;
7
8use crate::error::{Error, PacketError};
9use crate::header::{ConnectionIdentifier, Header, HeaderIdentifier, SingleResponseMode};
10
11#[derive(Clone, Debug, Default, PartialEq)]
15pub struct HeaderSet {
16 ids: BTreeMap<HeaderIdentifier, Header>,
17}
18
19impl HeaderSet {
20 pub fn new() -> Self {
21 Self { ids: BTreeMap::new() }
22 }
23
24 pub fn from_headers(headers: Vec<Header>) -> Result<Self, Error> {
25 let mut set = Self::new();
26 for header in headers {
27 set.add(header)?;
28 }
29 Ok(set)
30 }
31
32 pub fn from_header(header: Header) -> Self {
33 Self::from_headers(vec![header]).expect("single header always valid")
34 }
35
36 fn is_special_encoding_id(id: &HeaderIdentifier) -> bool {
38 use HeaderIdentifier::*;
39 match id {
40 ConnectionId | Target | Body | EndOfBody => true,
41 _ => false,
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.ids.is_empty()
47 }
48
49 pub fn contains_header(&self, id: &HeaderIdentifier) -> bool {
50 self.ids.contains_key(id)
51 }
52
53 #[cfg(test)]
54 pub fn contains_headers(&self, ids: &Vec<HeaderIdentifier>) -> bool {
55 for id in ids {
56 if !self.contains_header(id) {
57 return false;
58 }
59 }
60 true
61 }
62
63 pub fn get(&self, id: &HeaderIdentifier) -> Option<&Header> {
64 self.ids.get(id)
65 }
66
67 pub fn add(&mut self, header: Header) -> Result<(), Error> {
72 let id = header.identifier();
73 match self.get(&id) {
76 Some(h) if *h == header => return Ok(()),
77 Some(_h) => return Err(Error::AlreadyExists(id)),
78 None => {}
79 }
80
81 use HeaderIdentifier::*;
84 match id {
85 ConnectionId if self.contains_header(&Target) => {
86 return Err(Error::IncompatibleHeaders(ConnectionId, Target));
87 }
88 Target if self.contains_header(&ConnectionId) => {
89 return Err(Error::IncompatibleHeaders(Target, ConnectionId));
90 }
91 Body if self.contains_header(&EndOfBody) => {
92 return Err(Error::IncompatibleHeaders(Body, EndOfBody));
93 }
94 EndOfBody if self.contains_header(&Body) => {
95 return Err(Error::IncompatibleHeaders(EndOfBody, Body));
96 }
97 _ => {}
98 }
99 let _ = self.ids.insert(id, header);
100 Ok(())
101 }
102
103 pub fn try_append(&mut self, other: HeaderSet) -> Result<(), Error> {
106 for (_, header) in other.ids.into_iter() {
107 self.add(header)?;
108 }
109 Ok(())
110 }
111
112 pub fn remove_body(&mut self, final_: bool) -> Result<Vec<u8>, Error> {
116 if final_ {
117 let Some(Header::EndOfBody(end_of_body)) = self.remove(&HeaderIdentifier::EndOfBody)
118 else {
119 return Err(PacketError::data("missing end of body header").into());
120 };
121 Ok(end_of_body)
122 } else {
123 let Some(Header::Body(body)) = self.remove(&HeaderIdentifier::Body) else {
124 return Err(PacketError::data("missing body header").into());
125 };
126 Ok(body)
127 }
128 }
129
130 pub fn remove(&mut self, id: &HeaderIdentifier) -> Option<Header> {
133 self.ids.remove(id)
134 }
135
136 pub fn try_add_srm(&mut self, local: SingleResponseMode) -> Result<SingleResponseMode, Error> {
142 if let Some(Header::SingleResponseMode(srm)) =
144 self.get(&HeaderIdentifier::SingleResponseMode)
145 {
146 if *srm == SingleResponseMode::Enable && local != SingleResponseMode::Enable {
148 return Err(Error::SrmNotSupported);
149 }
150 return Ok(*srm);
152 }
153
154 if local == SingleResponseMode::Enable {
156 self.add(SingleResponseMode::Enable.into())?;
157 }
158 Ok(local)
159 }
160
161 pub fn try_add_connection_id(
162 &mut self,
163 id: &Option<ConnectionIdentifier>,
164 ) -> Result<(), Error> {
165 if let Some(id) = id {
166 self.add(Header::ConnectionId(*id))?;
167 }
168 Ok(())
169 }
170}
171
172impl Encodable for HeaderSet {
173 type Error = PacketError;
174
175 fn encoded_len(&self) -> usize {
176 self.ids.iter().map(|(_, h)| h.encoded_len()).sum()
177 }
178
179 fn encode(&self, buf: &mut [u8]) -> Result<(), Self::Error> {
180 if buf.len() < self.encoded_len() {
181 return Err(PacketError::BufferTooSmall);
182 }
183
184 let mut start_idx = 0;
185 if let Some(header) = self.get(&HeaderIdentifier::ConnectionId) {
188 header.encode(&mut buf[start_idx..])?;
189 start_idx += header.encoded_len();
190 }
191 if let Some(header) = self.get(&HeaderIdentifier::Target) {
192 header.encode(&mut buf[start_idx..])?;
193 start_idx += header.encoded_len();
194 }
195
196 for (id, header) in &self.ids {
207 if !Self::is_special_encoding_id(&id) {
208 header.encode(&mut buf[start_idx..])?;
209 start_idx += header.encoded_len();
210 }
211 }
212
213 if let Some(header) = self.get(&HeaderIdentifier::Body) {
215 header.encode(&mut buf[start_idx..])?;
216 start_idx += header.encoded_len();
217 }
218 if let Some(header) = self.get(&HeaderIdentifier::EndOfBody) {
219 header.encode(&mut buf[start_idx..])?;
220 }
221
222 Ok(())
223 }
224}
225
226impl Decodable for HeaderSet {
227 type Error = PacketError;
228
229 fn decode(buf: &[u8]) -> Result<Self, Self::Error> {
230 let mut headers = Self::new();
231 let mut start_idx = 0;
232 while start_idx < buf.len() {
233 let header = Header::decode(&buf[start_idx..])?;
234 start_idx += header.encoded_len();
235 headers.add(header).map_err(|e| PacketError::data(format!("{e:?}")))?;
236 }
237 Ok(headers)
238 }
239}
240
241#[cfg(test)]
242#[track_caller]
243pub fn expect_body(headers: &HeaderSet, expected: Vec<u8>) {
244 match headers.get(&HeaderIdentifier::Body).expect("contains body") {
245 Header::Body(v) => assert_eq!(v, &expected),
246 x => panic!("Expected body, got: {x:?}"),
247 }
248}
249
250#[cfg(test)]
251#[track_caller]
252pub fn expect_end_of_body(headers: &HeaderSet, expected: Vec<u8>) {
253 match headers.get(&HeaderIdentifier::EndOfBody).expect("contains end of body") {
254 Header::EndOfBody(v) => assert_eq!(v, &expected),
255 x => panic!("Expected end of body, got: {x:?}"),
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 use assert_matches::assert_matches;
264
265 #[fuchsia::test]
266 fn add_duplicate_header_is_ok() {
267 let mut headers = HeaderSet::new();
268 let header = Header::ConnectionId(ConnectionIdentifier(1));
269 headers.add(header.clone()).expect("can add header");
270 assert!(headers.contains_header(&HeaderIdentifier::ConnectionId));
271 assert_matches!(headers.add(header), Ok(_));
273 assert!(headers.contains_header(&HeaderIdentifier::ConnectionId));
274 }
275
276 #[fuchsia::test]
277 fn add_existing_header_is_error() {
278 let mut headers = HeaderSet::new();
279 headers.add(Header::ConnectionId(ConnectionIdentifier(2))).expect("can add header");
280 assert!(headers.contains_header(&HeaderIdentifier::ConnectionId));
281 assert_matches!(
283 headers.add(Header::ConnectionId(ConnectionIdentifier(3))),
284 Err(Error::AlreadyExists(HeaderIdentifier::ConnectionId))
285 );
286 }
287
288 #[fuchsia::test]
289 fn try_append_success() {
290 let mut headers1 = HeaderSet::from_header(Header::name("foo"));
291 let headers2 = HeaderSet::from_header(Header::Description("bar".into()));
292 let () = headers1.try_append(headers2).expect("valid headers");
293 assert!(headers1.contains_header(&HeaderIdentifier::Name));
294 assert!(headers1.contains_header(&HeaderIdentifier::Description));
295 }
296
297 #[fuchsia::test]
298 fn try_append_error() {
299 let mut headers1 = HeaderSet::from_header(Header::name("foo"));
300 let headers2 = HeaderSet::from_header(Header::name("bar"));
301 assert_matches!(headers1.try_append(headers2), Err(Error::AlreadyExists(_)));
302 }
303
304 #[fuchsia::test]
305 fn add_incompatible_header_is_error() {
306 let mut headers = HeaderSet::from_header(Header::ConnectionId(ConnectionIdentifier(2)));
308 assert_matches!(
309 headers.add(Header::Target("123".into())),
310 Err(Error::IncompatibleHeaders(..))
311 );
312
313 let mut headers = HeaderSet::from_header(Header::EndOfBody(vec![1]));
315 assert_matches!(headers.add(Header::Body(vec![2])), Err(Error::IncompatibleHeaders(..)));
316
317 let mut headers = HeaderSet::from_header(Header::Body(vec![1]));
319 assert_matches!(
320 headers.add(Header::EndOfBody(vec![2])),
321 Err(Error::IncompatibleHeaders(..))
322 );
323 }
324
325 #[fuchsia::test]
326 fn remove_headers() {
327 let mut headers =
328 HeaderSet::from_headers(vec![Header::Count(123), Header::name("123")]).unwrap();
329 assert!(headers.contains_header(&HeaderIdentifier::Count));
330 assert!(headers.contains_header(&HeaderIdentifier::Name));
331 assert!(headers.remove(&HeaderIdentifier::Count).is_some());
332 assert!(!headers.contains_header(&HeaderIdentifier::Count));
333 assert!(headers.remove(&HeaderIdentifier::Count).is_none());
334 assert!(headers.remove(&HeaderIdentifier::Name).is_some());
335 assert!(!headers.contains_header(&HeaderIdentifier::Name));
336 }
337
338 #[fuchsia::test]
339 fn remove_body_headers() {
340 let mut body_header = HeaderSet::from_header(Header::Body(vec![1, 2]));
341 let mut end_of_body_header = HeaderSet::from_header(Header::EndOfBody(vec![7, 8, 9]));
342
343 let eob = end_of_body_header.remove_body(true).expect("end of body exists");
344 assert_eq!(eob, vec![7, 8, 9]);
345 assert_matches!(
347 end_of_body_header.remove_body(true),
348 Err(Error::Packet(PacketError::Data(_)))
349 );
350
351 let b = body_header.remove_body(false).expect("body exists");
352 assert_eq!(b, vec![1, 2]);
353 assert_matches!(body_header.remove_body(false), Err(Error::Packet(PacketError::Data(_))));
355
356 let mut headers = HeaderSet::from_headers(vec![Header::Body(vec![1])]).unwrap();
358 assert_matches!(headers.remove_body(true), Err(Error::Packet(PacketError::Data(_))));
359
360 let mut headers = HeaderSet::from_headers(vec![Header::EndOfBody(vec![1])]).unwrap();
362 assert_matches!(headers.remove_body(false), Err(Error::Packet(PacketError::Data(_))));
363 }
364
365 #[fuchsia::test]
366 fn try_add_srm_success() {
367 let mut headers = HeaderSet::new();
370 let result = headers.try_add_srm(SingleResponseMode::Enable).expect("can add SRM");
371 assert_eq!(result, SingleResponseMode::Enable);
372 assert_matches!(
373 headers.get(&HeaderIdentifier::SingleResponseMode),
374 Some(Header::SingleResponseMode(SingleResponseMode::Enable))
375 );
376 let mut headers = HeaderSet::new();
379 let result = headers.try_add_srm(SingleResponseMode::Disable).expect("can add SRM");
380 assert_eq!(result, SingleResponseMode::Disable);
381 assert_matches!(headers.get(&HeaderIdentifier::SingleResponseMode), None);
382 let mut headers = HeaderSet::from_header(SingleResponseMode::Enable.into());
385 let result = headers.try_add_srm(SingleResponseMode::Enable).expect("can add SRM");
386 assert_eq!(result, SingleResponseMode::Enable);
387 assert_matches!(
388 headers.get(&HeaderIdentifier::SingleResponseMode),
389 Some(Header::SingleResponseMode(SingleResponseMode::Enable))
390 );
391 let mut headers = HeaderSet::from_header(SingleResponseMode::Disable.into());
394 let result = headers.try_add_srm(SingleResponseMode::Disable).expect("can add SRM");
395 assert_eq!(result, SingleResponseMode::Disable);
396 assert_matches!(
397 headers.get(&HeaderIdentifier::SingleResponseMode),
398 Some(Header::SingleResponseMode(SingleResponseMode::Disable))
399 );
400 let mut headers = HeaderSet::from_header(SingleResponseMode::Disable.into());
403 let result = headers.try_add_srm(SingleResponseMode::Enable).expect("can add SRM");
404 assert_eq!(result, SingleResponseMode::Disable);
405 assert_matches!(
406 headers.get(&HeaderIdentifier::SingleResponseMode),
407 Some(Header::SingleResponseMode(SingleResponseMode::Disable))
408 );
409 }
410
411 #[fuchsia::test]
412 fn try_add_srm_error() {
413 let mut headers = HeaderSet::from_header(SingleResponseMode::Enable.into());
416 let result = headers.try_add_srm(SingleResponseMode::Disable);
417 assert_matches!(result, Err(Error::SrmNotSupported));
418 assert_matches!(
419 headers.get(&HeaderIdentifier::SingleResponseMode),
420 Some(Header::SingleResponseMode(SingleResponseMode::Enable))
421 );
422 }
423
424 #[fuchsia::test]
425 fn try_add_connection_id_success() {
426 let mut headers = HeaderSet::new();
427
428 let () = headers.try_add_connection_id(&None).expect("success");
430 assert!(!headers.contains_header(&HeaderIdentifier::ConnectionId));
431
432 let () = headers.try_add_connection_id(&Some(ConnectionIdentifier(11))).expect("success");
433 assert!(headers.contains_header(&HeaderIdentifier::ConnectionId));
434 }
435
436 #[fuchsia::test]
437 fn try_add_connection_id_error() {
438 let mut headers = HeaderSet::from_header(Header::ConnectionId(ConnectionIdentifier(10)));
440 assert_matches!(
441 headers.try_add_connection_id(&Some(ConnectionIdentifier(11))),
442 Err(Error::AlreadyExists(_))
443 );
444
445 let mut headers = HeaderSet::from_header(Header::Target("foo".into()));
447 assert_matches!(
448 headers.try_add_connection_id(&Some(ConnectionIdentifier(1))),
449 Err(Error::IncompatibleHeaders(..))
450 );
451 }
452
453 #[fuchsia::test]
454 fn encode_header_set() {
455 let headers = HeaderSet::from_headers(vec![
456 Header::ConnectionId(ConnectionIdentifier(1)),
457 Header::EndOfBody(vec![1, 2, 3]),
458 ])
459 .expect("can build header set");
460
461 assert_eq!(headers.encoded_len(), 11);
463 let mut buf = vec![0; headers.encoded_len()];
464 headers.encode(&mut buf[..]).expect("can encode headers");
465 let expected_buf = [0xcb, 0x00, 0x00, 0x00, 0x01, 0x49, 0x00, 0x06, 0x01, 0x02, 0x03];
466 assert_eq!(buf, expected_buf);
467 }
468
469 #[fuchsia::test]
470 fn encode_header_set_enforces_ordering() {
471 let headers = HeaderSet::from_headers(vec![
472 Header::Body(vec![1, 2, 3]),
473 Header::name("2"),
474 Header::ConnectionId(ConnectionIdentifier(1)),
475 Header::SingleResponseMode(SingleResponseMode::Enable),
476 ])
477 .expect("can build header set");
478
479 assert_eq!(headers.encoded_len(), 20);
480 let mut buf = vec![0; headers.encoded_len()];
481 headers.encode(&mut buf[..]).expect("can encode headers");
482 let expected_buf = [
486 0xcb, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x07, 0x00, 0x32, 0x00, 0x00, 0x97, 0x01, 0x48, 0x00, 0x06, 0x01, 0x02, 0x03, ];
491 assert_eq!(buf, expected_buf);
492 }
493
494 #[fuchsia::test]
495 fn decode_header_set() {
496 let buf = [
497 0x05, 0x00, 0x09, 0x00, 0x68, 0x00, 0x65, 0x00,
498 0x00, 0xd6, 0x00, 0x00, 0x00, 0x05, 0x97, 0x01, ];
502 let headers = HeaderSet::decode(&buf[..]).expect("can decode into headers");
503 let expected_body = Header::Description("he".into());
504 let expected_permissions = Header::Permissions(5);
505 let expected_srm = Header::SingleResponseMode(SingleResponseMode::Enable);
506 let expected_headers =
507 HeaderSet::from_headers(vec![expected_body, expected_permissions, expected_srm])
508 .unwrap();
509 assert_eq!(headers, expected_headers);
510 }
511
512 #[fuchsia::test]
513 fn decode_partial_header_set_error() {
514 let buf = [
516 0xd6, 0x00, 0x00, 0x00, 0x09, 0x97, 0x01, 0xc4, 0x00, ];
520 let headers = HeaderSet::decode(&buf[..]);
521 assert_matches!(headers, Err(PacketError::BufferTooSmall));
522 }
523}