use anyhow::{format_err, Context as _, Error};
use fidl_fuchsia_recovery::{FactoryResetMarker, FactoryResetProxy};
use fidl_fuchsia_update_channel::{ProviderMarker, ProviderProxy};
use fuchsia_component::client::connect_to_protocol;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::fs::File;
use std::path::PathBuf;
use tracing::{info, warn};
const DEVICE_INDEX_FILE: &str = "stored-index.json";
const CONFIGURED_INDEX_FILE: &str = "forced-fdr-channel-indices.config";
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "version", content = "content", deny_unknown_fields)]
enum ChannelIndices {
#[serde(rename = "1")]
Version1 { channel_indices: HashMap<String, i32> },
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "version", content = "content", deny_unknown_fields)]
enum StoredIndex {
#[serde(rename = "1")]
Version1 { channel: String, index: i32 },
}
struct ForcedFDR {
data_dir: PathBuf,
config_data_dir: PathBuf,
info_proxy: ProviderProxy,
factory_reset_proxy: FactoryResetProxy,
}
impl ForcedFDR {
fn new() -> Result<Self, Error> {
let info_proxy = connect_to_protocol::<ProviderMarker>()?;
let factory_reset_proxy = connect_to_protocol::<FactoryResetMarker>()?;
Ok(ForcedFDR {
data_dir: "/data".into(),
config_data_dir: "/config/data".into(),
info_proxy,
factory_reset_proxy,
})
}
#[cfg(test)]
fn new_mock(
data_dir: PathBuf,
config_data_dir: PathBuf,
) -> (
Self,
fidl_fuchsia_update_channel::ProviderRequestStream,
fidl_fuchsia_recovery::FactoryResetRequestStream,
) {
let (info_proxy, info_stream) =
fidl::endpoints::create_proxy_and_stream::<ProviderMarker>();
let (fdr_proxy, fdr_stream) =
fidl::endpoints::create_proxy_and_stream::<FactoryResetMarker>();
(
ForcedFDR { data_dir, config_data_dir, info_proxy, factory_reset_proxy: fdr_proxy },
info_stream,
fdr_stream,
)
}
}
pub async fn perform_fdr_if_necessary() {
perform_fdr_if_necessary_impl().await.unwrap_or_else(|err| info!(tag = "forced-fdr", ?err))
}
async fn perform_fdr_if_necessary_impl() -> Result<(), Error> {
let forced_fdr = ForcedFDR::new().context("Failed to connect to required services")?;
run(forced_fdr).await
}
async fn run(fdr: ForcedFDR) -> Result<(), Error> {
let current_channel =
get_current_channel(&fdr).await.context("Failed to get current channel")?;
let channel_indices = get_channel_indices(&fdr).context("Channel indices not available")?;
if !is_channel_in_allowlist(&channel_indices, ¤t_channel) {
return Err(format_err!("Not in forced FDR allowlist"));
}
let channel_index = get_channel_index(&channel_indices, ¤t_channel)
.ok_or_else(|| format_err!("Not in forced FDR allowlist."))?;
let device_index = match get_stored_index(&fdr, ¤t_channel) {
Ok(index) => index,
Err(err) => {
info!(%err, "Unable to read stored index");
return write_stored_index(&fdr, ¤t_channel, channel_index)
.context("Failed to write device index");
}
};
if device_index >= channel_index {
return Err(format_err!("FDR not required"));
}
trigger_fdr(&fdr).await.context("Failed to trigger FDR")?;
Ok(())
}
fn get_channel_indices(fdr: &ForcedFDR) -> Result<HashMap<String, i32>, Error> {
let f = open_channel_indices_file(fdr)?;
match serde_json::from_reader(std::io::BufReader::new(f))? {
ChannelIndices::Version1 { channel_indices } => Ok(channel_indices),
}
}
fn open_channel_indices_file(fdr: &ForcedFDR) -> Result<File, Error> {
Ok(fs::File::open(fdr.config_data_dir.join(CONFIGURED_INDEX_FILE))?)
}
async fn get_current_channel(fdr: &ForcedFDR) -> Result<String, Error> {
Ok(fdr.info_proxy.get_current().await?)
}
fn is_channel_in_allowlist(allowlist: &HashMap<String, i32>, channel: &String) -> bool {
allowlist.contains_key(channel)
}
fn get_channel_index(channel_indices: &HashMap<String, i32>, channel: &String) -> Option<i32> {
channel_indices.get(channel).copied()
}
async fn trigger_fdr(fdr: &ForcedFDR) -> Result<i32, Error> {
warn!("Triggering FDR. SSH keys will be lost");
Ok(fdr.factory_reset_proxy.reset().await?)
}
fn get_stored_index(fdr: &ForcedFDR, current_channel: &String) -> Result<i32, Error> {
let f = open_stored_index_file(fdr)?;
match serde_json::from_reader(std::io::BufReader::new(f))? {
StoredIndex::Version1 { channel, index } => {
if *current_channel != channel {
return Err(format_err!("Mismatch between stored and current channel"));
}
Ok(index)
}
}
}
fn open_stored_index_file(fdr: &ForcedFDR) -> Result<File, Error> {
Ok(fs::File::open(fdr.data_dir.join(DEVICE_INDEX_FILE))?)
}
fn write_stored_index(fdr: &ForcedFDR, channel: &String, index: i32) -> Result<(), Error> {
info!("Writing index {} for channel {}", index, channel);
let stored_index = StoredIndex::Version1 { channel: channel.to_string(), index };
let contents = serde_json::to_string(&stored_index)?;
fs::write(fdr.data_dir.join(DEVICE_INDEX_FILE), contents)?;
Ok(())
}
#[cfg(test)]
mod forced_fdr_test;