1#![recursion_limit = "256"]
7
8use perfetto_protos::perfetto::protos::{
9 DisableTracingRequest, EnableTracingRequest, FreeBuffersRequest, GetAsyncCommandRequest,
10 GetAsyncCommandResponse, InitializeConnectionRequest, InitializeConnectionResponse, IpcFrame,
11 ReadBuffersRequest, RegisterDataSourceRequest, RegisterDataSourceResponse, ipc_frame,
12};
13use prost::Message;
14use starnix_core::task::{CurrentTask, EventHandler, Waiter};
15use starnix_core::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
16use starnix_core::vfs::socket::{
17 SocketDomain, SocketFile, SocketPeer, SocketProtocol, SocketType, resolve_unix_socket_address,
18};
19use starnix_core::vfs::{FileHandle, FsStr};
20use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Unlocked};
21use starnix_uapi::errors::Errno;
22use starnix_uapi::open_flags::OpenFlags;
23use starnix_uapi::vfs::FdEvents;
24use std::collections::VecDeque;
25use thiserror::Error;
26
27pub struct IpcConnection {
33 file: FileHandle,
35 request_id: u64,
38}
39
40#[derive(Error, Debug)]
41pub enum IpcWriteError {
42 #[error(transparent)]
43 Encode(#[from] prost::EncodeError),
44 #[error(transparent)]
45 Write(#[from] Errno),
46 #[error("TooLong: {0} exceeds max for u32")]
47 TooLong(usize),
48}
49
50#[derive(Error, Debug)]
51pub enum IpcReadError {
52 #[error(transparent)]
53 Decode(#[from] prost::DecodeError),
54 #[error(transparent)]
55 Read(#[from] Errno),
56}
57
58#[derive(Error, Debug)]
59pub enum InvokeMethodError {
60 #[error("could not not look up method name: {0}")]
61 InvalidMethod(String),
62 #[error(transparent)]
63 IpcWrite(#[from] IpcWriteError),
64 #[error(transparent)]
65 IpcRead(#[from] IpcReadError),
66 #[error("unexpected response: {0}")]
67 InvalidResponse(String),
68}
69
70#[derive(Error, Debug)]
71pub enum ProducerError {
72 #[error(transparent)]
73 InvokeMethod(#[from] InvokeMethodError),
74 #[error(transparent)]
75 IpcWrite(#[from] IpcWriteError),
76 #[error(transparent)]
77 IpcRead(#[from] IpcReadError),
78 #[error("unexpected response: {0}")]
79 InvalidResponse(String),
80 #[error(transparent)]
81 Decode(#[from] prost::DecodeError),
82}
83
84impl IpcConnection {
85 pub fn new(file: FileHandle) -> Self {
86 Self { file, request_id: 0 }
87 }
88
89 pub fn bind_service<L>(
90 &mut self,
91 service_name: &str,
92 locked: &mut Locked<L>,
93 current_task: &CurrentTask,
94 ) -> Result<(), IpcWriteError>
95 where
96 L: LockEqualOrBefore<FileOpsCore>,
97 {
98 let bind_service_message = IpcFrame {
101 request_id: Some(1),
102 msg: Some(ipc_frame::Msg::MsgBindService(ipc_frame::BindService {
103 service_name: Some(service_name.to_string()),
104 })),
105 ..Default::default()
106 };
107 self.write_frame(bind_service_message, locked, current_task)
108 }
109
110 pub fn invoke_method<L>(
111 &mut self,
112 service_id: u32,
113 method_id: u32,
114 arguments: Option<Vec<u8>>,
115 locked: &mut Locked<L>,
116 current_task: &CurrentTask,
117 ) -> Result<(), IpcWriteError>
118 where
119 L: LockEqualOrBefore<FileOpsCore>,
120 {
121 let msg = IpcFrame {
122 request_id: Some(self.allocate_request_id()),
123 msg: Some(ipc_frame::Msg::MsgInvokeMethod(ipc_frame::InvokeMethod {
124 service_id: Some(service_id),
125 method_id: Some(method_id),
126 args_proto: arguments,
127 drop_reply: None,
128 })),
129 ..Default::default()
130 };
131 self.write_frame(msg, locked, current_task)
132 }
133
134 fn write_frame<L>(
135 &mut self,
136 frame: IpcFrame,
137 locked: &mut Locked<L>,
138 current_task: &CurrentTask,
139 ) -> Result<(), IpcWriteError>
140 where
141 L: LockEqualOrBefore<FileOpsCore>,
142 {
143 let frame_len = u32::try_from(frame.encoded_len())
145 .map_err(|_| IpcWriteError::TooLong(frame.encoded_len()))?;
146 let mut bind_service_bytes =
147 Vec::with_capacity(frame.encoded_len() + std::mem::size_of::<u32>());
148 bind_service_bytes.extend_from_slice(&frame_len.to_le_bytes());
149 frame.encode(&mut bind_service_bytes)?;
150 let mut bind_service_buffer: VecInputBuffer = bind_service_bytes.into();
151 self.file.write(locked, current_task, &mut bind_service_buffer)?;
152 Ok(())
153 }
154
155 fn allocate_request_id(&mut self) -> u64 {
157 let id = self.request_id;
158 self.request_id += 1;
159 id
160 }
161}
162
163pub struct FrameReader {
169 file: FileHandle,
171 read_buffer: VecOutputBuffer,
176 data: VecDeque<u8>,
178 next_message_size: Option<usize>,
181}
182
183impl FrameReader {
184 pub fn new(file: FileHandle) -> Self {
185 Self {
186 file,
187 read_buffer: VecOutputBuffer::new(4096),
188 data: VecDeque::with_capacity(4096),
189 next_message_size: None,
190 }
191 }
192
193 pub fn next_frame_blocking<L>(
195 &mut self,
196 locked: &mut Locked<L>,
197 current_task: &CurrentTask,
198 ) -> Result<IpcFrame, IpcReadError>
199 where
200 L: LockEqualOrBefore<FileOpsCore>,
201 {
202 loop {
203 if self.next_message_size.is_none() && self.data.len() >= 4 {
204 let len_bytes: [u8; 4] = self
205 .data
206 .drain(..4)
207 .collect::<Vec<_>>()
208 .try_into()
209 .expect("self.data has at least 4 elements");
210 self.next_message_size = Some(u32::from_le_bytes(len_bytes) as usize);
211 }
212 if let Some(message_size) = self.next_message_size {
213 if self.data.len() >= message_size {
214 let message: Vec<u8> = self.data.drain(..message_size).collect();
215 self.next_message_size = None;
216 return Ok(IpcFrame::decode(message.as_slice())?);
217 }
218 }
219
220 let waiter = Waiter::new();
221 self.file.wait_async(
222 locked,
223 current_task,
224 &waiter,
225 FdEvents::POLLIN,
226 EventHandler::None,
227 );
228 while self.file.query_events(locked, current_task)? & FdEvents::POLLIN
229 != FdEvents::POLLIN
230 {
231 waiter.wait(locked, current_task)?;
232 }
233 self.file.read(locked, current_task, &mut self.read_buffer)?;
234 self.data.extend(self.read_buffer.data());
235 self.read_buffer.reset();
236 }
237 }
238}
239
240pub struct Consumer {
242 conn_file: FileHandle,
245 frame_reader: FrameReader,
247 bind_service_reply: ipc_frame::BindServiceReply,
250 request_id: u64,
252}
253
254impl Consumer {
255 pub fn new(
258 locked: &mut Locked<Unlocked>,
259 current_task: &CurrentTask,
260 socket_path: &FsStr,
261 ) -> Result<Self, anyhow::Error> {
262 let conn_file = SocketFile::new_socket(
263 locked,
264 current_task,
265 SocketDomain::Unix,
266 SocketType::Stream,
267 OpenFlags::RDWR,
268 SocketProtocol::from_raw(0),
269 false,
270 )?;
271 let conn = SocketFile::get_from_file(&conn_file)?;
272 let peer =
273 SocketPeer::Handle(resolve_unix_socket_address(locked, current_task, socket_path)?);
274 conn.connect(locked, current_task, peer)?;
275 let mut frame_reader = FrameReader::new(conn_file.clone());
276 let mut request_id = 1;
277
278 let bind_service_message = IpcFrame {
279 request_id: Some(request_id),
280 data_for_testing: Vec::new(),
281 msg: Some(ipc_frame::Msg::MsgBindService(ipc_frame::BindService {
282 service_name: Some("ConsumerPort".to_string()),
283 })),
284 };
285 request_id += 1;
286 let mut bind_service_bytes =
287 Vec::with_capacity(bind_service_message.encoded_len() + std::mem::size_of::<u32>());
288 bind_service_bytes.extend_from_slice(
289 &u32::try_from(bind_service_message.encoded_len()).unwrap().to_le_bytes(),
290 );
291 bind_service_message.encode(&mut bind_service_bytes)?;
292 let mut bind_service_buffer: VecInputBuffer = bind_service_bytes.into();
293 conn.file().write(locked, current_task, &mut bind_service_buffer)?;
294
295 let reply_frame = frame_reader.next_frame_blocking(locked, current_task)?;
296
297 let bind_service_reply = match reply_frame.msg {
298 Some(ipc_frame::Msg::MsgBindServiceReply(reply)) => reply,
299 m => return Err(anyhow::anyhow!("Got unexpected reply message: {:?}", m)),
300 };
301
302 Ok(Self { conn_file, frame_reader, bind_service_reply, request_id })
303 }
304
305 fn send_message<L>(
306 &mut self,
307 locked: &mut Locked<L>,
308 current_task: &CurrentTask,
309 msg: ipc_frame::Msg,
310 ) -> Result<u64, anyhow::Error>
311 where
312 L: LockEqualOrBefore<FileOpsCore>,
313 {
314 let request_id = self.request_id;
315 let frame =
316 IpcFrame { request_id: Some(request_id), data_for_testing: Vec::new(), msg: Some(msg) };
317
318 self.request_id += 1;
319
320 let mut frame_bytes = Vec::with_capacity(frame.encoded_len() + std::mem::size_of::<u32>());
321 frame_bytes.extend_from_slice(&u32::try_from(frame.encoded_len())?.to_le_bytes());
322 frame.encode(&mut frame_bytes)?;
323 let mut buffer: VecInputBuffer = frame_bytes.into();
324 self.conn_file.write(locked, current_task, &mut buffer)?;
325
326 Ok(request_id)
327 }
328
329 fn method_id(&self, name: &str) -> Result<u32, anyhow::Error> {
330 for method in &self.bind_service_reply.methods {
331 if let Some(method_name) = method.name.as_ref() {
332 if method_name == name {
333 if let Some(id) = method.id {
334 return Ok(id);
335 } else {
336 return Err(anyhow::anyhow!(
337 "Matched method name {} but found no id",
338 method_name
339 ));
340 }
341 }
342 }
343 }
344 Err(anyhow::anyhow!("Did not find method {}", name))
345 }
346
347 pub fn enable_tracing<L>(
348 &mut self,
349 locked: &mut Locked<L>,
350 current_task: &CurrentTask,
351 req: EnableTracingRequest,
352 ) -> Result<u64, anyhow::Error>
353 where
354 L: LockEqualOrBefore<FileOpsCore>,
355 {
356 let method_id = self.method_id("EnableTracing")?;
357 let mut encoded_args: Vec<u8> = Vec::with_capacity(req.encoded_len());
358 req.encode(&mut encoded_args)?;
359
360 self.send_message(
361 locked,
362 current_task,
363 ipc_frame::Msg::MsgInvokeMethod(ipc_frame::InvokeMethod {
364 service_id: self.bind_service_reply.service_id,
365 method_id: Some(method_id),
366 args_proto: Some(encoded_args),
367 drop_reply: None,
368 }),
369 )
370 }
371
372 pub fn disable_tracing<L>(
373 &mut self,
374 locked: &mut Locked<L>,
375 current_task: &CurrentTask,
376 req: DisableTracingRequest,
377 ) -> Result<u64, anyhow::Error>
378 where
379 L: LockEqualOrBefore<FileOpsCore>,
380 {
381 let method_id = self.method_id("DisableTracing")?;
382 let mut encoded_args: Vec<u8> = Vec::with_capacity(req.encoded_len());
383 req.encode(&mut encoded_args)?;
384
385 self.send_message(
386 locked,
387 current_task,
388 ipc_frame::Msg::MsgInvokeMethod(ipc_frame::InvokeMethod {
389 service_id: self.bind_service_reply.service_id,
390 method_id: Some(method_id),
391 args_proto: Some(encoded_args),
392 drop_reply: None,
393 }),
394 )
395 }
396
397 pub fn read_buffers<L>(
398 &mut self,
399 locked: &mut Locked<L>,
400 current_task: &CurrentTask,
401 req: ReadBuffersRequest,
402 ) -> Result<u64, anyhow::Error>
403 where
404 L: LockEqualOrBefore<FileOpsCore>,
405 {
406 let method_id = self.method_id("ReadBuffers")?;
407 let mut encoded_args: Vec<u8> = Vec::with_capacity(req.encoded_len());
408 req.encode(&mut encoded_args)?;
409
410 self.send_message(
411 locked,
412 current_task,
413 ipc_frame::Msg::MsgInvokeMethod(ipc_frame::InvokeMethod {
414 service_id: self.bind_service_reply.service_id,
415 method_id: Some(method_id),
416 args_proto: Some(encoded_args),
417 drop_reply: None,
418 }),
419 )
420 }
421
422 pub fn free_buffers<L>(
423 &mut self,
424 locked: &mut Locked<L>,
425 current_task: &CurrentTask,
426 req: FreeBuffersRequest,
427 ) -> Result<u64, anyhow::Error>
428 where
429 L: LockEqualOrBefore<FileOpsCore>,
430 {
431 let method_id = self.method_id("FreeBuffers")?;
432 let mut encoded_args: Vec<u8> = Vec::with_capacity(req.encoded_len());
433 req.encode(&mut encoded_args)?;
434
435 self.send_message(
436 locked,
437 current_task,
438 ipc_frame::Msg::MsgInvokeMethod(ipc_frame::InvokeMethod {
439 service_id: self.bind_service_reply.service_id,
440 method_id: Some(method_id),
441 args_proto: Some(encoded_args),
442 drop_reply: None,
443 }),
444 )
445 }
446
447 pub fn next_frame_blocking<L>(
448 &mut self,
449 locked: &mut Locked<L>,
450 current_task: &CurrentTask,
451 ) -> Result<IpcFrame, IpcReadError>
452 where
453 L: LockEqualOrBefore<FileOpsCore>,
454 {
455 self.frame_reader.next_frame_blocking(locked, current_task)
456 }
457}
458
459pub struct Producer {
466 frame_reader: FrameReader,
468
469 ipc_connection: IpcConnection,
471
472 service_id: u32,
475
476 method_map: std::collections::HashMap<String, u32>,
479}
480
481impl Producer {
482 pub fn new<L>(
485 locked: &mut Locked<L>,
486 current_task: &CurrentTask,
487 socket: FileHandle,
488 ) -> Result<Self, ProducerError>
489 where
490 L: LockEqualOrBefore<FileOpsCore>,
491 {
492 let mut producer = Self {
493 frame_reader: FrameReader::new(socket.clone()),
494 ipc_connection: IpcConnection::new(socket),
495 service_id: 0,
496 method_map: std::collections::HashMap::new(),
497 };
498
499 producer.ipc_connection.bind_service("ProducerPort", locked, current_task)?;
502
503 let reply_frame = producer.frame_reader.next_frame_blocking(locked, current_task)?;
507
508 let ipc_frame::BindServiceReply { success, service_id, methods } = match reply_frame.msg {
509 Some(ipc_frame::Msg::MsgBindServiceReply(reply)) => reply,
510 m => {
511 return Err(ProducerError::InvalidResponse(format!(
512 "Got unexpected reply message: {:?}",
513 m
514 )));
515 }
516 };
517
518 if !success.unwrap_or(false) {
519 return Err(ProducerError::InvalidResponse("Bind to socket failed".into()));
520 }
521
522 producer.method_map = methods
524 .into_iter()
525 .flat_map(|ipc_frame::bind_service_reply::MethodInfo { id, name }| match (id, name) {
526 (Some(id), Some(name)) => Some((name, id)),
527 _ => None,
528 })
529 .collect();
530 if let Some(service_id) = service_id {
531 producer.service_id = service_id
532 } else {
533 return Err(ProducerError::InvalidResponse(
534 "BindServiceReply did not include service_id".into(),
535 ));
536 }
537
538 Ok(producer)
539 }
540
541 pub fn initialize_connection<L>(
544 &mut self,
545 request: InitializeConnectionRequest,
546 locked: &mut Locked<L>,
547 current_task: &CurrentTask,
548 ) -> Result<InitializeConnectionResponse, ProducerError>
549 where
550 L: LockEqualOrBefore<FileOpsCore>,
551 {
552 let (Some(reply), has_more) = self.invoke_method(
553 "InitializeConnection",
554 Some(request.encode_to_vec()),
555 locked,
556 current_task,
557 )?
558 else {
559 return Err(ProducerError::InvalidResponse("expected a response but got none".into()));
560 };
561 if has_more {
562 return Err(ProducerError::InvalidResponse(
563 "InitializeConnection should not stream but got a streaming response".into(),
564 ));
565 }
566 Ok(InitializeConnectionResponse::decode(reply.as_ref())?)
567 }
568
569 pub fn register_data_source<L>(
571 &mut self,
572 request: RegisterDataSourceRequest,
573 locked: &mut Locked<L>,
574 current_task: &CurrentTask,
575 ) -> Result<RegisterDataSourceResponse, ProducerError>
576 where
577 L: LockEqualOrBefore<FileOpsCore>,
578 {
579 let (Some(reply), has_more) = self.invoke_method(
580 "RegisterDataSource",
581 Some(request.encode_to_vec()),
582 locked,
583 current_task,
584 )?
585 else {
586 return Err(ProducerError::InvalidResponse(
587 "RegisterDataSource expected a response but got none".into(),
588 ));
589 };
590 if has_more {
591 return Err(ProducerError::InvalidResponse(
592 "RegisterDataSource should not stream but got a streaming response".into(),
593 ));
594 }
595 Ok(RegisterDataSourceResponse::decode(reply.as_ref())?)
596 }
597
598 fn invoke_method<L>(
600 &mut self,
601 method_name: &str,
602 arguments: Option<Vec<u8>>,
603 locked: &mut Locked<L>,
604 current_task: &CurrentTask,
605 ) -> Result<(Option<Vec<u8>>, bool), InvokeMethodError>
606 where
607 L: LockEqualOrBefore<FileOpsCore>,
608 {
609 self.invoke_method_inner(method_name, arguments, locked, current_task)?;
610
611 let reply_frame = self.frame_reader.next_frame_blocking(locked, current_task)?;
612
613 let ipc_frame::InvokeMethodReply { success, has_more, reply_proto } = match reply_frame.msg
614 {
615 Some(ipc_frame::Msg::MsgInvokeMethodReply(reply)) => reply,
616 m => {
617 return Err(InvokeMethodError::InvalidResponse(format!(
618 "unexpected reply message: {:?}",
619 m
620 )));
621 }
622 };
623 match success {
624 Some(true) => Ok((reply_proto, has_more.unwrap_or(false))),
625 _ => {
626 return Err(InvokeMethodError::InvalidResponse(format!(
627 "InvokeMethod Reply did not succeed. Reply: success: {:?}, has_more: {:?}, proto: {:?}",
628 success, has_more, reply_proto,
629 )));
630 }
631 }
632 }
633
634 fn invoke_method_inner<L>(
636 &mut self,
637 method_name: &str,
638 arguments: Option<Vec<u8>>,
639 locked: &mut Locked<L>,
640 current_task: &CurrentTask,
641 ) -> Result<(), InvokeMethodError>
642 where
643 L: LockEqualOrBefore<FileOpsCore>,
644 {
645 let Some(method_id) = self.method_map.get(method_name).copied() else {
646 return Err(InvokeMethodError::InvalidMethod(method_name.into()));
647 };
648 self.ipc_connection.invoke_method(
649 self.service_id,
650 method_id,
651 arguments,
652 locked,
653 current_task,
654 )?;
655 Ok(())
656 }
657
658 pub fn get_command_request<L>(
661 &mut self,
662 locked: &mut Locked<L>,
663 current_task: &CurrentTask,
664 ) -> Result<(), ProducerError>
665 where
666 L: LockEqualOrBefore<FileOpsCore>,
667 {
668 Ok(self.invoke_method_inner(
669 "GetAsyncCommand",
670 Some(GetAsyncCommandRequest {}.encode_to_vec()),
671 locked,
672 current_task,
673 )?)
674 }
675
676 pub fn get_command_response<L>(
678 &mut self,
679 locked: &mut Locked<L>,
680 current_task: &CurrentTask,
681 ) -> Result<(Option<GetAsyncCommandResponse>, bool), ProducerError>
682 where
683 L: LockEqualOrBefore<FileOpsCore>,
684 {
685 let reply_frame = self.frame_reader.next_frame_blocking(locked, current_task)?;
686
687 let ipc_frame::InvokeMethodReply { success, has_more, reply_proto } = match reply_frame.msg
688 {
689 Some(ipc_frame::Msg::MsgInvokeMethodReply(reply)) => reply,
690 m => {
691 return Err(ProducerError::InvalidResponse(format!(
692 "Got unexpected reply message: {:?}",
693 m
694 )));
695 }
696 };
697 if !success.unwrap_or(false) {
698 return Err(ProducerError::InvalidResponse(format!(
699 "InvokeMethod Reply did not include success. Reply: success: {:?}, has_more: {:?}, proto: {:?}",
700 success, has_more, reply_proto
701 )));
702 }
703
704 let Some(reply_proto) = reply_proto else {
705 return Err(ProducerError::InvalidResponse(
706 "InvokeMethod reply didn't include a proto".into(),
707 ));
708 };
709 Ok((
710 Some(GetAsyncCommandResponse::decode(reply_proto.as_ref())?),
711 has_more.unwrap_or(false),
712 ))
713 }
714}