1use fidl::encoding::{
6 AtRestFlags, DynamicFlags, ALLOC_PRESENT_U32, ALLOC_PRESENT_U64, MAGIC_NUMBER_INITIAL,
7};
8use fidl::AsHandleRef as _;
9
10use std::collections::HashMap;
11
12use crate::error::{Error, Result};
13use crate::library;
14use crate::util::*;
15use crate::value::Value;
16
17type DeferCallback<'n, 't> = dyn FnOnce(&mut EncodeBuffer<'n>, RecursionCounter) -> Result<()> + 't;
19
20fn combine_calls<'n: 't, 't>(calls: Vec<Box<DeferCallback<'n, 't>>>) -> Box<DeferCallback<'n, 't>> {
22 Box::new(move |this, counter| {
23 for call in calls {
24 call(this, counter)?;
25 }
26
27 Ok(())
28 })
29}
30
31enum HandleType<'s> {
32 ClientEnd(&'s str),
33 ServerEnd(&'s str),
34 Bare,
35}
36
37struct EncodeBuffer<'n> {
39 ns: &'n library::Namespace,
40 bytes: Vec<u8>,
41 handles: Vec<fidl::HandleDisposition<'static>>,
42}
43
44impl<'n> EncodeBuffer<'n> {
45 fn align_8(&mut self) {
47 self.bytes
48 .extend(std::iter::repeat(0u8).take(alignment_padding_for_size(self.bytes.len())));
49 }
50
51 fn encode_transaction<'n_i: 't, 't>(
52 ns: &'n_i library::Namespace,
53 txid: u32,
54 protocol_name: &str,
55 direction: Direction,
56 method_name: &str,
57 value: Value,
58 ) -> Result<(Vec<u8>, Vec<fidl::HandleDisposition<'static>>)> {
59 let mut buf = EncodeBuffer { ns, bytes: Vec::new(), handles: Vec::new() };
60
61 let protocol = match ns.lookup(protocol_name)? {
62 library::LookupResult::Protocol(i) => Ok(i),
63 _ => Err(Error::LibraryError(format!("Could not find protocol '{}'.", protocol_name))),
64 }?;
65
66 let method = protocol.methods.get(method_name).ok_or_else(|| {
67 Error::LibraryError(format!(
68 "Could not find method '{}' on protocol '{}'",
69 method_name, protocol_name
70 ))
71 })?;
72
73 let (ty, has) = match direction {
74 Direction::Request => {
75 if !method.has_response && txid != 0 {
76 return Err(Error::EncodeError(
77 "Non-zero transaction ID for one-way method.".to_owned(),
78 ));
79 }
80 (method.request.as_ref(), method.has_request)
81 }
82 Direction::Response => (method.response.as_ref(), method.has_response),
83 };
84
85 let dynamic_flags =
86 if method.strict { DynamicFlags::empty() } else { DynamicFlags::FLEXIBLE };
87
88 if !has {
89 return Err(Error::LibraryError(format!(
90 "Method '{}' on protocol '{}' has no {}",
91 method_name,
92 protocol_name,
93 direction.to_string()
94 )));
95 }
96
97 buf.bytes.extend(&txid.to_le_bytes());
98 buf.bytes.extend(&AtRestFlags::USE_V2_WIRE_FORMAT.bits().to_le_bytes());
99 buf.bytes.push(dynamic_flags.bits());
100 buf.bytes.push(MAGIC_NUMBER_INITIAL);
101 buf.bytes.extend(&method.ordinal.to_le_bytes());
102
103 if let Some(ty) = ty {
104 buf.encode_type(ty, value)?(&mut buf, RecursionCounter::new())?
105 } else if !matches!(value, Value::Null) {
106 return Err(Error::EncodeError("Value must be null.".to_owned()));
107 } else {
108 };
109 buf.align_8();
110 Ok((buf.bytes, buf.handles))
111 }
112
113 fn encode_struct_nonnull<'t>(
114 &mut self,
115 st: &'n library::Struct,
116 value: Value,
117 start_offset: usize,
118 ) -> Result<Box<DeferCallback<'n, 't>>>
119 where
120 'n: 't,
121 {
122 let start_offset = self.bytes.len() - start_offset;
123
124 let values = match value {
125 Value::Object(s) => Ok(s),
126 _ => Err(Error::EncodeError("Value is not a struct.".to_owned())),
127 }?;
128
129 let mut values = {
130 let mut map = HashMap::with_capacity(values.len());
131
132 for (k, v) in values {
133 map.insert(k, v);
134 }
135
136 map
137 };
138
139 let mut calls = Vec::new();
140
141 for member in &st.members {
142 let value = values.remove(&member.name).unwrap_or(Value::Null);
143 self.bytes.extend(
144 std::iter::repeat(0u8).take(member.offset - (self.bytes.len() - start_offset)),
145 );
146 calls.push(self.encode_type(&member.ty, value)?);
147 }
148
149 if let Some((name, _)) = values.into_iter().next() {
150 Err(Error::EncodeError(format!("Unknown struct member: {}", name)))
151 } else {
152 self.bytes
153 .extend(std::iter::repeat(0u8).take(st.size - (self.bytes.len() - start_offset)));
154 Ok(combine_calls(calls))
155 }
156 }
157
158 fn encode_type<'t>(
159 &mut self,
160 ty: &'n library::Type,
161 value: Value,
162 ) -> Result<Box<DeferCallback<'n, 't>>>
163 where
164 'n: 't,
165 {
166 use library::Type::*;
167
168 match ty {
169 Unknown(_) | UnknownString(_) => {
170 return Err(Error::LibraryError("Unknown type".to_owned()))
171 }
172 Bool => self.encode_raw(if bool::try_from(value)? { &[1u8] } else { &[0u8] }),
173 U8 => self.encode_raw(&u8::try_from(value)?.to_le_bytes()),
174 U16 => self.encode_raw(&u16::try_from(value)?.to_le_bytes()),
175 U32 => self.encode_raw(&u32::try_from(value)?.to_le_bytes()),
176 U64 => self.encode_raw(&u64::try_from(value)?.to_le_bytes()),
177 I8 => self.encode_raw(&i8::try_from(value)?.to_le_bytes()),
178 I16 => self.encode_raw(&i16::try_from(value)?.to_le_bytes()),
179 I32 => self.encode_raw(&i32::try_from(value)?.to_le_bytes()),
180 I64 => self.encode_raw(&i64::try_from(value)?.to_le_bytes()),
181 F32 => self.encode_raw(&f32::try_from(value)?.to_le_bytes()),
182 F64 => self.encode_raw(&f64::try_from(value)?.to_le_bytes()),
183 Array(ty, size) => self.encode_array(ty, *size, value),
184 Vector { ty, nullable, element_count } => {
185 self.encode_vector(ty, *nullable, value, *element_count)
186 }
187 String { nullable, byte_count } => self.encode_string(*nullable, value, *byte_count),
188 Handle { object_type, rights, nullable } => {
189 self.encode_handle(*object_type, *rights, HandleType::Bare, *nullable, value)
190 }
191 FrameworkError => self.encode_raw(&[0, 0, 0, 0]),
192 Endpoint { role, protocol, rights, nullable } => self.encode_handle(
193 fidl::ObjectType::CHANNEL,
194 *rights,
195 match role {
196 library::EndpointRole::Client => HandleType::ClientEnd(protocol),
197 library::EndpointRole::Server => HandleType::ServerEnd(protocol),
198 },
199 *nullable,
200 value,
201 ),
202 Identifier { name, nullable } => self.encode_identifier(name.clone(), *nullable, value),
203 }
204 }
205
206 fn encode_raw<'t>(&mut self, data: &[u8]) -> Result<Box<DeferCallback<'n, 't>>>
207 where
208 'n: 't,
209 {
210 self.bytes.extend(data);
211 Ok(Box::new(|_, _| Ok(())))
212 }
213
214 fn encode_array<'t>(
215 &mut self,
216 ty: &'n library::Type,
217 size: usize,
218 value: Value,
219 ) -> Result<Box<DeferCallback<'n, 't>>>
220 where
221 'n: 't,
222 {
223 let values = if let Value::List(v) = value {
224 Ok(v)
225 } else {
226 Err(Error::EncodeError("Expected a list".to_owned()))
227 }?;
228
229 if values.len() != size {
230 return Err(Error::EncodeError(format!("Expected list of length {}", size)));
231 }
232
233 let mut calls = Vec::with_capacity(size);
234
235 for value in values {
236 calls.push(self.encode_type(ty, value)?);
237 }
238
239 Ok(combine_calls(calls))
240 }
241
242 fn encode_vector<'t>(
243 &mut self,
244 ty: &'n library::Type,
245 nullable: bool,
246 value: Value,
247 element_count: Option<usize>,
248 ) -> Result<Box<DeferCallback<'n, 't>>>
249 where
250 'n: 't,
251 {
252 let values = match (value, nullable) {
253 (Value::Null, true) => Ok(None),
254 (Value::Null, false) => {
255 Err(Error::EncodeError("Got null for non-nullable list".to_owned()))
256 }
257 (Value::List(v), _) => Ok(Some(v)),
258 _ => Err(Error::EncodeError("Expected a list".to_owned())),
259 }?;
260
261 if let Some(values) = values {
262 if element_count.map(|x| x < values.len()).unwrap_or(false) {
263 return Err(Error::EncodeError("Vector exceeded max size".to_owned()));
264 }
265
266 self.bytes.extend(&(values.len() as u64).to_le_bytes());
267 self.bytes.extend(&ALLOC_PRESENT_U64.to_le_bytes());
268 Ok(Box::new(move |this, counter| {
269 let counter = counter.next()?;
270 let mut calls = Vec::with_capacity(values.len());
271
272 for value in values {
273 calls.push(this.encode_type(ty, value)?);
274 }
275
276 this.align_8();
277
278 for call in calls {
279 call(this, counter)?;
280 }
281
282 Ok(())
283 }))
284 } else {
285 self.bytes.extend(std::iter::repeat(0u8).take(16));
286 Ok(Box::new(|_, _| Ok(())))
287 }
288 }
289
290 fn encode_string<'t>(
291 &mut self,
292 nullable: bool,
293 value: Value,
294 byte_count: Option<usize>,
295 ) -> Result<Box<DeferCallback<'n, 't>>>
296 where
297 'n: 't,
298 {
299 let string = match (value, nullable) {
300 (Value::Null, true) => Ok(None),
301 (Value::Null, false) => {
302 Err(Error::EncodeError("Got null for non-nullable string".to_owned()))
303 }
304 (Value::String(s), _) => Ok(Some(s)),
305 _ => Err(Error::EncodeError("Expected a string".to_owned())),
306 }?;
307
308 if let Some(string) = string {
309 if byte_count.map(|x| x < string.len()).unwrap_or(false) {
310 return Err(Error::EncodeError("String exceeded max size".to_owned()));
311 }
312
313 self.bytes.extend(&(string.len() as u64).to_le_bytes());
314 self.bytes.extend(&ALLOC_PRESENT_U64.to_le_bytes());
315 Ok(Box::new(move |this, counter| {
316 let _counter = counter.next()?;
317 this.bytes.extend(string.as_bytes());
318 this.align_8();
319 Ok(())
320 }))
321 } else {
322 self.bytes.extend(std::iter::repeat(0u8).take(16));
323 Ok(Box::new(|_, _| Ok(())))
324 }
325 }
326
327 fn encode_handle<'t>(
328 &mut self,
329 object_type: fidl::ObjectType,
330 rights: fidl::Rights,
331 expect: HandleType<'_>,
332 nullable: bool,
333 value: Value,
334 ) -> Result<Box<DeferCallback<'n, 't>>>
335 where
336 'n: 't,
337 {
338 let handle_op = match (value, nullable, expect) {
339 (Value::Null, true, _) => Ok(None),
340 (Value::Handle(h, _), true, _) if h.is_invalid() => Ok(None),
341 (Value::ServerEnd(h, _), true, _) if h.as_handle_ref().is_invalid() => Ok(None),
342 (Value::ClientEnd(h, _), true, _) if h.as_handle_ref().is_invalid() => Ok(None),
343 (Value::Handle(h, _), false, _) if h.is_invalid() => {
344 Err(Error::EncodeError("Got invalid handle for non-nullable handle".to_owned()))
345 }
346 (Value::ServerEnd(h, _), false, _) if h.as_handle_ref().is_invalid() => {
347 Err(Error::EncodeError("Got invalid handle for non-nullable handle".to_owned()))
348 }
349 (Value::ClientEnd(h, _), false, _) if h.as_handle_ref().is_invalid() => {
350 Err(Error::EncodeError("Got invalid handle for non-nullable handle".to_owned()))
351 }
352 (Value::Null, false, _) => {
353 Err(Error::EncodeError("Got null for non-nullable handle".to_owned()))
354 }
355 (Value::Handle(h, s), _, _) => {
356 if s != object_type && s != fidl::ObjectType::NONE {
357 Err(Error::EncodeError(format!(
358 "Expected object type {object_type:?} got {s:?}"
359 )))
360 } else {
361 Ok(Some(fidl::HandleOp::Move(h)))
362 }
363 }
364 (Value::ServerEnd(h, s), _, HandleType::ServerEnd(expect))
365 | (Value::ClientEnd(h, s), _, HandleType::ClientEnd(expect)) => {
366 if expect != s {
367 Err(Error::EncodeError(format!(
368 "Expected endpoint for protocol {expect}, got one for {s}"
369 )))
370 } else if object_type != fidl::ObjectType::CHANNEL {
371 Err(Error::EncodeError(format!(
372 "Expected object type {object_type:?} got channel for protocol {s}"
373 )))
374 } else {
375 Ok(Some(fidl::HandleOp::Move(h.into())))
376 }
377 }
378 (Value::ServerEnd(_, s), _, HandleType::ClientEnd(expect))
379 | (Value::ClientEnd(_, s), _, HandleType::ServerEnd(expect)) => {
380 if expect != s {
381 Err(Error::EncodeError(format!(
382 "Expected endpoint for protocol {expect}, got one for {s}"
383 )))
384 } else if object_type != fidl::ObjectType::CHANNEL {
385 Err(Error::EncodeError(format!(
386 "Expected object type {object_type:?} got channel for protocol {s}"
387 )))
388 } else {
389 Err(Error::EncodeError(format!("Got wrong end of channel for {expect}")))
390 }
391 }
392 (Value::ServerEnd(h, s), _, HandleType::Bare)
393 | (Value::ClientEnd(h, s), _, HandleType::Bare) => {
394 if object_type != fidl::ObjectType::CHANNEL {
395 Err(Error::EncodeError(format!(
396 "Expected object type {object_type:?} got channel for protocol {s}"
397 )))
398 } else {
399 Ok(Some(fidl::HandleOp::Move(h.into())))
400 }
401 }
402 _ => Err(Error::EncodeError("Expected a handle".to_owned())),
403 }?;
404
405 if let Some(handle_op) = handle_op {
406 self.bytes.extend(&ALLOC_PRESENT_U32.to_le_bytes());
407 Ok(Box::new(move |this, _| {
408 this.handles.push(fidl::HandleDisposition::new(
409 handle_op,
410 object_type,
411 rights,
412 fidl::Status::OK,
413 ));
414 Ok(())
415 }))
416 } else {
417 self.bytes.extend(&0u32.to_le_bytes());
418 Ok(Box::new(|_, _| Ok(())))
419 }
420 }
421
422 fn encode_identifier<'t>(
423 &mut self,
424 name: String,
425 nullable: bool,
426 value: Value,
427 ) -> Result<Box<DeferCallback<'n, 't>>>
428 where
429 'n: 't,
430 {
431 use library::LookupResult::*;
432 match (self.ns.lookup(&name)?, nullable) {
433 (Bits(b), false) => self.encode_bits(b, value),
434 (Enum(e), false) => self.encode_enum(e, value),
435 (Table(t), false) => self.encode_table(t, value),
436 (Struct(s), nullable) => self.encode_struct(s, nullable, value),
437 (Union(u), nullable) => self.encode_union(u, nullable, value),
438 (Protocol(_), _) => Err(Error::LibraryError(format!(
439 "Protocol names cannot be used as identifiers: {}",
440 name
441 ))),
442 _ => Err(Error::LibraryError(format!("Type {} shouldn't be nullable", name))),
443 }
444 }
445
446 fn encode_bits<'t>(
447 &mut self,
448 bits: &'n library::Bits,
449 value: Value,
450 ) -> Result<Box<DeferCallback<'n, 't>>>
451 where
452 'n: 't,
453 {
454 let value = match value {
455 Value::Bits(name, inner) => {
456 if name == bits.name {
457 *inner
458 } else {
459 return Err(Error::EncodeError(format!(
460 "Expected {}, got {}",
461 bits.name, name
462 )));
463 }
464 }
465 _ => value,
466 };
467
468 let data = u64::try_from(&value).unwrap_or(0);
471
472 if bits.strict && data & !bits.mask != 0 {
473 Err(Error::EncodeError(format!("Invalid bits set on {}", bits.name)))
474 } else {
475 self.encode_type(&bits.ty, value)
476 }
477 }
478
479 fn encode_enum<'t>(
480 &mut self,
481 en: &'n library::Enum,
482 value: Value,
483 ) -> Result<Box<DeferCallback<'n, 't>>>
484 where
485 'n: 't,
486 {
487 let value = match value {
488 Value::Enum(name, inner) => {
489 if name == en.name {
490 *inner
491 } else {
492 return Err(Error::EncodeError(format!("Expected {}, got {}", en.name, name)));
493 }
494 }
495 _ => value,
496 };
497
498 for item in &en.members {
499 if !en.strict || item.value.cast_equals(&value) {
500 return self.encode_type(&en.ty, value);
501 }
502 }
503
504 Err(Error::EncodeError("Invalid enum variant".to_owned()))
505 }
506
507 fn encode_struct<'t>(
508 &mut self,
509 st: &'n library::Struct,
510 nullable: bool,
511 value: Value,
512 ) -> Result<Box<DeferCallback<'n, 't>>>
513 where
514 'n: 't,
515 {
516 let value = match (value, nullable) {
517 (Value::Null, true) => Ok(None),
518 (Value::Null, false) => Err(Error::EncodeError("Struct can't be null".to_owned())),
519 (value, _) => Ok(Some(value)),
520 }?;
521
522 if let Some(value) = value {
523 if nullable {
524 self.bytes.extend(&ALLOC_PRESENT_U64.to_le_bytes());
525 Ok(Box::new(move |this, counter| {
526 let counter = counter.next()?;
527 let call = this.encode_struct_nonnull(st, value, 0)?;
528 this.align_8();
529 call(this, counter)
530 }))
531 } else {
532 self.encode_struct_nonnull(st, value, 0)
533 }
534 } else {
535 self.bytes.extend(&0u64.to_le_bytes());
536 Ok(Box::new(|_, _| Ok(())))
537 }
538 }
539
540 fn encode_envelope<'t>(
541 &mut self,
542 ty: &'n library::Type,
543 value: Value,
544 ) -> Result<Box<DeferCallback<'n, 't>>>
545 where
546 'n: 't,
547 {
548 let header_pos = self.bytes.len();
549 self.bytes.extend(&[0u8; 8]);
550
551 if let Value::Null = value {
552 Ok(Box::new(|_, _| Ok(())))
553 } else {
554 Ok(Box::new(move |this, counter| {
555 let counter = counter.next()?;
556 let start = this.bytes.len();
557 let handle_start = this.handles.len();
558
559 let header = if ty.inline_size(this.ns)? > 4 {
560 let call = this.encode_type(ty, value)?;
561
562 this.align_8();
563 call(this, counter)?;
564 let size = (this.bytes.len() - start) as u32;
565 let handle_count = (this.handles.len() - handle_start) as u16;
566
567 debug_assert!(size > 0 || handle_count > 0);
568 let mut header = Vec::new();
569 header.extend(&size.to_le_bytes());
570 header.extend(&handle_count.to_le_bytes());
571 header.extend(&0u16.to_le_bytes());
572 header
573 } else {
574 let mut header_buf =
575 EncodeBuffer { ns: this.ns, bytes: Vec::new(), handles: Vec::new() };
576 header_buf.encode_type(ty, value)?(&mut header_buf, counter)?;
577 let EncodeBuffer { bytes: mut header, handles, .. } = header_buf;
578 header.resize(4, 0);
579 header.extend(&(handles.len() as u16).to_le_bytes());
580 header.extend(&1u16.to_le_bytes());
581 this.handles.extend(handles);
582 header
583 };
584
585 this.bytes.splice(header_pos..(header_pos + header.len()), header.into_iter());
586 Ok(())
587 }))
588 }
589 }
590
591 fn encode_union<'t>(
592 &mut self,
593 union: &'n library::TableOrUnion,
594 nullable: bool,
595 value: Value,
596 ) -> Result<Box<DeferCallback<'n, 't>>>
597 where
598 'n: 't,
599 {
600 let entry = match value {
601 Value::Null => Ok(None),
602 Value::Union(u, n, b) if *u == union.name => Ok(Some((n, *b))),
603 _ => Err(Error::EncodeError(format!("Expected {}", union.name))),
604 }?;
605
606 if let Some((variant, value)) = entry {
607 for member in union.members.values() {
608 if *member.name == *variant {
609 self.bytes.extend(&member.ordinal.to_le_bytes());
610 return self.encode_envelope(&member.ty, value);
611 }
612 }
613
614 Err(Error::EncodeError(format!("Unrecognized union variant: '{}'", variant)))
615 } else if nullable {
616 self.bytes.extend(std::iter::repeat(0u8).take(16));
617 Ok(Box::new(|_, _| Ok(())))
618 } else {
619 Err(Error::EncodeError("Got null for non-nullable Union".to_owned()))
620 }
621 }
622
623 fn encode_table<'t>(
624 &mut self,
625 table: &'n library::TableOrUnion,
626 value: Value,
627 ) -> Result<Box<DeferCallback<'n, 't>>>
628 where
629 'n: 't,
630 {
631 let values = match value {
632 Value::Object(values) => Ok(values),
633 _ => Err(Error::EncodeError(format!("Could not convert to {}", table.name))),
634 }?;
635
636 let mut values_array = Vec::new();
637 for (value_name, value) in values {
638 for (&ord, member) in &table.members {
639 let array_idx = usize::try_from(ord - 1).unwrap();
640 if values_array.len() <= array_idx {
641 values_array.resize_with(array_idx + 1, || None);
642 }
643 if *member.name == value_name {
644 values_array[array_idx] = Some((&member.ty, value));
645 break;
646 }
647 }
648 }
649
650 while values_array.last().map(|x| x.is_none()).unwrap_or(false) {
651 values_array.pop();
652 }
653
654 self.bytes.extend(&(values_array.len() as u64).to_le_bytes());
655 self.bytes.extend(&ALLOC_PRESENT_U64.to_le_bytes());
656
657 Ok(Box::new(move |this, counter| {
658 let counter = counter.next()?;
659 let mut calls = Vec::with_capacity(values_array.len());
660
661 for slot in values_array.into_iter() {
662 if let Some((ty, item)) = slot {
663 calls.push(this.encode_envelope(ty, item)?);
664 } else {
665 this.bytes.extend(&0u64.to_le_bytes());
666 }
667 }
668
669 for call in calls {
670 call(this, counter)?;
671 }
672
673 Ok(())
674 }))
675 }
676}
677
678pub fn encode_request(
680 ns: &library::Namespace,
681 txid: u32,
682 protocol_name: &str,
683 method_name: &str,
684 value: Value,
685) -> Result<(Vec<u8>, Vec<fidl::HandleDisposition<'static>>)> {
686 EncodeBuffer::encode_transaction(
687 ns,
688 txid,
689 protocol_name,
690 Direction::Request,
691 method_name,
692 value,
693 )
694}
695
696pub fn encode_response(
698 ns: &library::Namespace,
699 txid: u32,
700 protocol_name: &str,
701 method_name: &str,
702 value: Value,
703) -> Result<(Vec<u8>, Vec<fidl::HandleDisposition<'static>>)> {
704 EncodeBuffer::encode_transaction(
705 ns,
706 txid,
707 protocol_name,
708 Direction::Response,
709 method_name,
710 value,
711 )
712}
713
714pub fn encode(
716 ns: &library::Namespace,
717 type_name: &str,
718 nullable: bool,
719 value: Value,
720) -> Result<(Vec<u8>, Vec<fidl::HandleDisposition<'static>>)> {
721 let mut buf = EncodeBuffer { ns, bytes: Vec::new(), handles: Vec::new() };
722 let cb = buf.encode_identifier(type_name.to_owned(), nullable, value)?;
723 buf.align_8();
724 cb(&mut buf, RecursionCounter::new()).map(|_| (buf.bytes, buf.handles))
725}