diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 86be90427d7..386efa2f0c3 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -7,7 +7,7 @@ use alloc::{ use core::ops::Range; use metal::{ MTLIndexType, MTLLoadAction, MTLPrimitiveType, MTLScissorRect, MTLSize, MTLStoreAction, - MTLViewport, MTLVisibilityResultMode, NSRange, + MTLViewport, MTLVisibilityResultMode, NSRange, NSUInteger, }; use smallvec::SmallVec; @@ -31,6 +31,75 @@ impl Default for super::CommandState { } } +/// Helper for passing encoders to `update_bind_group_state`. +/// +/// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for +/// that stage. +enum Encoder<'e> { + Vertex(&'e metal::RenderCommandEncoder), + Fragment(&'e metal::RenderCommandEncoder), + Task(&'e metal::RenderCommandEncoder), + Mesh(&'e metal::RenderCommandEncoder), + Compute(&'e metal::ComputeCommandEncoder), +} + +impl Encoder<'_> { + fn stage(&self) -> naga::ShaderStage { + match self { + Self::Vertex(_) => naga::ShaderStage::Vertex, + Self::Fragment(_) => naga::ShaderStage::Fragment, + Self::Task(_) => naga::ShaderStage::Task, + Self::Mesh(_) => naga::ShaderStage::Mesh, + Self::Compute(_) => naga::ShaderStage::Compute, + } + } + + fn set_buffer( + &self, + index: NSUInteger, + buffer: Option<&metal::BufferRef>, + offset: wgt::BufferAddress, + ) { + match *self { + Self::Vertex(enc) => enc.set_vertex_buffer(index, buffer, offset), + Self::Fragment(enc) => enc.set_fragment_buffer(index, buffer, offset), + Self::Task(enc) => enc.set_object_buffer(index, buffer, offset), + Self::Mesh(enc) => enc.set_mesh_buffer(index, buffer, offset), + Self::Compute(enc) => enc.set_buffer(index, buffer, offset), + } + } + + fn set_bytes(&self, index: NSUInteger, length: u64, bytes: *const core::ffi::c_void) { + match *self { + Self::Vertex(enc) => enc.set_vertex_bytes(index, length, bytes), + Self::Fragment(enc) => enc.set_fragment_bytes(index, length, bytes), + Self::Task(enc) => enc.set_object_bytes(index, length, bytes), + Self::Mesh(enc) => enc.set_mesh_bytes(index, length, bytes), + Self::Compute(enc) => enc.set_bytes(index, length, bytes), + } + } + + fn set_sampler_state(&self, index: NSUInteger, state: Option<&metal::SamplerStateRef>) { + match *self { + Self::Vertex(enc) => enc.set_vertex_sampler_state(index, state), + Self::Fragment(enc) => enc.set_fragment_sampler_state(index, state), + Self::Task(enc) => enc.set_object_sampler_state(index, state), + Self::Mesh(enc) => enc.set_mesh_sampler_state(index, state), + Self::Compute(enc) => enc.set_sampler_state(index, state), + } + } + + fn set_texture(&self, index: NSUInteger, texture: Option<&metal::TextureRef>) { + match *self { + Self::Vertex(enc) => enc.set_vertex_texture(index, texture), + Self::Fragment(enc) => enc.set_fragment_texture(index, texture), + Self::Task(enc) => enc.set_object_texture(index, texture), + Self::Mesh(enc) => enc.set_mesh_texture(index, texture), + Self::Compute(enc) => enc.set_texture(index, texture), + } + } +} + impl super::CommandEncoder { pub fn raw_command_buffer(&self) -> Option<&metal::CommandBuffer> { self.raw_cmd_buf.as_ref() @@ -146,31 +215,29 @@ impl super::CommandEncoder { } /// Updates the bindings for a single shader stage, called in `set_bind_group`. - #[expect(clippy::too_many_arguments)] fn update_bind_group_state( &mut self, - stage: naga::ShaderStage, - render_encoder: Option<&metal::RenderCommandEncoder>, - compute_encoder: Option<&metal::ComputeCommandEncoder>, + encoder: Encoder<'_>, index_base: super::ResourceData, bg_info: &super::BindGroupLayoutInfo, dynamic_offsets: &[wgt::DynamicOffset], group_index: u32, group: &super::BindGroup, ) { - let resource_indices = match stage { - naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, - naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, - naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, - naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, - naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, + use naga::ShaderStage as S; + let resource_indices = match encoder.stage() { + S::Vertex => &bg_info.base_resource_indices.vs, + S::Fragment => &bg_info.base_resource_indices.fs, + S::Task => &bg_info.base_resource_indices.ts, + S::Mesh => &bg_info.base_resource_indices.ms, + S::Compute => &bg_info.base_resource_indices.cs, }; - let buffers = match stage { - naga::ShaderStage::Vertex => group.counters.vs.buffers, - naga::ShaderStage::Fragment => group.counters.fs.buffers, - naga::ShaderStage::Task => group.counters.ts.buffers, - naga::ShaderStage::Mesh => group.counters.ms.buffers, - naga::ShaderStage::Compute => group.counters.cs.buffers, + let buffers = match encoder.stage() { + S::Vertex => group.counters.vs.buffers, + S::Fragment => group.counters.fs.buffers, + S::Task => group.counters.ts.buffers, + S::Mesh => group.counters.ms.buffers, + S::Compute => group.counters.cs.buffers, }; let mut changes_sizes_buffer = false; for index in 0..buffers { @@ -179,18 +246,9 @@ impl super::CommandEncoder { if let Some(dyn_index) = buf.dynamic_index { offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; } - let a1 = (resource_indices.buffers + index) as u64; - let a2 = Some(buf.ptr.as_native()); - let a3 = offset; - match stage { - naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3), - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) - } - naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3), - naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3), - naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3), - } + let index = (resource_indices.buffers + index) as u64; + let buffer = Some(buf.ptr.as_native()); + encoder.set_buffer(index, buffer, offset); if let Some(size) = buf.binding_size { let br = naga::ResourceBinding { group: group_index, @@ -203,66 +261,40 @@ impl super::CommandEncoder { if changes_sizes_buffer { if let Some((index, sizes)) = self .state - .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) + .make_sizes_buffer_update(encoder.stage(), &mut self.temp.binding_sizes) { - let a1 = index as _; - let a2 = (sizes.len() * WORD_SIZE) as u64; - let a3 = sizes.as_ptr().cast(); - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) - } - naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3), - naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3), - naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3), - } + let index = index as _; + let length = (sizes.len() * WORD_SIZE) as u64; + let bytes_ptr = sizes.as_ptr().cast(); + encoder.set_bytes(index, length, bytes_ptr); } } - let samplers = match stage { - naga::ShaderStage::Vertex => group.counters.vs.samplers, - naga::ShaderStage::Fragment => group.counters.fs.samplers, - naga::ShaderStage::Task => group.counters.ts.samplers, - naga::ShaderStage::Mesh => group.counters.ms.samplers, - naga::ShaderStage::Compute => group.counters.cs.samplers, + let samplers = match encoder.stage() { + S::Vertex => group.counters.vs.samplers, + S::Fragment => group.counters.fs.samplers, + S::Task => group.counters.ts.samplers, + S::Mesh => group.counters.ms.samplers, + S::Compute => group.counters.cs.samplers, }; for index in 0..samplers { let res = group.samplers[(index_base.samplers + index) as usize]; - let a1 = (resource_indices.samplers + index) as u64; - let a2 = Some(res.as_native()); - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_sampler_state(a1, a2) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_sampler_state(a1, a2) - } - naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2), - naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2), - naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2), - } + let index = (resource_indices.samplers + index) as u64; + let state = Some(res.as_native()); + encoder.set_sampler_state(index, state); } - let textures = match stage { - naga::ShaderStage::Vertex => group.counters.vs.textures, - naga::ShaderStage::Fragment => group.counters.fs.textures, - naga::ShaderStage::Task => group.counters.ts.textures, - naga::ShaderStage::Mesh => group.counters.ms.textures, - naga::ShaderStage::Compute => group.counters.cs.textures, + let textures = match encoder.stage() { + S::Vertex => group.counters.vs.textures, + S::Fragment => group.counters.fs.textures, + S::Task => group.counters.ts.textures, + S::Mesh => group.counters.ms.textures, + S::Compute => group.counters.cs.textures, }; for index in 0..textures { let res = group.textures[(index_base.textures + index) as usize]; - let a1 = (resource_indices.textures + index) as u64; - let a2 = Some(res.as_native()); - match stage { - naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2), - naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2), - naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2), - naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), - naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), - } + let index = (resource_indices.textures + index) as u64; + let texture = Some(res.as_native()); + encoder.set_texture(index, texture); } } } @@ -826,9 +858,7 @@ impl crate::CommandEncoder for super::CommandEncoder { let compute_encoder = self.state.compute.clone(); if let Some(encoder) = render_encoder { self.update_bind_group_state( - naga::ShaderStage::Vertex, - Some(&encoder), - None, + Encoder::Vertex(&encoder), // All zeros, as vs comes first super::ResourceData::default(), bg_info, @@ -837,9 +867,7 @@ impl crate::CommandEncoder for super::CommandEncoder { group, ); self.update_bind_group_state( - naga::ShaderStage::Task, - Some(&encoder), - None, + Encoder::Task(&encoder), // All zeros, as ts comes first super::ResourceData::default(), bg_info, @@ -848,9 +876,7 @@ impl crate::CommandEncoder for super::CommandEncoder { group, ); self.update_bind_group_state( - naga::ShaderStage::Mesh, - Some(&encoder), - None, + Encoder::Mesh(&encoder), group.counters.ts.clone(), bg_info, dynamic_offsets, @@ -858,9 +884,7 @@ impl crate::CommandEncoder for super::CommandEncoder { group, ); self.update_bind_group_state( - naga::ShaderStage::Fragment, - Some(&encoder), - None, + Encoder::Fragment(&encoder), super::ResourceData { buffers: group.counters.vs.buffers + group.counters.ts.buffers @@ -884,9 +908,7 @@ impl crate::CommandEncoder for super::CommandEncoder { } if let Some(encoder) = compute_encoder { self.update_bind_group_state( - naga::ShaderStage::Compute, - None, - Some(&encoder), + Encoder::Compute(&encoder), super::ResourceData { buffers: group.counters.vs.buffers + group.counters.ts.buffers