use {
aes::{
cipher::{
generic_array::GenericArray, inout::InOut, typenum::consts::U16, BlockBackend,
BlockClosure, BlockDecrypt, BlockEncrypt, BlockSizeUser, KeyInit, KeyIvInit,
StreamCipher as _, StreamCipherSeek,
},
Aes256,
},
anyhow::{anyhow, Error},
async_trait::async_trait,
chacha20::{ChaCha20, Key},
fprint::TypeFingerprint,
serde::{
de::{Error as SerdeError, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
},
static_assertions::assert_cfg,
zerocopy::{AsBytes, FromBytes, FromZeros, NoCell},
};
pub mod ff1;
pub const KEY_SIZE: usize = 256 / 8;
pub const WRAPPED_KEY_SIZE: usize = KEY_SIZE + 16;
const SECTOR_SIZE: u64 = 512;
pub type KeyBytes = [u8; KEY_SIZE];
pub struct UnwrappedKey {
key: KeyBytes,
}
impl UnwrappedKey {
pub fn new(key: KeyBytes) -> Self {
UnwrappedKey { key }
}
pub fn key(&self) -> &KeyBytes {
&self.key
}
}
pub type UnwrappedKeys = Vec<(u64, UnwrappedKey)>;
pub type WrappedKeyBytes = WrappedKeyBytesV32;
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct WrappedKeyBytesV32(pub [u8; WRAPPED_KEY_SIZE]);
impl Default for WrappedKeyBytes {
fn default() -> Self {
Self([0u8; WRAPPED_KEY_SIZE])
}
}
impl TryFrom<Vec<u8>> for WrappedKeyBytes {
type Error = anyhow::Error;
fn try_from(buf: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(buf.try_into().map_err(|_| anyhow!("wrapped key wrong length"))?))
}
}
impl From<[u8; WRAPPED_KEY_SIZE]> for WrappedKeyBytes {
fn from(buf: [u8; WRAPPED_KEY_SIZE]) -> Self {
Self(buf)
}
}
impl TypeFingerprint for WrappedKeyBytes {
fn fingerprint() -> String {
"WrappedKeyBytes".to_owned()
}
}
impl std::ops::Deref for WrappedKeyBytes {
type Target = [u8; WRAPPED_KEY_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for WrappedKeyBytes {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Serialize for WrappedKeyBytes {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&self[..])
}
}
impl<'de> Deserialize<'de> for WrappedKeyBytes {
fn deserialize<D>(deserializer: D) -> Result<WrappedKeyBytes, D::Error>
where
D: Deserializer<'de>,
{
struct WrappedKeyVisitor;
impl<'d> Visitor<'d> for WrappedKeyVisitor {
type Value = WrappedKeyBytes;
fn expecting(&self, formatter: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
formatter.write_str("Expected wrapped keys to be 48 bytes")
}
fn visit_bytes<E>(self, bytes: &[u8]) -> Result<WrappedKeyBytes, E>
where
E: SerdeError,
{
self.visit_byte_buf(bytes.to_vec())
}
fn visit_byte_buf<E>(self, bytes: Vec<u8>) -> Result<WrappedKeyBytes, E>
where
E: SerdeError,
{
let orig_len = bytes.len();
let bytes: [u8; WRAPPED_KEY_SIZE] =
bytes.try_into().map_err(|_| SerdeError::invalid_length(orig_len, &self))?;
Ok(WrappedKeyBytes::from(bytes))
}
}
deserializer.deserialize_byte_buf(WrappedKeyVisitor)
}
}
pub type WrappedKey = WrappedKeyV32;
#[derive(Clone, Debug, Serialize, Deserialize, TypeFingerprint, PartialEq)]
pub struct WrappedKeyV32 {
pub wrapping_key_id: u64,
pub key: WrappedKeyBytesV32,
}
pub type WrappedKeys = WrappedKeysV32;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, TypeFingerprint)]
pub struct WrappedKeysV32(pub Vec<(u64, WrappedKeyV32)>);
impl From<Vec<(u64, WrappedKey)>> for WrappedKeys {
fn from(buf: Vec<(u64, WrappedKey)>) -> Self {
Self(buf)
}
}
impl std::ops::Deref for WrappedKeys {
type Target = Vec<(u64, WrappedKey)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for WrappedKeys {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
struct XtsCipher {
id: u64,
cipher: Aes256,
}
pub struct XtsCipherSet(Vec<XtsCipher>);
impl XtsCipherSet {
pub fn new(keys: &UnwrappedKeys) -> Self {
Self(
keys.iter()
.map(|(id, k)| XtsCipher {
id: *id,
cipher: Aes256::new(GenericArray::from_slice(k.key())),
})
.collect(),
)
}
pub fn decrypt(&self, offset: u64, key_id: u64, buffer: &mut [u8]) -> Result<(), Error> {
fxfs_trace::duration!(c"decrypt", "len" => buffer.len());
assert_eq!(offset % SECTOR_SIZE, 0);
let cipher = &self
.0
.iter()
.find(|cipher| cipher.id == key_id)
.ok_or(anyhow!("Key not found"))?
.cipher;
let mut sector_offset = offset / SECTOR_SIZE;
for sector in buffer.chunks_exact_mut(SECTOR_SIZE as usize) {
let mut tweak = Tweak(sector_offset as u128);
cipher.encrypt_block(GenericArray::from_mut_slice(tweak.as_bytes_mut()));
cipher.decrypt_with_backend(XtsProcessor::new(tweak, sector));
sector_offset += 1;
}
Ok(())
}
pub fn encrypt(&self, offset: u64, key_id: u64, buffer: &mut [u8]) -> Result<(), Error> {
fxfs_trace::duration!(c"encrypt", "len" => buffer.len());
assert_eq!(offset % SECTOR_SIZE, 0);
let cipher = &self
.0
.iter()
.find(|cipher| cipher.id == key_id)
.ok_or(anyhow!("Key not found"))?
.cipher;
let mut sector_offset = offset / SECTOR_SIZE;
for sector in buffer.chunks_exact_mut(SECTOR_SIZE as usize) {
let mut tweak = Tweak(sector_offset as u128);
cipher.encrypt_block(GenericArray::from_mut_slice(tweak.as_bytes_mut()));
cipher.encrypt_with_backend(XtsProcessor::new(tweak, sector));
sector_offset += 1;
}
Ok(())
}
}
pub struct StreamCipher(ChaCha20);
impl StreamCipher {
pub fn new(key: &UnwrappedKey, offset: u64) -> Self {
let mut cipher =
Self(ChaCha20::new(Key::from_slice(&key.key), &[0; 12].into()));
cipher.0.seek(offset);
cipher
}
pub fn encrypt(&mut self, buffer: &mut [u8]) {
fxfs_trace::duration!(c"StreamCipher::encrypt", "len" => buffer.len());
self.0.apply_keystream(buffer);
}
pub fn decrypt(&mut self, buffer: &mut [u8]) {
fxfs_trace::duration!(c"StreamCipher::decrypt", "len" => buffer.len());
self.0.apply_keystream(buffer);
}
pub fn offset(&self) -> u64 {
self.0.current_pos()
}
}
pub enum KeyPurpose {
Data,
Metadata,
}
#[async_trait]
pub trait Crypt: Send + Sync {
async fn create_key(
&self,
owner: u64,
purpose: KeyPurpose,
) -> Result<(WrappedKey, UnwrappedKey), Error>;
async fn unwrap_key(&self, wrapped_key: &WrappedKey, owner: u64)
-> Result<UnwrappedKey, Error>;
async fn unwrap_keys(&self, keys: &WrappedKeys, owner: u64) -> Result<UnwrappedKeys, Error> {
let mut futures = vec![];
for (key_id, key) in keys.iter() {
futures.push(async move { Ok((*key_id, self.unwrap_key(key, owner).await?)) });
}
futures::future::try_join_all(futures).await
}
}
assert_cfg!(target_endian = "little");
#[derive(AsBytes, FromBytes, FromZeros, NoCell)]
#[repr(C)]
struct Tweak(u128);
struct XtsProcessor<'a> {
tweak: Tweak,
data: &'a mut [u8],
}
impl<'a> XtsProcessor<'a> {
fn new(tweak: Tweak, data: &'a mut [u8]) -> Self {
assert_eq!(data.as_ptr() as usize & 15, 0, "data must be 16 byte aligned");
Self { tweak, data }
}
}
impl BlockSizeUser for XtsProcessor<'_> {
type BlockSize = U16;
}
impl BlockClosure for XtsProcessor<'_> {
fn call<B: BlockBackend<BlockSize = Self::BlockSize>>(self, backend: &mut B) {
let Self { mut tweak, data } = self;
for chunk in data.chunks_exact_mut(16) {
let ptr = chunk.as_mut_ptr() as *mut u128;
unsafe {
*ptr ^= tweak.0;
let chunk = ptr as *mut GenericArray<u8, U16>;
backend.proc_block(InOut::from_raw(chunk, chunk));
*ptr ^= tweak.0;
}
tweak.0 = (tweak.0 << 1) ^ ((tweak.0 as i128 >> 127) as u128 & 0x87);
}
}
}
#[cfg(test)]
mod tests {
use super::{StreamCipher, UnwrappedKey};
#[test]
fn test_stream_cipher_offset() {
let key = UnwrappedKey::new([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
]);
let mut cipher1 = StreamCipher::new(&key, 0);
let mut p1 = [1, 2, 3, 4];
let mut c1 = p1.clone();
cipher1.encrypt(&mut c1);
let mut cipher2 = StreamCipher::new(&key, 1);
let p2 = [5, 6, 7, 8];
let mut c2 = p2.clone();
cipher2.encrypt(&mut c2);
let xor_fn = |buf1: &mut [u8], buf2| {
for (b1, b2) in buf1.iter_mut().zip(buf2) {
*b1 ^= b2;
}
};
xor_fn(&mut c1, &c2);
xor_fn(&mut p1, &p2);
assert_ne!(c1, p1);
}
}