#![warn(missing_docs)]
use ahash::AHashSet;
use once_cell::sync::Lazy;
use serde::de::{Deserializer, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};
static CACHE: Lazy<Mutex<AHashSet<Storage>>> = Lazy::new(|| Mutex::new(AHashSet::new()));
#[derive(Eq, Hash, PartialEq)]
struct Storage(Arc<Box<str>>);
impl Borrow<str> for Storage {
#[inline]
fn borrow(&self) -> &str {
self.0.as_ref()
}
}
#[derive(Clone, Eq, Hash, PartialEq)]
pub struct FlyStr(RawRepr);
static_assertions::assert_eq_size!(FlyStr, usize);
impl FlyStr {
pub fn new(s: impl AsRef<str> + Into<String>) -> Self {
Self(RawRepr::new(s))
}
#[inline]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Default for FlyStr {
#[inline]
fn default() -> Self {
Self::new("")
}
}
impl From<&'_ str> for FlyStr {
#[inline]
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl From<&'_ String> for FlyStr {
#[inline]
fn from(s: &String) -> Self {
Self::new(&**s)
}
}
impl From<String> for FlyStr {
#[inline]
fn from(s: String) -> Self {
Self::new(s)
}
}
impl From<Box<str>> for FlyStr {
#[inline]
fn from(s: Box<str>) -> Self {
Self::new(s)
}
}
impl From<&Box<str>> for FlyStr {
#[inline]
fn from(s: &Box<str>) -> Self {
Self::new(&**s)
}
}
impl Into<String> for FlyStr {
#[inline]
fn into(self) -> String {
self.as_str().to_owned()
}
}
impl Into<String> for &'_ FlyStr {
#[inline]
fn into(self) -> String {
self.as_str().to_owned()
}
}
impl Deref for FlyStr {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl AsRef<str> for FlyStr {
#[inline]
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl PartialOrd for FlyStr {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FlyStr {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_str().cmp(other.as_str())
}
}
impl PartialEq<str> for FlyStr {
#[inline]
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<&'_ str> for FlyStr {
#[inline]
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl PartialEq<String> for FlyStr {
#[inline]
fn eq(&self, other: &String) -> bool {
self.as_str() == &**other
}
}
impl PartialOrd<str> for FlyStr {
#[inline]
fn partial_cmp(&self, other: &str) -> Option<std::cmp::Ordering> {
self.as_str().partial_cmp(other)
}
}
impl Debug for FlyStr {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Debug::fmt(self.as_str(), f)
}
}
impl Display for FlyStr {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Display::fmt(self.as_str(), f)
}
}
impl Serialize for FlyStr {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.as_str())
}
}
impl<'d> Deserialize<'d> for FlyStr {
fn deserialize<D: Deserializer<'d>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_str(FlyStrVisitor)
}
}
struct FlyStrVisitor;
impl Visitor<'_> for FlyStrVisitor {
type Value = FlyStr;
fn expecting(&self, formatter: &mut Formatter<'_>) -> FmtResult {
formatter.write_str("a string")
}
fn visit_borrowed_str<'de, E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.into())
}
}
#[repr(C)] union RawRepr {
heap: NonNull<Box<str>>,
inline: InlineRepr,
}
static_assertions::assert_eq_size!(Arc<Box<str>>, RawRepr);
static_assertions::const_assert!(std::mem::align_of::<Box<str>>() > 1);
static_assertions::assert_type_eq_all!(byteorder::NativeEndian, byteorder::LittleEndian);
enum SafeRepr<'a> {
Heap(NonNull<Box<str>>),
Inline(&'a InlineRepr),
}
unsafe impl Send for RawRepr {}
unsafe impl Sync for RawRepr {}
impl RawRepr {
fn new(s: impl AsRef<str> + Into<String>) -> Self {
let borrowed = s.as_ref();
if borrowed.len() <= MAX_INLINE_SIZE {
let new = Self { inline: InlineRepr::new(borrowed) };
assert!(new.is_inline(), "least significant bit must be 1 for inline strings");
new
} else {
let mut cache = CACHE.lock().unwrap();
if let Some(existing) = cache.get(borrowed) {
Self { heap: nonnull_from_arc(Arc::clone(&existing.0)) }
} else {
let new_storage = Arc::new(s.into().into_boxed_str());
cache.insert(Storage(Arc::clone(&new_storage)));
let new = Self { heap: nonnull_from_arc(new_storage) };
assert!(!new.is_inline(), "least significant bit must be 0 for heap strings");
new
}
}
}
#[inline]
fn is_inline(&self) -> bool {
(unsafe { self.inline.masked_len } & 1) == 1
}
#[inline]
fn project(&self) -> SafeRepr<'_> {
if self.is_inline() {
SafeRepr::Inline(unsafe { &self.inline })
} else {
SafeRepr::Heap(unsafe { self.heap })
}
}
#[inline]
fn as_str(&self) -> &str {
match self.project() {
SafeRepr::Heap(ptr) => unsafe { &**ptr.as_ref() },
SafeRepr::Inline(i) => i.as_str(),
}
}
}
impl PartialEq for RawRepr {
#[inline]
fn eq(&self, other: &Self) -> bool {
let lhs = unsafe { &self.inline };
let rhs = unsafe { &other.inline };
lhs.eq(rhs)
}
}
impl Eq for RawRepr {}
impl Hash for RawRepr {
fn hash<H: Hasher>(&self, h: &mut H) {
let this = unsafe { &self.inline };
this.hash(h);
}
}
impl Clone for RawRepr {
fn clone(&self) -> Self {
match self.project() {
SafeRepr::Heap(ptr) => {
let clone = unsafe { Arc::from_raw(ptr.as_ptr() as *const Box<str>) };
unsafe { Arc::increment_strong_count(ptr.as_ptr()) };
Self { heap: nonnull_from_arc(clone) }
}
SafeRepr::Inline(&inline) => Self { inline },
}
}
}
impl Drop for RawRepr {
fn drop(&mut self) {
if !self.is_inline() {
let mut cache = CACHE.lock().unwrap();
let heap = unsafe { Arc::from_raw(self.heap.as_ptr()) };
if Arc::strong_count(&heap) == 2 {
assert!(cache.remove(&**heap), "cache must have a reference if refcount is 2");
}
}
}
}
#[inline]
fn nonnull_from_arc(a: Arc<Box<str>>) -> NonNull<Box<str>> {
let raw: *const Box<str> = Arc::into_raw(a);
unsafe { NonNull::new_unchecked(raw as *mut Box<str>) }
}
#[derive(Clone, Copy, Hash, PartialEq)]
#[repr(C)] struct InlineRepr {
masked_len: u8,
contents: [u8; MAX_INLINE_SIZE],
}
const MAX_INLINE_SIZE: usize = std::mem::size_of::<NonNull<Box<str>>>() - 1;
static_assertions::const_assert!((std::u8::MAX >> 1) as usize >= MAX_INLINE_SIZE);
impl InlineRepr {
#[inline]
fn new(s: &str) -> Self {
assert!(s.len() <= MAX_INLINE_SIZE);
let masked_len = ((s.len() as u8) << 1) | 1;
let mut contents = [0u8; MAX_INLINE_SIZE];
contents[..s.len()].copy_from_slice(s.as_bytes());
Self { masked_len, contents }
}
#[inline]
fn as_str(&self) -> &str {
let len = self.masked_len >> 1;
unsafe { std::str::from_utf8_unchecked(&self.contents[..len as usize]) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use static_assertions::{const_assert, const_assert_eq};
use std::collections::BTreeSet;
use test_case::test_case;
#[cfg(not(target_os = "fuchsia"))]
use serial_test::serial;
fn reset_global_cache() {
match CACHE.lock() {
Ok(mut c) => *c = AHashSet::new(),
Err(e) => *e.into_inner() = AHashSet::new(),
}
}
fn num_strings_in_global_cache() -> usize {
CACHE.lock().unwrap().len()
}
impl RawRepr {
fn refcount(&self) -> Option<usize> {
match self.project() {
SafeRepr::Heap(ptr) => {
let tmp = unsafe { Arc::from_raw(ptr.as_ptr() as *const Box<str>) };
unsafe { Arc::increment_strong_count(ptr.as_ptr()) };
let count = Arc::strong_count(&tmp) - 1;
Some(count)
}
SafeRepr::Inline(_) => None,
}
}
}
const SHORT_STRING: &str = "hello";
const_assert!(SHORT_STRING.len() < MAX_INLINE_SIZE);
const MAX_LEN_SHORT_STRING: &str = "hello!!";
const_assert_eq!(MAX_LEN_SHORT_STRING.len(), MAX_INLINE_SIZE);
const MIN_LEN_LONG_STRING: &str = "hello!!!";
const_assert_eq!(MIN_LEN_LONG_STRING.len(), MAX_INLINE_SIZE + 1);
const LONG_STRING: &str = "hello, world!!!!!!!!!!!!!!!!!!!!";
#[test_case("" ; "empty string")]
#[test_case(SHORT_STRING ; "short strings")]
#[test_case(MAX_LEN_SHORT_STRING ; "max len short strings")]
#[test_case(MIN_LEN_LONG_STRING ; "barely long strings")]
#[test_case(LONG_STRING ; "long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn string_formatting_is_equivalent_to_str(original: &str) {
reset_global_cache();
let cached = FlyStr::new(original);
assert_eq!(format!("{original}"), format!("{cached}"));
assert_eq!(format!("{original:?}"), format!("{cached:?}"));
}
#[test_case("" ; "empty string")]
#[test_case(SHORT_STRING ; "short strings")]
#[test_case(MAX_LEN_SHORT_STRING ; "max len short strings")]
#[test_case(MIN_LEN_LONG_STRING ; "barely long strings")]
#[test_case(LONG_STRING ; "long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn string_equality_works(contents: &str) {
reset_global_cache();
let cached = FlyStr::new(contents);
assert_eq!(cached, cached.clone(), "must be equal to itself");
assert_eq!(cached, contents, "must be equal to the original");
assert_eq!(cached, contents.to_owned(), "must be equal to an owned copy of the original");
assert_ne!(cached, "goodbye");
}
#[test_case("", SHORT_STRING ; "empty and short string")]
#[test_case(SHORT_STRING, MAX_LEN_SHORT_STRING ; "two short strings")]
#[test_case(MAX_LEN_SHORT_STRING, MIN_LEN_LONG_STRING ; "short and long strings")]
#[test_case(MIN_LEN_LONG_STRING, LONG_STRING ; "barely long and long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn string_comparison_works(lesser: &str, greater: &str) {
reset_global_cache();
let lesser = FlyStr::new(lesser);
let greater = FlyStr::new(greater);
assert!(lesser < greater);
assert!(lesser <= greater);
assert!(greater > lesser);
assert!(greater >= lesser);
}
#[test_case("" ; "empty string")]
#[test_case(SHORT_STRING ; "short strings")]
#[test_case(MAX_LEN_SHORT_STRING ; "max len short strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn no_allocations_for_short_strings(contents: &str) {
reset_global_cache();
assert_eq!(num_strings_in_global_cache(), 0);
let original = FlyStr::new(contents);
assert_eq!(num_strings_in_global_cache(), 0);
assert_eq!(original.0.refcount(), None);
let cloned = original.clone();
assert_eq!(num_strings_in_global_cache(), 0);
assert_eq!(cloned.0.refcount(), None);
let deduped = FlyStr::new(contents);
assert_eq!(num_strings_in_global_cache(), 0);
assert_eq!(deduped.0.refcount(), None);
}
#[test_case(MIN_LEN_LONG_STRING ; "barely long strings")]
#[test_case(LONG_STRING ; "long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn only_one_copy_allocated_for_long_strings(contents: &str) {
reset_global_cache();
assert_eq!(num_strings_in_global_cache(), 0);
let original = FlyStr::new(contents);
assert_eq!(num_strings_in_global_cache(), 1, "only one string allocated");
assert_eq!(original.0.refcount(), Some(2), "one copy on stack, one in cache");
let cloned = original.clone();
assert_eq!(num_strings_in_global_cache(), 1, "cloning just incremented refcount");
assert_eq!(cloned.0.refcount(), Some(3), "two copies on stack, one in cache");
let deduped = FlyStr::new(contents);
assert_eq!(num_strings_in_global_cache(), 1, "new string was deduped");
assert_eq!(deduped.0.refcount(), Some(4), "three copies on stack, one in cache");
}
#[test_case(MIN_LEN_LONG_STRING ; "barely long strings")]
#[test_case(LONG_STRING ; "long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn cached_strings_dropped_when_refs_dropped(contents: &str) {
reset_global_cache();
let alloced = FlyStr::new(contents);
assert_eq!(num_strings_in_global_cache(), 1, "only one string allocated");
drop(alloced);
assert_eq!(num_strings_in_global_cache(), 0, "last reference dropped");
}
#[test_case("", SHORT_STRING ; "empty and short string")]
#[test_case(SHORT_STRING, MAX_LEN_SHORT_STRING ; "two short strings")]
#[test_case(SHORT_STRING, LONG_STRING ; "short and long strings")]
#[test_case(LONG_STRING, MAX_LEN_SHORT_STRING ; "long and max-len-short strings")]
#[test_case(MIN_LEN_LONG_STRING, LONG_STRING ; "barely long and long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn equality_and_hashing_with_pointer_value_works_correctly(first: &str, second: &str) {
reset_global_cache();
let first = FlyStr::new(first);
let second = FlyStr::new(second);
let mut set = AHashSet::new();
set.insert(first.clone());
assert!(set.contains(&first));
assert!(!set.contains(&second));
set.insert(first);
assert_eq!(set.len(), 1, "set did not grow because the same string was inserted as before");
set.insert(second.clone());
assert_eq!(set.len(), 2, "inserting a different string must mutate the set");
assert!(set.contains(&second));
set.insert(second);
assert_eq!(set.len(), 2);
}
#[test_case("", SHORT_STRING ; "empty and short string")]
#[test_case(SHORT_STRING, MAX_LEN_SHORT_STRING ; "two short strings")]
#[test_case(SHORT_STRING, LONG_STRING ; "short and long strings")]
#[test_case(LONG_STRING, MAX_LEN_SHORT_STRING ; "long and max-len-short strings")]
#[test_case(MIN_LEN_LONG_STRING, LONG_STRING ; "barely long and long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn comparison_for_btree_storage_works(first: &str, second: &str) {
reset_global_cache();
let first = FlyStr::new(first);
let second = FlyStr::new(second);
let mut set = BTreeSet::new();
set.insert(first.clone());
assert!(set.contains(&first));
assert!(!set.contains(&second));
set.insert(first);
assert_eq!(set.len(), 1, "set did not grow because the same string was inserted as before");
set.insert(second.clone());
assert_eq!(set.len(), 2, "inserting a different string must mutate the set");
assert!(set.contains(&second));
set.insert(second);
assert_eq!(set.len(), 2);
}
#[test_case("" ; "empty string")]
#[test_case(SHORT_STRING ; "short strings")]
#[test_case(MAX_LEN_SHORT_STRING ; "max len short strings")]
#[test_case(MIN_LEN_LONG_STRING ; "min len long strings")]
#[test_case(LONG_STRING ; "long strings")]
#[cfg_attr(not(target_os = "fuchsia"), serial)]
fn serde_works(contents: &str) {
reset_global_cache();
let s = FlyStr::new(contents);
let as_json = serde_json::to_string(&s).unwrap();
assert_eq!(as_json, format!("\"{contents}\""));
assert_eq!(s, serde_json::from_str::<FlyStr>(&as_json).unwrap());
}
}