Skip to main content

bt_gatt/
test_utils.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use bt_common::core::{Address, AddressType};
6use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
7use futures::future::{ready, Ready};
8use futures::stream::{FusedStream, Stream};
9use parking_lot::Mutex;
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::sync::Arc;
12use std::task::{Poll, Waker};
13
14use bt_common::{PeerId, Uuid};
15
16use crate::central::ScanResult;
17use crate::client::CharacteristicNotification;
18
19use crate::periodic_advertising::{PeriodicAdvertising, SyncReport};
20use crate::pii::GetPeerAddr;
21use crate::server::{
22    self, LocalService, NotificationType, ReadResponder, ServiceDefinition, WriteResponder,
23};
24use crate::{types::*, GattTypes, ServerTypes};
25
26#[derive(Default)]
27struct FakePeerServiceInner {
28    // Notifier that's used to send out notification.
29    notifiers: HashMap<Handle, UnboundedSender<Result<CharacteristicNotification>>>,
30
31    // Characteristics to return when `read_characteristic` and `discover_characteristics` are
32    // called.
33    characteristics: HashMap<Handle, (Characteristic, Vec<u8>)>,
34}
35
36#[derive(Clone)]
37pub struct FakePeerService {
38    inner: Arc<Mutex<FakePeerServiceInner>>,
39}
40
41impl FakePeerService {
42    pub fn new() -> Self {
43        Self { inner: Arc::new(Mutex::new(Default::default())) }
44    }
45
46    // Adds a characteristic so that it can be returned when discover/read method is
47    // called.
48    // Also triggers sending a characteristic value change notification to be sent.
49    pub fn add_characteristic(&mut self, char: Characteristic, value: Vec<u8>) {
50        let mut lock = self.inner.lock();
51        let handle = char.handle;
52        lock.characteristics.insert(handle, (char, value.clone()));
53        if let Some(notifier) = lock.notifiers.get_mut(&handle) {
54            notifier
55                .unbounded_send(Ok(CharacteristicNotification {
56                    handle,
57                    value,
58                    maybe_truncated: false,
59                }))
60                .expect("should succeed");
61        }
62    }
63
64    // Sets expected characteristic value so that it can be used for validation when
65    // write method is called.
66    pub fn expect_characteristic_value(&mut self, handle: &Handle, value: Vec<u8>) {
67        let mut lock = self.inner.lock();
68        let Some(char) = lock.characteristics.get_mut(handle) else {
69            panic!("Can't find characteristic {handle:?} to set expected value");
70        };
71        char.1 = value;
72    }
73
74    /// Sends a notification on the characteristic with the provided `handle`.
75    pub fn notify(&self, handle: &Handle, notification: Result<CharacteristicNotification>) {
76        let mut lock = self.inner.lock();
77        if let Some(notifier) = lock.notifiers.get_mut(handle) {
78            notifier.unbounded_send(notification).expect("can send notification");
79        }
80    }
81
82    /// Removes the notification subscription for the characteristic with the
83    /// provided `handle`.
84    pub fn clear_notifier(&self, handle: &Handle) {
85        let mut lock = self.inner.lock();
86        let _ = lock.notifiers.remove(handle);
87    }
88}
89
90impl crate::client::PeerService<FakeTypes> for FakePeerService {
91    fn discover_characteristics(
92        &self,
93        uuid: Option<Uuid>,
94    ) -> <FakeTypes as GattTypes>::CharacteristicDiscoveryFut {
95        let lock = self.inner.lock();
96        let mut result = Vec::new();
97        for (_handle, (char, _value)) in &lock.characteristics {
98            match uuid {
99                Some(uuid) if uuid == char.uuid => result.push(char.clone()),
100                None => result.push(char.clone()),
101                _ => {}
102            }
103        }
104        ready(Ok(result))
105    }
106
107    fn read_characteristic<'a>(
108        &self,
109        handle: &Handle,
110        _offset: u16,
111        buf: &'a mut [u8],
112    ) -> <FakeTypes as GattTypes>::ReadFut<'a> {
113        let read_characteristics = &(*self.inner.lock()).characteristics;
114        let Some((_, value)) = read_characteristics.get(handle) else {
115            return ready(Err(Error::Gatt(GattError::InvalidHandle)));
116        };
117        buf[..value.len()].copy_from_slice(value.as_slice());
118        ready(Ok((value.len(), false)))
119    }
120
121    // For testing, should call `expect_characteristic_value` with the expected
122    // value.
123    fn write_characteristic<'a>(
124        &self,
125        handle: &Handle,
126        _mode: WriteMode,
127        _offset: u16,
128        buf: &'a [u8],
129    ) -> <FakeTypes as GattTypes>::WriteFut<'a> {
130        let expected_characteristics = &(*self.inner.lock()).characteristics;
131        // The write operation was not expected.
132        let Some((_, expected)) = expected_characteristics.get(handle) else {
133            panic!("Write operation to characteristic {handle:?} was not expected");
134        };
135        // Value written was not expected.
136        if buf.len() != expected.len() || &buf[..expected.len()] != expected.as_slice() {
137            panic!("Value written to characteristic {handle:?} was not expected: {buf:?}");
138        }
139        ready(Ok(()))
140    }
141
142    fn read_descriptor<'a>(
143        &self,
144        _handle: &Handle,
145        _offset: u16,
146        _buf: &'a mut [u8],
147    ) -> <FakeTypes as GattTypes>::ReadFut<'a> {
148        todo!()
149    }
150
151    fn write_descriptor<'a>(
152        &self,
153        _handle: &Handle,
154        _offset: u16,
155        _buf: &'a [u8],
156    ) -> <FakeTypes as GattTypes>::WriteFut<'a> {
157        todo!()
158    }
159
160    fn subscribe(&self, handle: &Handle) -> <FakeTypes as GattTypes>::NotificationStream {
161        let (sender, receiver) = unbounded();
162        (*self.inner.lock()).notifiers.insert(*handle, sender);
163        receiver
164    }
165}
166
167#[derive(Clone)]
168pub struct FakeServiceHandle {
169    pub uuid: Uuid,
170    pub is_primary: bool,
171    pub fake_service: FakePeerService,
172}
173
174impl crate::client::PeerServiceHandle<FakeTypes> for FakeServiceHandle {
175    fn uuid(&self) -> Uuid {
176        self.uuid
177    }
178
179    fn is_primary(&self) -> bool {
180        self.is_primary
181    }
182
183    fn connect(&self) -> <FakeTypes as GattTypes>::ServiceConnectFut {
184        futures::future::ready(Ok(self.fake_service.clone()))
185    }
186}
187
188#[derive(Default)]
189struct FakeClientInner {
190    fake_services: Vec<FakeServiceHandle>,
191}
192
193#[derive(Clone)]
194pub struct FakeClient {
195    inner: Arc<Mutex<FakeClientInner>>,
196}
197
198impl FakeClient {
199    pub fn new() -> Self {
200        FakeClient { inner: Arc::new(Mutex::new(FakeClientInner::default())) }
201    }
202
203    /// Add a fake peer service to this client.
204    pub fn add_service(&mut self, uuid: Uuid, is_primary: bool, fake_service: FakePeerService) {
205        self.inner.lock().fake_services.push(FakeServiceHandle { uuid, is_primary, fake_service });
206    }
207}
208
209impl crate::Client<FakeTypes> for FakeClient {
210    fn peer_id(&self) -> PeerId {
211        todo!()
212    }
213
214    fn find_service(&self, uuid: Uuid) -> <FakeTypes as GattTypes>::FindServicesFut {
215        let fake_services = &self.inner.lock().fake_services;
216        let mut filtered_services = Vec::new();
217        for handle in fake_services {
218            if handle.uuid == uuid {
219                filtered_services.push(handle.clone());
220            }
221        }
222
223        futures::future::ready(Ok(filtered_services))
224    }
225}
226
227#[derive(Default, Debug)]
228struct ScannedResultStreamInner {
229    results: VecDeque<Result<ScanResult>>,
230    waker: Option<Waker>,
231}
232
233#[derive(Clone, Debug, Default)]
234pub struct ScannedResultStreamController(Arc<Mutex<ScannedResultStreamInner>>);
235
236impl ScannedResultStreamController {
237    /// Add a single scanned result item to output from the stream.
238    pub fn add_scanned_result(&self, item: Result<ScanResult>) {
239        let mut lock = self.0.lock();
240        lock.results.push_back(item);
241        if let Some(waker) = lock.waker.take() {
242            waker.wake();
243        }
244    }
245}
246
247#[derive(Debug, Default)]
248pub struct ScannedResultStream {
249    inner: Arc<Mutex<ScannedResultStreamInner>>,
250}
251
252impl ScannedResultStream {
253    /// Creates a new ScannedResultStream.
254    /// Client can get a ScannedResultStreamController using the `controller`
255    /// method.
256    pub fn new() -> Self {
257        Self::default()
258    }
259
260    pub fn controller(&self) -> ScannedResultStreamController {
261        ScannedResultStreamController(self.inner.clone())
262    }
263}
264
265impl FusedStream for ScannedResultStream {
266    fn is_terminated(&self) -> bool {
267        self.inner.lock().results.is_empty()
268    }
269}
270
271impl Stream for ScannedResultStream {
272    type Item = Result<ScanResult>;
273
274    fn poll_next(
275        self: std::pin::Pin<&mut Self>,
276        cx: &mut std::task::Context<'_>,
277    ) -> Poll<Option<Self::Item>> {
278        let mut lock = self.inner.lock();
279        match lock.results.pop_front() {
280            Some(result) => Poll::Ready(Some(result)),
281            None => {
282                lock.waker = Some(cx.waker().clone());
283                Poll::Pending
284            }
285        }
286    }
287}
288
289/// Implements a fake [`GetPeerAddr`] that just converts the peer_id into a
290/// public [`Address`] based on the given peer_id.
291pub struct FakeGetPeerAddr;
292
293impl GetPeerAddr for FakeGetPeerAddr {
294    async fn get_peer_address(&self, peer_id: PeerId) -> Result<(Address, AddressType)> {
295        Ok((
296            [
297                peer_id.0 as u8,
298                ((peer_id.0 >> 8) & 0xff) as u8,
299                ((peer_id.0 >> 16) & 0xff) as u8,
300                ((peer_id.0 >> 24) & 0xff) as u8,
301                ((peer_id.0 >> 32) & 0xff) as u8,
302                ((peer_id.0 >> 48) & 0xff) as u8,
303            ],
304            AddressType::Public,
305        ))
306    }
307}
308
309pub struct FakeTypes {}
310
311impl GattTypes for FakeTypes {
312    type Central = FakeCentral;
313    type ScanResultStream = ScannedResultStream;
314    type Client = FakeClient;
315    type ConnectFuture = Ready<Result<FakeClient>>;
316    type PeerServiceHandle = FakeServiceHandle;
317    type FindServicesFut = Ready<Result<Vec<FakeServiceHandle>>>;
318    type PeerService = FakePeerService;
319    type ServiceConnectFut = Ready<Result<FakePeerService>>;
320    type CharacteristicDiscoveryFut = Ready<Result<Vec<Characteristic>>>;
321    type NotificationStream = UnboundedReceiver<Result<CharacteristicNotification>>;
322    type ReadFut<'a> = Ready<Result<(usize, bool)>>;
323    type WriteFut<'a> = Ready<Result<()>>;
324    type PeriodicAdvertising = FakePeriodicAdvertising;
325}
326
327impl ServerTypes for FakeTypes {
328    type Server = FakeServer;
329    type LocalService = FakeLocalService;
330    type LocalServiceFut = Ready<Result<FakeLocalService>>;
331    type ServiceEventStream = UnboundedReceiver<Result<server::ServiceEvent<FakeTypes>>>;
332    type ServiceWriteType = Vec<u8>;
333    type ReadResponder = FakeResponder;
334    type WriteResponder = FakeResponder;
335    type IndicateConfirmationStream = UnboundedReceiver<Result<server::ConfirmationEvent>>;
336}
337
338pub struct FakePeriodicAdvertising;
339
340impl PeriodicAdvertising for FakePeriodicAdvertising {
341    type SyncFut = Ready<Result<Self::SyncStream>>;
342    type SyncStream = futures::stream::Empty<Result<SyncReport>>;
343
344    fn sync_to_advertising_reports(
345        _peer_id: PeerId,
346        _advertising_sid: u8,
347        _config: crate::periodic_advertising::SyncConfiguration,
348    ) -> Self::SyncFut {
349        unimplemented!()
350    }
351}
352
353#[derive(Default)]
354pub struct FakeCentralInner {
355    clients: HashMap<PeerId, FakeClient>,
356}
357
358#[derive(Clone)]
359pub struct FakeCentral {
360    inner: Arc<Mutex<FakeCentralInner>>,
361}
362
363impl FakeCentral {
364    pub fn new() -> Self {
365        Self { inner: Arc::new(Mutex::new(FakeCentralInner::default())) }
366    }
367
368    pub fn add_client(&mut self, peer_id: PeerId, client: FakeClient) {
369        let _ = self.inner.lock().clients.insert(peer_id, client);
370    }
371}
372
373impl crate::Central<FakeTypes> for FakeCentral {
374    fn scan(&self, _filters: &[crate::central::ScanFilter]) -> ScannedResultStream {
375        ScannedResultStream::default()
376    }
377
378    fn connect(&self, peer_id: PeerId) -> <FakeTypes as GattTypes>::ConnectFuture {
379        let clients = &self.inner.lock().clients;
380        let res = match clients.get(&peer_id) {
381            Some(client) => Ok(client.clone()),
382            None => Err(Error::PeerDisconnected(peer_id)),
383        };
384        futures::future::ready(res)
385    }
386
387    fn periodic_advertising(&self) -> Result<<FakeTypes as GattTypes>::PeriodicAdvertising> {
388        unimplemented!()
389    }
390}
391
392#[derive(Debug)]
393pub enum FakeServerEvent {
394    ReadResponded {
395        service_id: server::ServiceId,
396        handle: Handle,
397        value: Result<Vec<u8>>,
398    },
399    WriteResponded {
400        service_id: server::ServiceId,
401        handle: Handle,
402        value: Result<()>,
403    },
404    Notified {
405        service_id: server::ServiceId,
406        handle: Handle,
407        value: Vec<u8>,
408        peers: Vec<PeerId>,
409    },
410    Indicated {
411        service_id: server::ServiceId,
412        handle: Handle,
413        value: Vec<u8>,
414        peers: Vec<PeerId>,
415        confirmations: UnboundedSender<Result<server::ConfirmationEvent>>,
416    },
417    Unpublished {
418        id: server::ServiceId,
419    },
420    Published {
421        id: server::ServiceId,
422        definition: ServiceDefinition,
423    },
424}
425
426#[derive(Debug)]
427struct FakeServerInner {
428    services: HashMap<server::ServiceId, ServiceDefinition>,
429    service_senders:
430        HashMap<server::ServiceId, UnboundedSender<Result<server::ServiceEvent<FakeTypes>>>>,
431    sender: UnboundedSender<FakeServerEvent>,
432    notification_peers: HashSet<PeerId>,
433    indication_peers: HashSet<PeerId>,
434}
435
436#[derive(Clone, Debug)]
437pub struct FakeServer {
438    inner: Arc<Mutex<FakeServerInner>>,
439}
440
441impl server::Server<FakeTypes> for FakeServer {
442    fn prepare(
443        &self,
444        service: server::ServiceDefinition,
445    ) -> <FakeTypes as ServerTypes>::LocalServiceFut {
446        let id = service.id();
447        self.inner.lock().services.insert(id, service);
448        futures::future::ready(Ok(FakeLocalService::new(id, self.inner.clone())))
449    }
450}
451
452impl FakeServer {
453    pub fn new() -> (Self, UnboundedReceiver<FakeServerEvent>) {
454        let (sender, receiver) = futures::channel::mpsc::unbounded();
455        (
456            Self {
457                inner: Arc::new(Mutex::new(FakeServerInner {
458                    services: Default::default(),
459                    service_senders: Default::default(),
460                    sender,
461                    notification_peers: HashSet::new(),
462                    indication_peers: HashSet::new(),
463                })),
464            },
465            receiver,
466        )
467    }
468
469    pub fn service(&self, id: server::ServiceId) -> Option<ServiceDefinition> {
470        self.inner.lock().services.get(&id).cloned()
471    }
472
473    pub fn incoming_write(
474        &self,
475        peer_id: PeerId,
476        id: server::ServiceId,
477        handle: Handle,
478        offset: u32,
479        value: Vec<u8>,
480    ) {
481        // TODO: check that the write is allowed
482        let sender = self.inner.lock().sender.clone();
483        self.inner
484            .lock()
485            .service_senders
486            .get(&id)
487            .unwrap()
488            .unbounded_send(Ok(server::ServiceEvent::Write {
489                peer_id,
490                handle,
491                offset,
492                value,
493                responder: FakeResponder { sender, service_id: id, handle },
494            }))
495            .unwrap();
496    }
497
498    pub fn incoming_read(
499        &self,
500        peer_id: PeerId,
501        id: server::ServiceId,
502        handle: Handle,
503        offset: u32,
504    ) {
505        // TODO: check that the read is allowed
506        let sender = self.inner.lock().sender.clone();
507        self.inner
508            .lock()
509            .service_senders
510            .get(&id)
511            .unwrap()
512            .unbounded_send(Ok(server::ServiceEvent::Read {
513                peer_id,
514                handle,
515                offset,
516                responder: FakeResponder { sender, service_id: id, handle },
517            }))
518            .unwrap();
519    }
520
521    pub fn incoming_client_configuration(
522        &self,
523        peer_id: PeerId,
524        id: server::ServiceId,
525        handle: Handle,
526        notification_type: NotificationType,
527    ) {
528        let mut inner = self.inner.lock();
529        match notification_type {
530            NotificationType::Notify => {
531                inner.notification_peers.insert(peer_id);
532            }
533            NotificationType::Indicate => {
534                inner.indication_peers.insert(peer_id);
535            }
536            NotificationType::Disable => {
537                inner.notification_peers.remove(&peer_id);
538                inner.indication_peers.remove(&peer_id);
539            }
540        }
541        inner
542            .service_senders
543            .get(&id)
544            .unwrap()
545            .unbounded_send(Ok(server::ServiceEvent::ClientConfiguration {
546                peer_id,
547                handle,
548                notification_type,
549            }))
550            .unwrap();
551    }
552}
553
554pub struct FakeLocalService {
555    id: server::ServiceId,
556    inner: Arc<Mutex<FakeServerInner>>,
557}
558
559impl FakeLocalService {
560    fn new(id: server::ServiceId, inner: Arc<Mutex<FakeServerInner>>) -> Self {
561        Self { id, inner }
562    }
563}
564
565impl Drop for FakeLocalService {
566    fn drop(&mut self) {
567        self.inner.lock().services.remove(&self.id);
568    }
569}
570
571impl LocalService<FakeTypes> for FakeLocalService {
572    fn publish(&self) -> <FakeTypes as ServerTypes>::ServiceEventStream {
573        let (sender, receiver) = futures::channel::mpsc::unbounded();
574        let _ = self.inner.lock().service_senders.insert(self.id, sender);
575        let definition = self.inner.lock().services.get(&self.id).unwrap().clone();
576        self.inner
577            .lock()
578            .sender
579            .unbounded_send(FakeServerEvent::Published { id: self.id, definition })
580            .unwrap();
581        receiver
582    }
583
584    fn notify(&self, characteristic: &Handle, data: &[u8], peers: &[PeerId]) {
585        let inner = self.inner.lock();
586        let peers_to_notify: HashSet<_> = if peers.is_empty() {
587            inner.notification_peers.clone()
588        } else {
589            peers.iter().filter(|p| inner.notification_peers.contains(p)).cloned().collect()
590        };
591
592        if !peers_to_notify.is_empty() {
593            inner
594                .sender
595                .unbounded_send(FakeServerEvent::Notified {
596                    service_id: self.id,
597                    handle: *characteristic,
598                    value: data.into(),
599                    peers: peers_to_notify.into_iter().collect(),
600                })
601                .unwrap();
602        }
603    }
604
605    fn indicate(
606        &self,
607        characteristic: &Handle,
608        data: &[u8],
609        peers: &[PeerId],
610    ) -> <FakeTypes as ServerTypes>::IndicateConfirmationStream {
611        let (sender, receiver) = futures::channel::mpsc::unbounded();
612        let inner = self.inner.lock();
613        let peers_to_indicate: HashSet<_> = if peers.is_empty() {
614            inner.indication_peers.clone()
615        } else {
616            peers.iter().filter(|p| inner.indication_peers.contains(p)).cloned().collect()
617        };
618
619        if !peers_to_indicate.is_empty() {
620            inner
621                .sender
622                .unbounded_send(FakeServerEvent::Indicated {
623                    service_id: self.id,
624                    handle: *characteristic,
625                    value: data.into(),
626                    peers: peers_to_indicate.into_iter().collect(),
627                    confirmations: sender,
628                })
629                .unwrap();
630        }
631        receiver
632    }
633}
634
635pub struct FakeResponder {
636    sender: UnboundedSender<FakeServerEvent>,
637    service_id: server::ServiceId,
638    handle: Handle,
639}
640
641impl ReadResponder for FakeResponder {
642    fn respond(self, value: &[u8]) {
643        self.sender
644            .unbounded_send(FakeServerEvent::ReadResponded {
645                service_id: self.service_id,
646                handle: self.handle,
647                value: Ok(value.into()),
648            })
649            .unwrap();
650    }
651
652    fn error(self, error: GattError) {
653        self.sender
654            .unbounded_send(FakeServerEvent::ReadResponded {
655                service_id: self.service_id,
656                handle: self.handle,
657                value: Err(Error::Gatt(error)),
658            })
659            .unwrap();
660    }
661}
662
663impl WriteResponder for FakeResponder {
664    fn acknowledge(self) {
665        self.sender
666            .unbounded_send(FakeServerEvent::WriteResponded {
667                service_id: self.service_id,
668                handle: self.handle,
669                value: Ok(()),
670            })
671            .unwrap();
672    }
673
674    fn error(self, error: GattError) {
675        self.sender
676            .unbounded_send(FakeServerEvent::WriteResponded {
677                service_id: self.service_id,
678                handle: self.handle,
679                value: Err(Error::Gatt(error)),
680            })
681            .unwrap();
682    }
683}