1use static_assertions::const_assert_eq;
6use std::sync::LazyLock;
7
8#[derive(Clone, Copy)]
9pub(crate) struct State {
10 pub(crate) buffer: XSaveArea,
11 strategy: Strategy,
12}
13
14pub const XSAVE_AREA_SIZE: usize = 832;
16
17const XSAVE_FEATURE_X87: u64 = 1 << 0;
18const XSAVE_FEATURE_SSE: u64 = 1 << 1;
19const XSAVE_FEATURE_AVX: u64 = 1 << 2;
20
21pub const SUPPORTED_XSAVE_FEATURES: u64 = XSAVE_FEATURE_X87 | XSAVE_FEATURE_SSE | XSAVE_FEATURE_AVX;
24
25#[derive(Clone, Copy, Default)]
26#[repr(C)]
27struct X87MMXState {
28 low: u64,
29 high: u64,
30}
31
32#[derive(Clone, Copy, Default)]
33#[repr(C)]
34struct SSERegister {
35 low: u64,
36 high: u64,
37}
38
39#[derive(Clone, Copy)]
41#[repr(C)]
42struct X86LegacySaveArea {
43 fcw: u16,
44 fsw: u16,
45 ftw: u8,
46 _reserved: u8,
47
48 fop: u16,
49 fip: u64,
50 fdp: u64,
51
52 mxcsr: u32,
53 mxcsr_mask: u32,
54
55 st: [X87MMXState; 8],
56
57 xmm: [SSERegister; 16],
58}
59
60const_assert_eq!(std::mem::size_of::<X86LegacySaveArea>(), 416);
61
62#[derive(Clone, Copy)]
63#[repr(C, align(16))]
64struct FXSaveArea {
65 x86_legacy_save_area: X86LegacySaveArea,
66 _reserved: [u8; 96],
67}
68const_assert_eq!(std::mem::size_of::<FXSaveArea>(), 512);
69
70impl Default for FXSaveArea {
71 fn default() -> Self {
72 Self {
73 x86_legacy_save_area: X86LegacySaveArea {
74 fcw: 0x37f, fsw: 0,
76 ftw: 0,
79 _reserved: Default::default(),
80 fop: 0,
81 fip: 0,
82 fdp: 0,
83 mxcsr: 0x3f << 7, mxcsr_mask: 0,
85 st: Default::default(),
86 xmm: Default::default(),
87 },
88 _reserved: [0; 96],
89 }
90 }
91}
92
93#[derive(Clone, Copy)]
94#[repr(C, align(64))]
95pub(crate) struct XSaveArea {
96 fxsave_area: FXSaveArea,
97 xsave_header: [u8; 64],
98 avx_state: [u8; 256],
100 }
103
104const_assert_eq!(std::mem::size_of::<XSaveArea>(), XSAVE_AREA_SIZE);
105
106impl XSaveArea {
107 fn addr(&self) -> *const u8 {
108 self as *const _ as *const u8
109 }
110
111 fn addr_mut(&mut self) -> *mut u8 {
112 self as *mut _ as *mut u8
113 }
114}
115
116impl Default for XSaveArea {
117 fn default() -> Self {
118 Self { fxsave_area: Default::default(), xsave_header: [0; 64], avx_state: [0; 256] }
119 }
120}
121
122#[derive(PartialEq, Debug, Copy, Clone, PartialOrd)]
123pub enum Strategy {
124 XSaveOpt,
125 XSave,
126 FXSave,
127}
128
129pub static PREFERRED_STRATEGY: LazyLock<Strategy> = LazyLock::new(|| {
130 if is_x86_feature_detected!("xsaveopt") {
131 Strategy::XSaveOpt
132 } else if is_x86_feature_detected!("xsave") {
133 Strategy::XSave
134 } else {
135 assert!(!is_x86_feature_detected!("avx"));
141 Strategy::FXSave
142 }
143});
144
145impl State {
146 pub fn with_strategy(strategy: Strategy) -> Self {
147 Self { buffer: XSaveArea::default(), strategy }
148 }
149
150 #[inline(always)]
151 pub(crate) fn save(&mut self) {
152 match self.strategy {
153 Strategy::XSaveOpt => unsafe {
154 std::arch::x86_64::_xsaveopt(self.buffer.addr_mut(), SUPPORTED_XSAVE_FEATURES);
155 },
156 Strategy::XSave => unsafe {
157 std::arch::x86_64::_xsave(self.buffer.addr_mut(), SUPPORTED_XSAVE_FEATURES);
158 },
159 Strategy::FXSave => unsafe {
160 std::arch::x86_64::_fxsave(self.buffer.addr_mut());
161 },
162 }
163 }
164
165 #[inline(always)]
166 pub(crate) unsafe fn restore(&self) {
168 match self.strategy {
169 Strategy::XSave | Strategy::XSaveOpt => {
170 std::arch::x86_64::_xrstor(self.buffer.addr(), SUPPORTED_XSAVE_FEATURES)
171 }
172 Strategy::FXSave => std::arch::x86_64::_fxrstor(self.buffer.addr()),
173 }
174 }
175
176 pub fn reset(&mut self) {
177 self.initialize_saved_area()
178 }
179
180 fn initialize_saved_area(&mut self) {
181 *self = Default::default()
182 }
183
184 pub(crate) fn set_xsave_area(&mut self, xsave_area: [u8; XSAVE_AREA_SIZE]) {
185 self.buffer = unsafe { std::mem::transmute(xsave_area) };
186
187 self.buffer.fxsave_area._reserved = [0u8; 96];
190 }
191}
192
193impl Default for State {
194 fn default() -> Self {
195 Self { buffer: XSaveArea::default(), strategy: *PREFERRED_STRATEGY }
196 }
197}
198
199#[cfg(test)]
200mod test {
201 use super::*;
202
203 #[::fuchsia::test]
204 fn save_restore_sse_registers() {
205 use core::arch::asm;
206
207 let write_custom_state = || {
208 let flt = [0u8; 8];
217 unsafe {
218 asm!("fstp dword ptr [{flt}]", flt = in(reg) &flt as *const u8);
219 }
220 let fpust = 0u16;
222 unsafe {
223 asm!("fnstsw [{fpust}]", fpust = in(reg)&fpust);
224 }
225 assert_eq!(fpust & 1 << 0, 0x1); assert_eq!(fpust & 1 << 6, 1 << 6); assert_eq!(fpust & 1 << 9, 0); let mut fpucw = 0u16;
231 unsafe {
232 asm!("fnstcw [{fpucw}]", fpucw = in(reg) &fpucw);
233 }
234 fpucw &= !0x3f;
236 unsafe {
237 asm!("fldcw [{fpucw}]", fpucw = in(reg) &fpucw);
238 }
239
240 let mut mxcsr = 0u32;
241 unsafe {
242 asm!("stmxcsr [{mxcsr}]", mxcsr = in(reg) &mxcsr);
243 }
244 mxcsr &= !(0x7 << 7);
246 unsafe {
247 asm!("ldmxcsr [{mxcsr}]", mxcsr = in(reg) &mxcsr);
248 }
249
250 let vals_a = [0x42u8; 16];
252 let vals_b = [0x43u8; 16];
253 let vals_c = [0x44u8; 16];
254 unsafe {
255 asm!("movups xmm0, [{vals_a}]
256 movups xmm1, [{vals_b}]
257 movups xmm2, [{vals_c}]",
258 vals_a = in(reg) &vals_a,
259 vals_b = in(reg) &vals_b,
260 vals_c = in(reg) &vals_c,
261 out("xmm0") _,
262 out("xmm1") _,
263 out("xmm2") _,
264 );
265 }
266 };
267
268 let clear_state = || {
269 unsafe {
270 asm!("fninit");
272 let mxcsr = 0x3f << 7;
274 asm!("ldmxcsr [{mxcsr}]", mxcsr = in(reg) &mxcsr);
275 asm!("xorps xmm0, xmm0
277 xorps xmm1, xmm1
278 xorps xmm2, xmm2",
279 out("xmm0") _,
280 out("xmm1") _,
281 out("xmm2") _,
282 );
283 }
284 };
285
286 let dest = [0u8; 16];
287 let validate_state_cleared = || {
288 let fpust = 0u16;
289 unsafe {
290 asm!("fnstsw [{fpust}]", fpust = in(reg)&fpust);
291 }
292 assert_eq!(fpust, 0);
293
294 let fpucw = 0u16;
295 unsafe { asm!("fnstcw [{fpucw}]", fpucw = in(reg) &fpucw) };
296 assert_eq!(fpucw, 0x37f); let mxcsr = 0u32;
299 unsafe {
300 asm!("stmxcsr [{mxcsr}]", mxcsr = in(reg) &mxcsr);
301 }
302 assert_eq!(mxcsr & 0x1f, 0); assert_eq!((mxcsr >> 7) & 0x3f, 0x3f); unsafe {
305 asm!("movups [{dest}], xmm0", dest = in(reg) &dest);
306 }
307 for i in 0..16 {
308 assert_eq!(dest[i], 0);
309 }
310 unsafe {
311 asm!("movups [{dest}], xmm1", dest = in(reg) &dest);
312 }
313 for i in 0..16 {
314 assert_eq!(dest[i], 0);
315 }
316 unsafe {
317 asm!("movups [{dest}], xmm2", dest = in(reg) &dest);
318 }
319 for i in 0..16 {
320 assert_eq!(dest[i], 0);
321 }
322 };
323
324 let validate_state_restored = || {
325 let fpust = 0u16;
329 unsafe {
330 asm!("fnstsw [{fpust}]", fpust = in(reg)&fpust);
331 }
332 assert_eq!(fpust & 1 << 0, 0x1); assert_eq!(fpust & 1 << 6, 1 << 6); assert_eq!(fpust & 1 << 9, 0); let fpucw = 0u16;
338 unsafe { asm!("fnstcw [{fpucw}]", fpucw = in(reg) &fpucw) };
339 assert_eq!(fpucw, 0x340); let mxcsr = 0u32;
342 unsafe {
343 asm!("stmxcsr [{mxcsr}]", mxcsr = in(reg) &mxcsr);
344 }
345 assert_eq!(mxcsr & 0x1f, 0); assert_eq!((mxcsr >> 7) & 0x3f, 0x38); unsafe {
350 asm!("movups [{dest}], xmm0", dest = in(reg) &dest);
351 }
352 for i in 0..16 {
353 assert_eq!(dest[i], 0x42);
354 }
355 unsafe {
356 asm!("movups [{dest}], xmm1", dest = in(reg) &dest);
357 }
358 for i in 0..16 {
359 assert_eq!(dest[i], 0x43);
360 }
361 unsafe {
362 asm!("movups [{dest}], xmm2", dest = in(reg) &dest);
363 }
364 for i in 0..16 {
365 assert_eq!(dest[i], 0x44);
366 }
367 };
368
369 let mut state = State::default();
370 write_custom_state();
371 state.save();
372 clear_state();
373 validate_state_cleared();
374 unsafe {
375 state.restore();
376 }
377 validate_state_restored();
378 }
379}