1use std::io::{self, Read, Write};
6
7use thiserror::Error;
8
9use fuchsia_trace as ftrace;
10
11use crate::{Fixed, NewId, ObjectId};
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq)]
14pub enum ArgKind {
15 Int,
16 Uint,
17 Fixed,
18 String,
19 Object,
20 NewId,
21 Array,
22 Handle,
23}
24
25#[derive(Debug)]
26pub enum Arg {
27 Int(i32),
28 Uint(u32),
29 Fixed(Fixed),
30 String(String),
31 Object(ObjectId),
32 NewId(NewId),
33 Array(Array),
34 Handle(zx::Handle),
35}
36
37impl Arg {
38 pub fn kind(&self) -> ArgKind {
39 match self {
40 Arg::Int(_) => ArgKind::Int,
41 Arg::Uint(_) => ArgKind::Uint,
42 Arg::Fixed(_) => ArgKind::Fixed,
43 Arg::String(_) => ArgKind::String,
44 Arg::Object(_) => ArgKind::Object,
45 Arg::NewId(_) => ArgKind::NewId,
46 Arg::Array(_) => ArgKind::Array,
47 Arg::Handle(_) => ArgKind::Handle,
48 }
49 }
50}
51
52macro_rules! impl_unwrap_arg(
53 ($name:ident, $type:ty, $enumtype:ident) => (
54 pub fn $name(self) -> $type {
55 if let Arg::$enumtype(x) = self {
56 x
57 } else {
58 panic!("Argument is not of the required type: \
59 expected {:?}, found {:?}",
60 stringify!($enumtype), self);
61 }
62 }
63 )
64);
65
66impl Arg {
67 impl_unwrap_arg!(unwrap_int, i32, Int);
68 impl_unwrap_arg!(unwrap_uint, u32, Uint);
69 impl_unwrap_arg!(unwrap_fixed, Fixed, Fixed);
70 impl_unwrap_arg!(unwrap_object, ObjectId, Object);
71 impl_unwrap_arg!(unwrap_new_id, ObjectId, NewId);
72 impl_unwrap_arg!(unwrap_string, String, String);
73 impl_unwrap_arg!(unwrap_array, Array, Array);
74 impl_unwrap_arg!(unwrap_handle, zx::Handle, Handle);
75}
76
77#[derive(Debug, Error)]
78#[error("Argument is not of the required type: expected {:?}, found {:?}", expected, found)]
79pub struct MismatchedArgKind {
80 pub expected: ArgKind,
81 pub found: ArgKind,
82}
83
84macro_rules! impl_as_arg(
85 ($name:ident, $type:ty, $enumtype:ident) => (
86 pub fn $name(self) -> Result<$type, MismatchedArgKind> {
87 if let Arg::$enumtype(x) = self {
88 Ok(x)
89 } else {
90 Err(MismatchedArgKind {
91 expected: ArgKind::$enumtype,
92 found: self.kind(),
93 })
94 }
95 }
96 )
97);
98
99impl Arg {
100 impl_as_arg!(as_int, i32, Int);
101 impl_as_arg!(as_uint, u32, Uint);
102 impl_as_arg!(as_fixed, Fixed, Fixed);
103 impl_as_arg!(as_object, ObjectId, Object);
104 impl_as_arg!(as_new_id, ObjectId, NewId);
105 impl_as_arg!(as_string, String, String);
106 impl_as_arg!(as_array, Array, Array);
107 impl_as_arg!(as_handle, zx::Handle, Handle);
108}
109
110macro_rules! impl_from_for_arg(
112 ($enumtype:ident, $type:ty) => (
113 impl From<$type> for Arg {
114 fn from(v: $type) -> Self {
115 Arg::$enumtype(v)
116 }
117 }
118 )
119);
120impl_from_for_arg!(Int, i32);
121impl_from_for_arg!(Uint, u32);
122impl_from_for_arg!(Fixed, Fixed);
123impl_from_for_arg!(String, String);
124impl_from_for_arg!(Array, Array);
125impl_from_for_arg!(Handle, zx::Handle);
126
127#[derive(Copy, Clone, Debug, Eq, PartialEq)]
128pub struct MessageHeader {
129 pub sender: u32,
130 pub opcode: u16,
131 pub length: u16,
132}
133
134#[derive(Debug)]
135pub struct Message {
136 byte_buf: io::Cursor<Vec<u8>>,
137 handle_buf: Vec<zx::Handle>,
138}
139
140fn compute_padding(size: u64) -> usize {
141 (-(size as i64) & 3) as usize
142}
143
144impl Message {
145 pub fn new() -> Self {
147 Message { byte_buf: io::Cursor::new(Vec::new()), handle_buf: Vec::new() }
148 }
149
150 pub fn from_parts(bytes: Vec<u8>, handles: Vec<zx::Handle>) -> Self {
151 Message { byte_buf: io::Cursor::new(bytes), handle_buf: handles }
152 }
153
154 pub fn is_empty(&self) -> bool {
156 self.byte_buf.get_ref().len() as u64 == self.byte_buf.position()
157 && self.handle_buf.is_empty()
158 }
159
160 pub fn clear(&mut self) {
161 self.byte_buf.set_position(0);
162 self.byte_buf.get_mut().truncate(0);
163 self.handle_buf.truncate(0);
164 }
165
166 pub fn rewind(&mut self) {
167 self.byte_buf.set_position(0);
168 }
169
170 pub fn bytes(&self) -> &[u8] {
171 self.byte_buf.get_ref().as_slice()
172 }
173
174 pub fn take(self) -> (Vec<u8>, Vec<zx::Handle>) {
175 (self.byte_buf.into_inner(), self.handle_buf)
176 }
177
178 pub fn write_arg(&mut self, arg: Arg) -> io::Result<()> {
181 ftrace::duration!(c"wayland", c"Message::write_arg");
182 match arg {
183 Arg::Int(i) => self.byte_buf.write_all(&i.to_ne_bytes()[..]),
184 Arg::Uint(i) => self.byte_buf.write_all(&i.to_ne_bytes()[..]),
185 Arg::Fixed(i) => self.byte_buf.write_all(&i.bits().to_ne_bytes()[..]),
186 Arg::Object(i) => self.byte_buf.write_all(&i.to_ne_bytes()[..]),
187 Arg::NewId(i) => self.byte_buf.write_all(&i.to_ne_bytes()[..]),
188 Arg::String(s) => self.write_slice(s.as_bytes(), true),
189 Arg::Array(a) => self.write_slice(a.as_slice(), false),
190 Arg::Handle(h) => {
191 self.handle_buf.push(h);
192 Ok(())
193 }
194 }
195 }
196
197 fn read_i32(&mut self) -> io::Result<i32> {
198 let mut buf = [0; std::mem::size_of::<i32>()];
199 let () = self.byte_buf.read_exact(&mut buf)?;
200 Ok(i32::from_ne_bytes(buf))
201 }
202
203 fn read_u32(&mut self) -> io::Result<u32> {
204 let mut buf = [0; std::mem::size_of::<u32>()];
205 let () = self.byte_buf.read_exact(&mut buf)?;
206 Ok(u32::from_ne_bytes(buf))
207 }
208
209 pub fn read_arg(&mut self, arg: ArgKind) -> io::Result<Arg> {
211 ftrace::duration!(c"wayland", c"Message::read_arg");
212 match arg {
213 ArgKind::Int => self.read_i32().map(Arg::Int),
214 ArgKind::Uint => self.read_u32().map(Arg::Uint),
215 ArgKind::Fixed => self.read_i32().map(|i| Arg::Fixed(i.into())),
216 ArgKind::Object => self.read_u32().map(Arg::Object),
217 ArgKind::NewId => self.read_u32().map(Arg::NewId),
218 ArgKind::String => self
219 .read_slice(true)
220 .map(|vec| String::from_utf8_lossy(vec.as_slice()).to_string())
221 .map(Arg::String),
222 ArgKind::Array => self.read_slice(false).map(|v| Arg::Array(v.into())),
223 ArgKind::Handle => {
224 if !self.handle_buf.is_empty() {
225 Ok(Arg::Handle(self.handle_buf.remove(0)))
226 } else {
227 Err(io::Error::new(
228 io::ErrorKind::UnexpectedEof,
229 "Unable to read handle from Message",
230 ))
231 }
232 }
233 }
234 }
235
236 pub fn read_args(&mut self, args: &[ArgKind]) -> io::Result<Vec<Arg>> {
238 ftrace::duration!(c"wayland", c"Message::read_args", "len" => args.len() as u64);
239 args.iter().map(|arg| self.read_arg(*arg)).collect()
240 }
241
242 pub fn write_args(&mut self, args: Vec<Arg>) -> io::Result<()> {
244 ftrace::duration!(c"wayland", c"Message::write_args", "len" => args.len() as u64);
245 args.into_iter().try_for_each(|arg| self.write_arg(arg))?;
246 Ok(())
247 }
248
249 pub fn peek_header(&mut self) -> io::Result<MessageHeader> {
252 let pos = self.byte_buf.position();
253 let header = self.read_header();
254 self.byte_buf.set_position(pos);
255 header
256 }
257
258 pub fn read_header(&mut self) -> io::Result<MessageHeader> {
259 let sender = self.read_u32()?;
260 let word = self.read_u32()?;
261 Ok(MessageHeader { sender, length: (word >> 16) as u16, opcode: word as u16 })
262 }
263
264 pub fn write_header(&mut self, header: &MessageHeader) -> io::Result<()> {
265 self.byte_buf.write_all(&header.sender.to_ne_bytes()[..])?;
266 self.byte_buf
267 .write_all(&((header.length as u32) << 16 | header.opcode as u32).to_ne_bytes()[..])?;
268 Ok(())
269 }
270
271 fn read_slice(&mut self, null_term: bool) -> io::Result<Vec<u8>> {
272 let pos = self.byte_buf.position();
273 let len = self.read_u32()?;
274 let mut vec: Vec<u8> = Vec::with_capacity(len as usize);
275 if len == 0 {
276 return Ok(vec);
277 }
278
279 vec.resize(len as usize, 0);
280 self.byte_buf.read_exact(vec.as_mut_slice())?;
281
282 if null_term {
283 match vec.pop() {
284 Some(term) => {
285 if term != b'\0' {
286 return Err(io::Error::new(
287 io::ErrorKind::InvalidData,
288 format!("Expected null terminator; found {}", term),
289 ));
290 }
291 }
292 None => {
293 return Err(io::Error::new(
294 io::ErrorKind::UnexpectedEof,
295 "Missing null terminator on string argument",
296 ));
297 }
298 }
299 }
300
301 let pad = compute_padding(self.byte_buf.position() - pos);
302 for _ in 0..pad {
303 let mut buf = [0; 1];
304 self.byte_buf.read_exact(&mut buf)?;
305 }
306 assert!(self.byte_buf.position() % 4 == 0);
307 Ok(vec)
308 }
309
310 fn write_slice(&mut self, s: &[u8], null_term: bool) -> io::Result<()> {
311 let pos = self.byte_buf.position();
312 let mut len: u32 = s.len() as u32;
313 if null_term {
314 len += 1;
315 }
316
317 self.byte_buf.write_all(&len.to_ne_bytes())?;
318 self.byte_buf.write_all(s)?;
319 if null_term {
320 self.byte_buf.write_all(&0u8.to_ne_bytes()[..])?;
321 }
322
323 let pad = compute_padding(self.byte_buf.position() - pos);
325 for _ in 0..pad {
326 self.byte_buf.write_all(&0u8.to_ne_bytes()[..])?;
327 }
328 assert!(self.byte_buf.position() % 4 == 0);
329 Ok(())
330 }
331}
332
333impl From<zx::MessageBuf> for Message {
335 fn from(buf: zx::MessageBuf) -> Self {
336 let (bytes, handles) = buf.split();
337 Message::from_parts(bytes, handles)
338 }
339}
340
341#[derive(Debug)]
359pub struct Array(Message);
360
361impl Array {
362 pub fn new() -> Self {
364 Self(Message::new())
365 }
366
367 pub fn from_vec(v: Vec<u8>) -> Self {
369 Self(Message::from_parts(v, vec![]))
370 }
371
372 pub fn as_slice(&self) -> &[u8] {
374 self.0.bytes()
375 }
376
377 pub fn into_vec(self) -> Vec<u8> {
379 let (bytes, _) = self.0.take();
380 bytes
381 }
382
383 pub fn len(&self) -> usize {
385 self.0.bytes().len()
386 }
387
388 pub fn push<T: Into<Arg>>(&mut self, arg: T) -> io::Result<()> {
393 let arg = arg.into();
394 if let Arg::Handle(_) = &arg {
395 Err(io::Error::new(io::ErrorKind::InvalidInput, "Arrays cannot contain handles"))
396 } else {
397 self.0.write_arg(arg)
398 }
399 }
400
401 pub fn read_arg(&mut self, kind: ArgKind) -> io::Result<Arg> {
406 if kind == ArgKind::Handle {
407 Err(io::Error::new(io::ErrorKind::InvalidInput, "Arrays cannot contain handles"))
408 } else {
409 self.0.read_arg(kind)
410 }
411 }
412}
413
414impl From<Vec<u8>> for Array {
415 fn from(v: Vec<u8>) -> Self {
416 Self::from_vec(v)
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 use anyhow::Error;
425
426 macro_rules! assert_matches(
433 ($e:expr, $p:pat => $a:expr) => (
434 match $e {
435 $p => $a,
436 _ => panic!("Failed to match!"),
437 }
438 )
439 );
440
441 const UINT_VALUE: u32 = 0x1234567;
442 const INT_VALUE: i32 = -12345678;
443 const FIXED_VALUE: i32 = 0x11223344;
444 const OBJECT_VALUE: u32 = 0x88775566;
445 const NEW_ID_VALUE: u32 = 0x55443322;
446 const STRING_VALUE: &str = "Hello from a test";
447 const ARRAY_VALUE: &[u8] = &[0, 1, 2, 3, 4, 5, 6];
448
449 #[test]
451 fn sanity() {
452 let (h1, _h2) = zx::Socket::create_stream();
453
454 let mut message = Message::new();
455 assert!(message.write_arg(Arg::Uint(UINT_VALUE)).is_ok());
456 assert!(message.write_arg(Arg::Int(INT_VALUE)).is_ok());
457 assert!(message.write_arg(Arg::Fixed(Fixed::from_bits(FIXED_VALUE))).is_ok());
458 assert!(message.write_arg(Arg::Object(OBJECT_VALUE)).is_ok());
459 assert!(message.write_arg(Arg::NewId(NEW_ID_VALUE)).is_ok());
460 assert!(message.write_arg(Arg::String(STRING_VALUE.to_owned())).is_ok());
461 assert!(message.write_arg(Arg::Array(ARRAY_VALUE.to_owned().into())).is_ok());
462 assert!(message.write_arg(Arg::Handle(h1.into())).is_ok());
463
464 let (bytes, handles) = message.take();
465 assert_eq!(1, handles.len());
466 const INT_SIZE: u32 = 4;
467 const UINT_SIZE: u32 = 4;
468 const FIXED_SIZE: u32 = 4;
469 const OBJECT_SIZE: u32 = 4;
470 const NEW_ID_SIZE: u32 = 4;
471 const STRING_SIZE: u32 = 24;
472 const VEC_SIZE: u32 = 12;
473 let expected_size =
474 INT_SIZE + UINT_SIZE + FIXED_SIZE + OBJECT_SIZE + NEW_ID_SIZE + STRING_SIZE + VEC_SIZE;
475 assert_eq!(expected_size as usize, bytes.len());
476
477 let mut message = Message::from_parts(bytes, handles);
478 let arg = message.read_arg(ArgKind::Uint).unwrap();
479 assert_matches!(arg, Arg::Uint(x) => assert_eq!(x, UINT_VALUE));
480 let arg = message.read_arg(ArgKind::Int).unwrap();
481 assert_matches!(arg, Arg::Int(x) => assert_eq!(x, INT_VALUE));
482 let arg = message.read_arg(ArgKind::Fixed).unwrap();
483 assert_matches!(arg, Arg::Fixed(x) => assert_eq!(x, Fixed::from_bits(FIXED_VALUE)));
484 let arg = message.read_arg(ArgKind::Object).unwrap();
485 assert_matches!(arg, Arg::Object(x) => assert_eq!(x, OBJECT_VALUE));
486 let arg = message.read_arg(ArgKind::NewId).unwrap();
487 assert_matches!(arg, Arg::NewId(x) => assert_eq!(x, NEW_ID_VALUE));
488 let arg = message.read_arg(ArgKind::String).unwrap();
489 assert_matches!(arg, Arg::String(ref x) => assert_eq!(x, STRING_VALUE));
490 let arg = message.read_arg(ArgKind::Array).unwrap();
491 assert_matches!(arg, Arg::Array(ref x) => assert_eq!(x.as_slice(), ARRAY_VALUE));
492 }
493
494 #[test]
496 fn string_test() {
497 let mut message = Message::new();
498
499 let s0 = "";
501 message.clear();
502 assert!(message.write_arg(Arg::String(s0.to_owned())).is_ok());
503 assert_eq!(8, message.bytes().len());
504 message.rewind();
505 let arg = message.read_arg(ArgKind::String).unwrap();
506 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s0));
507
508 let s1 = "1";
510 message.clear();
511 assert!(message.write_arg(Arg::String(s1.to_owned())).is_ok());
512 assert_eq!(8, message.bytes().len());
513 message.rewind();
514 let arg = message.read_arg(ArgKind::String).unwrap();
515 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s1));
516
517 let s2 = "22";
519 message.clear();
520 assert!(message.write_arg(Arg::String(s2.to_owned())).is_ok());
521 assert_eq!(8, message.bytes().len());
522 message.rewind();
523 let arg = message.read_arg(ArgKind::String).unwrap();
524 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s2));
525
526 let s3 = "333";
528 message.clear();
529 assert!(message.write_arg(Arg::String(s3.to_owned())).is_ok());
530 assert_eq!(8, message.bytes().len());
531 message.rewind();
532 let arg = message.read_arg(ArgKind::String).unwrap();
533 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s3));
534
535 let s4 = "4444";
537 message.clear();
538 assert!(message.write_arg(Arg::String(s4.to_owned())).is_ok());
539 assert_eq!(12, message.bytes().len());
540 message.rewind();
541 let arg = message.read_arg(ArgKind::String).unwrap();
542 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s4));
543
544 let s5 = "55555";
546 message.clear();
547 assert!(message.write_arg(Arg::String(s5.to_owned())).is_ok());
548 assert_eq!(12, message.bytes().len());
549 message.rewind();
550 let arg = message.read_arg(ArgKind::String).unwrap();
551 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s5));
552
553 message.clear();
561 assert!(message.write_arg(Arg::Uint(0)).is_ok());
562 assert_eq!(4, message.bytes().len());
563 message.rewind();
564 let arg = message.read_arg(ArgKind::String).unwrap();
565 assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, ""));
566 }
567
568 #[test]
569 fn peek_header() -> Result<(), Error> {
570 let header = MessageHeader { sender: 3, opcode: 2, length: 8 };
572 let mut message = Message::new();
573 message.write_header(&header)?;
574 message.rewind();
575
576 assert_eq!(header, message.peek_header()?);
578 assert_eq!(header, message.peek_header()?);
579 assert_eq!(header, message.peek_header()?);
580 assert_eq!(header, message.peek_header()?);
581 Ok(())
582 }
583
584 #[test]
585 fn empty_message() {
586 let (h1, h2) = zx::Channel::create();
587
588 let message = Message::new();
589 assert!(message.is_empty());
590 let message = Message::from_parts(vec![], vec![]);
591 assert!(message.is_empty());
592 let message = Message::from_parts(vec![1], vec![]);
593 assert!(!message.is_empty());
594 let message = Message::from_parts(vec![], vec![h1.into()]);
595 assert!(!message.is_empty());
596
597 let mut message = Message::from_parts(vec![0, 0, 0, 0], vec![h2.into()]);
598 assert!(!message.is_empty());
599 let _ = message.read_arg(ArgKind::Uint).unwrap();
600 assert!(!message.is_empty());
601 let _ = message.read_arg(ArgKind::Handle).unwrap();
602 assert!(message.is_empty());
603 }
604
605 #[test]
606 fn array_read_write() -> Result<(), Error> {
607 let mut message = Message::new();
608 let mut array = Array::new();
609 array.push(3)?;
610 array.push(-2)?;
611 array.push(Fixed::from_float(-2.0))?;
612 message.write_arg(array.into())?;
613 message.rewind();
614
615 let mut array = message.read_arg(ArgKind::Array)?.as_array()?;
616 assert_eq!(3, array.read_arg(ArgKind::Uint)?.as_uint()?);
617 assert_eq!(-2, array.read_arg(ArgKind::Int)?.as_int()?);
618 assert_eq!(Fixed::from_float(-2.), array.read_arg(ArgKind::Fixed)?.as_fixed()?);
619
620 Ok(())
621 }
622}