1use crate::prelude::*;
2use crate::vk;
3use crate::RawPtr;
4use crate::{Device, Instance};
5use std::ffi::CStr;
6use std::mem;
7
8#[derive(Clone)]
9pub struct RayTracing {
10 handle: vk::Device,
11 fp: vk::NvRayTracingFn,
12}
13
14impl RayTracing {
15 pub fn new(instance: &Instance, device: &Device) -> Self {
16 let handle = device.handle();
17 let fp = vk::NvRayTracingFn::load(|name| unsafe {
18 mem::transmute(instance.get_device_proc_addr(handle, name.as_ptr()))
19 });
20 Self { handle, fp }
21 }
22
23 pub unsafe fn get_properties(
24 instance: &Instance,
25 pdevice: vk::PhysicalDevice,
26 ) -> vk::PhysicalDeviceRayTracingPropertiesNV {
27 let mut props_rt = vk::PhysicalDeviceRayTracingPropertiesNV::default();
28 {
29 let mut props = vk::PhysicalDeviceProperties2::builder().push_next(&mut props_rt);
30 instance.get_physical_device_properties2(pdevice, &mut props);
31 }
32 props_rt
33 }
34
35 pub unsafe fn create_acceleration_structure(
37 &self,
38 create_info: &vk::AccelerationStructureCreateInfoNV,
39 allocation_callbacks: Option<&vk::AllocationCallbacks>,
40 ) -> VkResult<vk::AccelerationStructureNV> {
41 let mut accel_struct = mem::zeroed();
42 (self.fp.create_acceleration_structure_nv)(
43 self.handle,
44 create_info,
45 allocation_callbacks.as_raw_ptr(),
46 &mut accel_struct,
47 )
48 .result_with_success(accel_struct)
49 }
50
51 pub unsafe fn destroy_acceleration_structure(
53 &self,
54 accel_struct: vk::AccelerationStructureNV,
55 allocation_callbacks: Option<&vk::AllocationCallbacks>,
56 ) {
57 (self.fp.destroy_acceleration_structure_nv)(
58 self.handle,
59 accel_struct,
60 allocation_callbacks.as_raw_ptr(),
61 );
62 }
63
64 pub unsafe fn get_acceleration_structure_memory_requirements(
66 &self,
67 info: &vk::AccelerationStructureMemoryRequirementsInfoNV,
68 ) -> vk::MemoryRequirements2KHR {
69 let mut requirements = mem::zeroed();
70 (self.fp.get_acceleration_structure_memory_requirements_nv)(
71 self.handle,
72 info,
73 &mut requirements,
74 );
75 requirements
76 }
77
78 pub unsafe fn bind_acceleration_structure_memory(
80 &self,
81 bind_info: &[vk::BindAccelerationStructureMemoryInfoNV],
82 ) -> VkResult<()> {
83 (self.fp.bind_acceleration_structure_memory_nv)(
84 self.handle,
85 bind_info.len() as u32,
86 bind_info.as_ptr(),
87 )
88 .result()
89 }
90
91 pub unsafe fn cmd_build_acceleration_structure(
93 &self,
94 command_buffer: vk::CommandBuffer,
95 info: &vk::AccelerationStructureInfoNV,
96 instance_data: vk::Buffer,
97 instance_offset: vk::DeviceSize,
98 update: bool,
99 dst: vk::AccelerationStructureNV,
100 src: vk::AccelerationStructureNV,
101 scratch: vk::Buffer,
102 scratch_offset: vk::DeviceSize,
103 ) {
104 (self.fp.cmd_build_acceleration_structure_nv)(
105 command_buffer,
106 info,
107 instance_data,
108 instance_offset,
109 if update { vk::TRUE } else { vk::FALSE },
110 dst,
111 src,
112 scratch,
113 scratch_offset,
114 );
115 }
116
117 pub unsafe fn cmd_copy_acceleration_structure(
119 &self,
120 command_buffer: vk::CommandBuffer,
121 dst: vk::AccelerationStructureNV,
122 src: vk::AccelerationStructureNV,
123 mode: vk::CopyAccelerationStructureModeNV,
124 ) {
125 (self.fp.cmd_copy_acceleration_structure_nv)(command_buffer, dst, src, mode);
126 }
127
128 pub unsafe fn cmd_trace_rays(
130 &self,
131 command_buffer: vk::CommandBuffer,
132 raygen_shader_binding_table_buffer: vk::Buffer,
133 raygen_shader_binding_offset: vk::DeviceSize,
134 miss_shader_binding_table_buffer: vk::Buffer,
135 miss_shader_binding_offset: vk::DeviceSize,
136 miss_shader_binding_stride: vk::DeviceSize,
137 hit_shader_binding_table_buffer: vk::Buffer,
138 hit_shader_binding_offset: vk::DeviceSize,
139 hit_shader_binding_stride: vk::DeviceSize,
140 callable_shader_binding_table_buffer: vk::Buffer,
141 callable_shader_binding_offset: vk::DeviceSize,
142 callable_shader_binding_stride: vk::DeviceSize,
143 width: u32,
144 height: u32,
145 depth: u32,
146 ) {
147 (self.fp.cmd_trace_rays_nv)(
148 command_buffer,
149 raygen_shader_binding_table_buffer,
150 raygen_shader_binding_offset,
151 miss_shader_binding_table_buffer,
152 miss_shader_binding_offset,
153 miss_shader_binding_stride,
154 hit_shader_binding_table_buffer,
155 hit_shader_binding_offset,
156 hit_shader_binding_stride,
157 callable_shader_binding_table_buffer,
158 callable_shader_binding_offset,
159 callable_shader_binding_stride,
160 width,
161 height,
162 depth,
163 );
164 }
165
166 pub unsafe fn create_ray_tracing_pipelines(
168 &self,
169 pipeline_cache: vk::PipelineCache,
170 create_info: &[vk::RayTracingPipelineCreateInfoNV],
171 allocation_callbacks: Option<&vk::AllocationCallbacks>,
172 ) -> VkResult<Vec<vk::Pipeline>> {
173 let mut pipelines = vec![mem::zeroed(); create_info.len()];
174 (self.fp.create_ray_tracing_pipelines_nv)(
175 self.handle,
176 pipeline_cache,
177 create_info.len() as u32,
178 create_info.as_ptr(),
179 allocation_callbacks.as_raw_ptr(),
180 pipelines.as_mut_ptr(),
181 )
182 .result_with_success(pipelines)
183 }
184
185 pub unsafe fn get_ray_tracing_shader_group_handles(
187 &self,
188 pipeline: vk::Pipeline,
189 first_group: u32,
190 group_count: u32,
191 data: &mut [u8],
192 ) -> VkResult<()> {
193 (self.fp.get_ray_tracing_shader_group_handles_nv)(
194 self.handle,
195 pipeline,
196 first_group,
197 group_count,
198 data.len(),
199 data.as_mut_ptr() as *mut std::ffi::c_void,
200 )
201 .result()
202 }
203
204 pub unsafe fn get_acceleration_structure_handle(
206 &self,
207 accel_struct: vk::AccelerationStructureNV,
208 ) -> VkResult<u64> {
209 let mut handle: u64 = 0;
210 let handle_ptr: *mut u64 = &mut handle;
211 (self.fp.get_acceleration_structure_handle_nv)(
212 self.handle,
213 accel_struct,
214 std::mem::size_of::<u64>(),
215 handle_ptr as *mut std::ffi::c_void,
216 )
217 .result_with_success(handle)
218 }
219
220 pub unsafe fn cmd_write_acceleration_structures_properties(
222 &self,
223 command_buffer: vk::CommandBuffer,
224 structures: &[vk::AccelerationStructureNV],
225 query_type: vk::QueryType,
226 query_pool: vk::QueryPool,
227 first_query: u32,
228 ) {
229 (self.fp.cmd_write_acceleration_structures_properties_nv)(
230 command_buffer,
231 structures.len() as u32,
232 structures.as_ptr(),
233 query_type,
234 query_pool,
235 first_query,
236 );
237 }
238
239 pub unsafe fn compile_deferred(&self, pipeline: vk::Pipeline, shader: u32) -> VkResult<()> {
241 (self.fp.compile_deferred_nv)(self.handle, pipeline, shader).result()
242 }
243
244 pub const fn name() -> &'static CStr {
245 vk::NvRayTracingFn::name()
246 }
247
248 pub fn fp(&self) -> &vk::NvRayTracingFn {
249 &self.fp
250 }
251
252 pub fn device(&self) -> vk::Device {
253 self.handle
254 }
255}