1#[cfg(feature = "zdict_builder")]
18use std::io::{self, Read};
19
20pub use zstd_safe::{CDict, DDict};
21
22pub struct EncoderDictionary<'a> {
27 cdict: CDict<'a>,
28}
29
30impl EncoderDictionary<'static> {
31 pub fn copy(dictionary: &[u8], level: i32) -> Self {
35 Self {
36 cdict: zstd_safe::create_cdict(dictionary, level),
37 }
38 }
39}
40
41impl<'a> EncoderDictionary<'a> {
42 #[cfg(feature = "experimental")]
43 #[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "experimental")))]
44 pub fn new(dictionary: &'a [u8], level: i32) -> Self {
50 Self {
51 cdict: zstd_safe::create_cdict_by_reference(dictionary, level),
52 }
53 }
54
55 pub fn as_cdict(&self) -> &CDict<'a> {
57 &self.cdict
58 }
59}
60
61pub struct DecoderDictionary<'a> {
63 ddict: DDict<'a>,
64}
65
66impl DecoderDictionary<'static> {
67 pub fn copy(dictionary: &[u8]) -> Self {
71 Self {
72 ddict: zstd_safe::DDict::create(dictionary),
73 }
74 }
75}
76
77impl<'a> DecoderDictionary<'a> {
78 #[cfg(feature = "experimental")]
79 #[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "experimental")))]
80 pub fn new(dict: &'a [u8]) -> Self {
84 Self {
85 ddict: zstd_safe::create_ddict_by_reference(dict),
86 }
87 }
88
89 pub fn as_ddict(&self) -> &DDict<'a> {
91 &self.ddict
92 }
93}
94
95#[cfg(feature = "zdict_builder")]
100#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
101pub fn from_continuous(
102 sample_data: &[u8],
103 sample_sizes: &[usize],
104 max_size: usize,
105) -> io::Result<Vec<u8>> {
106 use crate::map_error_code;
107
108 if sample_sizes.iter().sum::<usize>() != sample_data.len() {
110 return Err(io::Error::new(
111 io::ErrorKind::Other,
112 "sample sizes don't add up".to_string(),
113 ));
114 }
115
116 let mut result = Vec::with_capacity(max_size);
117 zstd_safe::train_from_buffer(&mut result, sample_data, sample_sizes)
118 .map_err(map_error_code)?;
119 Ok(result)
120}
121
122#[cfg(feature = "zdict_builder")]
132#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
133pub fn from_samples<S: AsRef<[u8]>>(
134 samples: &[S],
135 max_size: usize,
136) -> io::Result<Vec<u8>> {
137 let data: Vec<_> =
139 samples.iter().flat_map(|s| s.as_ref()).cloned().collect();
140 let sizes: Vec<_> = samples.iter().map(|s| s.as_ref().len()).collect();
141
142 from_continuous(&data, &sizes, max_size)
143}
144
145#[cfg(feature = "zdict_builder")]
147#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
148pub fn from_files<I, P>(filenames: I, max_size: usize) -> io::Result<Vec<u8>>
149where
150 P: AsRef<std::path::Path>,
151 I: IntoIterator<Item = P>,
152{
153 use std::fs;
154
155 let mut buffer = Vec::new();
156 let mut sizes = Vec::new();
157
158 for filename in filenames {
159 let mut file = fs::File::open(filename)?;
160 let len = file.read_to_end(&mut buffer)?;
161 sizes.push(len);
162 }
163
164 from_continuous(&buffer, &sizes, max_size)
165}
166
167#[cfg(test)]
168#[cfg(feature = "zdict_builder")]
169mod tests {
170 use std::fs;
171 use std::io;
172 use std::io::Read;
173
174 use walkdir;
175
176 #[test]
177 fn test_dict_training() {
178 let paths: Vec<_> = walkdir::WalkDir::new("src")
180 .into_iter()
181 .map(|entry| entry.unwrap())
182 .map(|entry| entry.into_path())
183 .filter(|path| path.to_str().unwrap().ends_with(".rs"))
184 .collect();
185
186 let dict = super::from_files(&paths, 4000).unwrap();
187
188 for path in paths {
189 let mut buffer = Vec::new();
190 let mut file = fs::File::open(path).unwrap();
191 let mut content = Vec::new();
192 file.read_to_end(&mut content).unwrap();
193 io::copy(
194 &mut &content[..],
195 &mut crate::stream::Encoder::with_dictionary(
196 &mut buffer,
197 1,
198 &dict,
199 )
200 .unwrap()
201 .auto_finish(),
202 )
203 .unwrap();
204
205 let mut result = Vec::new();
206 io::copy(
207 &mut crate::stream::Decoder::with_dictionary(
208 &buffer[..],
209 &dict[..],
210 )
211 .unwrap(),
212 &mut result,
213 )
214 .unwrap();
215
216 assert_eq!(&content, &result);
217 }
218 }
219}