use aes::cipher::generic_array::GenericArray;
use aes::cipher::inout::InOut;
use aes::cipher::typenum::consts::U16;
use aes::cipher::{
BlockBackend, BlockClosure, BlockDecrypt, BlockEncrypt, BlockSizeUser, KeyInit, KeyIvInit,
StreamCipher as _, StreamCipherSeek,
};
use aes::Aes256;
use anyhow::{anyhow, Error};
use async_trait::async_trait;
use chacha20::{self, ChaCha20};
use fprint::TypeFingerprint;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt as _;
use fxfs_macros::{migrate_nodefault, Migrate};
use serde::de::{Error as SerdeError, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use static_assertions::assert_cfg;
use std::sync::Arc;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
use zx_status as zx;
pub mod ff1;
pub const KEY_SIZE: usize = 256 / 8;
pub const WRAPPED_KEY_SIZE: usize = KEY_SIZE + 16;
pub const FSCRYPT_PADDING: usize = 16;
const SECTOR_SIZE: u64 = 512;
pub type KeyBytes = [u8; KEY_SIZE];
#[derive(Debug)]
pub struct UnwrappedKey {
key: KeyBytes,
}
impl UnwrappedKey {
pub fn new(key: KeyBytes) -> Self {
UnwrappedKey { key }
}
pub fn key(&self) -> &KeyBytes {
&self.key
}
}
pub type UnwrappedKeys = Vec<(u64, Option<UnwrappedKey>)>;
pub type WrappedKeyBytes = WrappedKeyBytesV32;
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct WrappedKeyBytesV32(pub [u8; WRAPPED_KEY_SIZE]);
impl Default for WrappedKeyBytes {
fn default() -> Self {
Self([0u8; WRAPPED_KEY_SIZE])
}
}
impl TryFrom<Vec<u8>> for WrappedKeyBytes {
type Error = anyhow::Error;
fn try_from(buf: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(buf.try_into().map_err(|_| anyhow!("wrapped key wrong length"))?))
}
}
impl From<[u8; WRAPPED_KEY_SIZE]> for WrappedKeyBytes {
fn from(buf: [u8; WRAPPED_KEY_SIZE]) -> Self {
Self(buf)
}
}
impl TypeFingerprint for WrappedKeyBytes {
fn fingerprint() -> String {
"WrappedKeyBytes".to_owned()
}
}
impl std::ops::Deref for WrappedKeyBytes {
type Target = [u8; WRAPPED_KEY_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for WrappedKeyBytes {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Serialize for WrappedKeyBytes {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&self[..])
}
}
impl<'de> Deserialize<'de> for WrappedKeyBytes {
fn deserialize<D>(deserializer: D) -> Result<WrappedKeyBytes, D::Error>
where
D: Deserializer<'de>,
{
struct WrappedKeyVisitor;
impl<'d> Visitor<'d> for WrappedKeyVisitor {
type Value = WrappedKeyBytes;
fn expecting(&self, formatter: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
formatter.write_str("Expected wrapped keys to be 48 bytes")
}
fn visit_bytes<E>(self, bytes: &[u8]) -> Result<WrappedKeyBytes, E>
where
E: SerdeError,
{
self.visit_byte_buf(bytes.to_vec())
}
fn visit_byte_buf<E>(self, bytes: Vec<u8>) -> Result<WrappedKeyBytes, E>
where
E: SerdeError,
{
let orig_len = bytes.len();
let bytes: [u8; WRAPPED_KEY_SIZE] =
bytes.try_into().map_err(|_| SerdeError::invalid_length(orig_len, &self))?;
Ok(WrappedKeyBytes::from(bytes))
}
}
deserializer.deserialize_byte_buf(WrappedKeyVisitor)
}
}
pub type WrappedKey = WrappedKeyV40;
#[derive(Clone, Debug, Serialize, Deserialize, TypeFingerprint, PartialEq)]
pub struct WrappedKeyV40 {
pub wrapping_key_id: u128,
pub key: WrappedKeyBytesV32,
}
#[derive(Default, Clone, Migrate, Debug, Serialize, Deserialize, TypeFingerprint)]
#[migrate_nodefault]
pub struct WrappedKeyV32 {
pub wrapping_key_id: u64,
pub key: WrappedKeyBytesV32,
}
pub type WrappedKeys = WrappedKeysV40;
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, TypeFingerprint)]
pub struct WrappedKeysV40(Vec<(u64, WrappedKeyV40)>);
impl From<WrappedKeysV32> for WrappedKeysV40 {
fn from(value: WrappedKeysV32) -> Self {
Self(value.0.into_iter().map(|(id, key)| (id, key.into())).collect())
}
}
#[derive(Clone, Debug, Serialize, Deserialize, TypeFingerprint)]
pub struct WrappedKeysV32(pub Vec<(u64, WrappedKeyV32)>);
impl From<Vec<(u64, WrappedKey)>> for WrappedKeys {
fn from(buf: Vec<(u64, WrappedKey)>) -> Self {
Self(buf)
}
}
impl std::ops::Deref for WrappedKeys {
type Target = Vec<(u64, WrappedKey)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for WrappedKeys {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl WrappedKeys {
pub fn get_wrapping_key_with_id(&self, key_id: u64) -> Option<[u8; 16]> {
let wrapped_key_entry = self.0.iter().find(|(x, _)| *x == key_id);
wrapped_key_entry.map(|(_, wrapped_key)| wrapped_key.wrapping_key_id.to_le_bytes())
}
}
#[derive(Clone, Debug)]
pub struct Cipher {
id: u64,
cipher: Option<Aes256>,
}
impl Cipher {
pub fn new(id: u64, key: &UnwrappedKey) -> Self {
Self { id, cipher: Some(Aes256::new(GenericArray::from_slice(key.key()))) }
}
pub fn unavailable(id: u64) -> Self {
Cipher { id, cipher: None }
}
pub fn key(&self) -> Option<&Aes256> {
self.cipher.as_ref()
}
}
pub struct Key {
keys: Arc<CipherSet>,
index: usize,
}
impl Key {
fn key(&self) -> &Aes256 {
self.keys.0[self.index].cipher.as_ref().unwrap()
}
pub fn key_id(&self) -> u64 {
self.keys.0[self.index].id
}
pub fn encrypt(&self, offset: u64, buffer: &mut [u8]) -> Result<(), Error> {
fxfs_trace::duration!(c"encrypt", "len" => buffer.len());
assert_eq!(offset % SECTOR_SIZE, 0);
let cipher = &self.key();
let mut sector_offset = offset / SECTOR_SIZE;
for sector in buffer.chunks_exact_mut(SECTOR_SIZE as usize) {
let mut tweak = Tweak(sector_offset as u128);
cipher.encrypt_block(GenericArray::from_mut_slice(tweak.as_mut_bytes()));
cipher.encrypt_with_backend(XtsProcessor::new(tweak, sector));
sector_offset += 1;
}
Ok(())
}
pub fn decrypt(&self, offset: u64, buffer: &mut [u8]) -> Result<(), Error> {
fxfs_trace::duration!(c"decrypt", "len" => buffer.len());
assert_eq!(offset % SECTOR_SIZE, 0);
let cipher = &self.key();
let mut sector_offset = offset / SECTOR_SIZE;
for sector in buffer.chunks_exact_mut(SECTOR_SIZE as usize) {
let mut tweak = Tweak(sector_offset as u128);
cipher.encrypt_block(GenericArray::from_mut_slice(tweak.as_mut_bytes()));
cipher.decrypt_with_backend(XtsProcessor::new(tweak, sector));
sector_offset += 1;
}
Ok(())
}
pub fn encrypt_filename(&self, object_id: u64, buffer: &mut Vec<u8>) -> Result<(), Error> {
buffer.resize(buffer.len().next_multiple_of(FSCRYPT_PADDING), 0);
let cipher = self.key();
cipher.encrypt_with_backend(CbcEncryptProcessor::new(Tweak(object_id as u128), buffer));
Ok(())
}
pub fn decrypt_filename(&self, object_id: u64, buffer: &mut Vec<u8>) -> Result<(), Error> {
let cipher = self.key();
cipher.decrypt_with_backend(CbcDecryptProcessor::new(Tweak(object_id as u128), buffer));
if let Some(i) = buffer.iter().rposition(|x| *x != 0) {
let new_len = i + 1;
buffer.truncate(new_len);
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct CipherSet(Vec<Cipher>);
impl From<Vec<Cipher>> for CipherSet {
fn from(value: Vec<Cipher>) -> Self {
Self(value)
}
}
impl CipherSet {
pub fn new(keys: &UnwrappedKeys) -> Self {
Self(
keys.iter()
.map(|(id, k)| match k {
Some(k) => Cipher::new(*id, k),
None => Cipher::unavailable(*id),
})
.collect(),
)
}
pub fn ciphers(&self) -> &[Cipher] {
&self.0
}
pub fn cipher(&self, id: u64) -> Option<(usize, &Cipher)> {
self.0.iter().enumerate().find(|(_, x)| x.id == id)
}
pub fn contains_key_id(&self, id: u64) -> bool {
self.0.iter().find(|x| x.id == id).is_some()
}
pub fn find_key(self: &Arc<Self>, id: u64) -> FindKeyResult {
let Some((index, cipher)) = self.0.iter().enumerate().find(|(_, x)| x.id == id) else {
return FindKeyResult::NotFound;
};
if cipher.key().is_some() {
FindKeyResult::Key(Key { keys: self.clone(), index })
} else {
FindKeyResult::Unavailable
}
}
}
pub enum FindKeyResult {
NotFound,
Unavailable,
Key(Key),
}
pub struct StreamCipher(ChaCha20);
impl StreamCipher {
pub fn new(key: &UnwrappedKey, offset: u64) -> Self {
let mut cipher = Self(ChaCha20::new(
chacha20::Key::from_slice(&key.key),
&[0; 12].into(),
));
cipher.0.seek(offset);
cipher
}
pub fn encrypt(&mut self, buffer: &mut [u8]) {
fxfs_trace::duration!(c"StreamCipher::encrypt", "len" => buffer.len());
self.0.apply_keystream(buffer);
}
pub fn decrypt(&mut self, buffer: &mut [u8]) {
fxfs_trace::duration!(c"StreamCipher::decrypt", "len" => buffer.len());
self.0.apply_keystream(buffer);
}
pub fn offset(&self) -> u64 {
self.0.current_pos()
}
}
pub enum KeyPurpose {
Data,
Metadata,
}
#[async_trait]
pub trait Crypt: Send + Sync {
async fn create_key(
&self,
owner: u64,
purpose: KeyPurpose,
) -> Result<(WrappedKey, UnwrappedKey), zx::Status>;
async fn create_key_with_id(
&self,
owner: u64,
wrapping_key_id: u128,
) -> Result<(WrappedKey, UnwrappedKey), zx::Status>;
async fn unwrap_key(
&self,
wrapped_key: &WrappedKey,
owner: u64,
) -> Result<UnwrappedKey, zx::Status>;
async fn unwrap_keys(
&self,
keys: &WrappedKeys,
owner: u64,
) -> Result<UnwrappedKeys, zx::Status> {
let futures = FuturesUnordered::new();
for (key_id, key) in keys.iter() {
futures.push(async move {
match self.unwrap_key(key, owner).await {
Ok(unwrapped_key) => Ok((*key_id, Some(unwrapped_key))),
Err(zx::Status::NOT_FOUND) => Ok((*key_id, None)),
Err(e) => Err(e),
}
});
}
Ok(futures.try_collect::<UnwrappedKeys>().await?)
}
}
assert_cfg!(target_endian = "little");
#[derive(IntoBytes, KnownLayout, FromBytes, Immutable)]
#[repr(C)]
struct Tweak(u128);
pub fn xor_in_place(a: &mut [u8], b: &[u8]) {
for (b1, b2) in a.iter_mut().zip(b.iter()) {
*b1 ^= *b2;
}
}
struct CbcEncryptProcessor<'a> {
tweak: Tweak,
data: &'a mut [u8],
}
impl<'a> CbcEncryptProcessor<'a> {
fn new(tweak: Tweak, data: &'a mut [u8]) -> Self {
Self { tweak, data }
}
}
impl BlockSizeUser for CbcEncryptProcessor<'_> {
type BlockSize = U16;
}
impl BlockClosure for CbcEncryptProcessor<'_> {
fn call<B: BlockBackend<BlockSize = Self::BlockSize>>(self, backend: &mut B) {
let Self { mut tweak, data } = self;
for block in data.chunks_exact_mut(16) {
xor_in_place(block, &tweak.0.to_le_bytes());
let chunk: &mut GenericArray<u8, _> = GenericArray::from_mut_slice(block);
backend.proc_block(InOut::from(chunk));
tweak.0 = u128::from_le_bytes(block.try_into().unwrap())
}
}
}
struct CbcDecryptProcessor<'a> {
tweak: Tweak,
data: &'a mut [u8],
}
impl<'a> CbcDecryptProcessor<'a> {
fn new(tweak: Tweak, data: &'a mut [u8]) -> Self {
Self { tweak, data }
}
}
impl BlockSizeUser for CbcDecryptProcessor<'_> {
type BlockSize = U16;
}
impl BlockClosure for CbcDecryptProcessor<'_> {
fn call<B: BlockBackend<BlockSize = Self::BlockSize>>(self, backend: &mut B) {
let Self { mut tweak, data } = self;
for block in data.chunks_exact_mut(16) {
let ciphertext = block.to_vec();
let chunk = GenericArray::from_mut_slice(block);
backend.proc_block(InOut::from(chunk));
xor_in_place(block, &tweak.0.to_le_bytes());
tweak.0 = u128::from_le_bytes(ciphertext.try_into().unwrap());
}
}
}
struct XtsProcessor<'a> {
tweak: Tweak,
data: &'a mut [u8],
}
impl<'a> XtsProcessor<'a> {
fn new(tweak: Tweak, data: &'a mut [u8]) -> Self {
assert_eq!(data.as_ptr() as usize & 15, 0, "data must be 16 byte aligned");
Self { tweak, data }
}
}
impl BlockSizeUser for XtsProcessor<'_> {
type BlockSize = U16;
}
impl BlockClosure for XtsProcessor<'_> {
fn call<B: BlockBackend<BlockSize = Self::BlockSize>>(self, backend: &mut B) {
let Self { mut tweak, data } = self;
for chunk in data.chunks_exact_mut(16) {
let ptr = chunk.as_mut_ptr() as *mut u128;
unsafe {
*ptr ^= tweak.0;
let chunk = ptr as *mut GenericArray<u8, U16>;
backend.proc_block(InOut::from_raw(chunk, chunk));
*ptr ^= tweak.0;
}
tweak.0 = (tweak.0 << 1) ^ ((tweak.0 as i128 >> 127) as u128 & 0x87);
}
}
}
#[cfg(test)]
mod tests {
use crate::{Cipher, CipherSet, Key};
use super::{StreamCipher, UnwrappedKey};
#[test]
fn test_stream_cipher_offset() {
let key = UnwrappedKey::new([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
]);
let mut cipher1 = StreamCipher::new(&key, 0);
let mut p1 = [1, 2, 3, 4];
let mut c1 = p1.clone();
cipher1.encrypt(&mut c1);
let mut cipher2 = StreamCipher::new(&key, 1);
let p2 = [5, 6, 7, 8];
let mut c2 = p2.clone();
cipher2.encrypt(&mut c2);
let xor_fn = |buf1: &mut [u8], buf2| {
for (b1, b2) in buf1.iter_mut().zip(buf2) {
*b1 ^= b2;
}
};
xor_fn(&mut c1, &c2);
xor_fn(&mut p1, &p2);
assert_ne!(c1, p1);
}
#[test]
fn test_encrypt_filename() {
let raw_key_hex = "1fcdf30b7d191bd95d3161fe08513b864aa15f27f910f1c66eec8cfa93e9893b";
let raw_key_bytes: [u8; 32] =
hex::decode(raw_key_hex).expect("decode failed").try_into().unwrap();
let unwrapped_key = UnwrappedKey::new(raw_key_bytes);
let cipher_set = CipherSet::from(vec![Cipher::new(0, &unwrapped_key)]);
let key = Key { keys: std::sync::Arc::new(cipher_set), index: 0 };
let object_id = 2;
let mut text = "filename".to_string().as_bytes().to_vec();
key.encrypt_filename(object_id, &mut text).expect("encrypt filename failed");
assert_eq!(text, hex::decode("52d56369103a39b3ea1e09c85dd51546").expect("decode failed"));
}
#[test]
fn test_decrypt_filename() {
let raw_key_hex = "1fcdf30b7d191bd95d3161fe08513b864aa15f27f910f1c66eec8cfa93e9893b";
let raw_key_bytes: [u8; 32] =
hex::decode(raw_key_hex).expect("decode failed").try_into().unwrap();
let unwrapped_key = UnwrappedKey::new(raw_key_bytes);
let cipher_set = CipherSet::from(vec![Cipher::new(0, &unwrapped_key)]);
let key = Key { keys: std::sync::Arc::new(cipher_set), index: 0 };
let object_id = 2;
let mut text = hex::decode("52d56369103a39b3ea1e09c85dd51546").expect("decode failed");
key.decrypt_filename(object_id, &mut text).expect("encrypt filename failed");
assert_eq!(text, "filename".to_string().as_bytes().to_vec());
}
}