-
Notifications
You must be signed in to change notification settings - Fork 1.2k
refactor(metal): Use descriptive names in update_bind_group_state
#8628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,30 @@ impl Default for super::CommandState { | |
| } | ||
| } | ||
|
|
||
| /// Helper for passing encoders to `update_bind_group_state`. | ||
| /// | ||
| /// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for | ||
| /// that stage. | ||
| enum Encoder<'e> { | ||
| Vertex(&'e metal::RenderCommandEncoder), | ||
| Fragment(&'e metal::RenderCommandEncoder), | ||
| Task(&'e metal::RenderCommandEncoder), | ||
| Mesh(&'e metal::RenderCommandEncoder), | ||
| Compute(&'e metal::ComputeCommandEncoder), | ||
| } | ||
|
|
||
| impl Encoder<'_> { | ||
| fn stage(&self) -> naga::ShaderStage { | ||
| match self { | ||
| Self::Vertex(_) => naga::ShaderStage::Vertex, | ||
| Self::Fragment(_) => naga::ShaderStage::Fragment, | ||
| Self::Task(_) => naga::ShaderStage::Task, | ||
| Self::Mesh(_) => naga::ShaderStage::Mesh, | ||
| Self::Compute(_) => naga::ShaderStage::Compute, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl super::CommandEncoder { | ||
| pub fn raw_command_buffer(&self) -> Option<&metal::CommandBuffer> { | ||
| self.raw_cmd_buf.as_ref() | ||
|
|
@@ -146,31 +170,30 @@ impl super::CommandEncoder { | |
| } | ||
|
|
||
| /// Updates the bindings for a single shader stage, called in `set_bind_group`. | ||
| #[expect(clippy::too_many_arguments)] | ||
| fn update_bind_group_state( | ||
| &mut self, | ||
| stage: naga::ShaderStage, | ||
| render_encoder: Option<&metal::RenderCommandEncoder>, | ||
| compute_encoder: Option<&metal::ComputeCommandEncoder>, | ||
| encoder: Encoder<'_>, | ||
| index_base: super::ResourceData<u32>, | ||
| bg_info: &super::BindGroupLayoutInfo, | ||
| dynamic_offsets: &[wgt::DynamicOffset], | ||
| group_index: u32, | ||
| group: &super::BindGroup, | ||
| ) { | ||
| let resource_indices = match stage { | ||
| naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, | ||
| naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, | ||
| naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, | ||
| naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, | ||
| naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, | ||
| use naga::ShaderStage as S; | ||
| use Encoder as E; | ||
| let resource_indices = match encoder.stage() { | ||
| S::Vertex => &bg_info.base_resource_indices.vs, | ||
| S::Fragment => &bg_info.base_resource_indices.fs, | ||
| S::Task => &bg_info.base_resource_indices.ts, | ||
| S::Mesh => &bg_info.base_resource_indices.ms, | ||
| S::Compute => &bg_info.base_resource_indices.cs, | ||
| }; | ||
| let buffers = match stage { | ||
| naga::ShaderStage::Vertex => group.counters.vs.buffers, | ||
| naga::ShaderStage::Fragment => group.counters.fs.buffers, | ||
| naga::ShaderStage::Task => group.counters.ts.buffers, | ||
| naga::ShaderStage::Mesh => group.counters.ms.buffers, | ||
| naga::ShaderStage::Compute => group.counters.cs.buffers, | ||
| let buffers = match encoder.stage() { | ||
| S::Vertex => group.counters.vs.buffers, | ||
| S::Fragment => group.counters.fs.buffers, | ||
| S::Task => group.counters.ts.buffers, | ||
| S::Mesh => group.counters.ms.buffers, | ||
| S::Compute => group.counters.cs.buffers, | ||
| }; | ||
| let mut changes_sizes_buffer = false; | ||
| for index in 0..buffers { | ||
|
|
@@ -179,17 +202,14 @@ impl super::CommandEncoder { | |
| if let Some(dyn_index) = buf.dynamic_index { | ||
| offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; | ||
| } | ||
| let a1 = (resource_indices.buffers + index) as u64; | ||
| let a2 = Some(buf.ptr.as_native()); | ||
| let a3 = offset; | ||
|
Comment on lines
-182
to
-184
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eww gross, glad we got rid of "a1, a2, a3" |
||
| match stage { | ||
| naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3), | ||
| naga::ShaderStage::Fragment => { | ||
| render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) | ||
| } | ||
| naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3), | ||
| naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3), | ||
| naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3), | ||
| let index = (resource_indices.buffers + index) as u64; | ||
| let buf_ptr = Some(buf.ptr.as_native()); | ||
| match encoder { | ||
| E::Vertex(encoder) => encoder.set_vertex_buffer(index, buf_ptr, offset), | ||
| E::Fragment(encoder) => encoder.set_fragment_buffer(index, buf_ptr, offset), | ||
| E::Task(encoder) => encoder.set_object_buffer(index, buf_ptr, offset), | ||
| E::Mesh(encoder) => encoder.set_mesh_buffer(index, buf_ptr, offset), | ||
| E::Compute(encoder) => encoder.set_buffer(index, buf_ptr, offset), | ||
|
||
| } | ||
| if let Some(size) = buf.binding_size { | ||
| let br = naga::ResourceBinding { | ||
|
|
@@ -203,65 +223,57 @@ impl super::CommandEncoder { | |
| if changes_sizes_buffer { | ||
| if let Some((index, sizes)) = self | ||
| .state | ||
| .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) | ||
| .make_sizes_buffer_update(encoder.stage(), &mut self.temp.binding_sizes) | ||
| { | ||
| let a1 = index as _; | ||
| let a2 = (sizes.len() * WORD_SIZE) as u64; | ||
| let a3 = sizes.as_ptr().cast(); | ||
| match stage { | ||
| naga::ShaderStage::Vertex => { | ||
| render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) | ||
| } | ||
| naga::ShaderStage::Fragment => { | ||
| render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) | ||
| } | ||
| naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3), | ||
| naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3), | ||
| naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3), | ||
| let index = index as _; | ||
| let length = (sizes.len() * WORD_SIZE) as u64; | ||
| let bytes_ptr = sizes.as_ptr().cast(); | ||
| match encoder { | ||
| E::Vertex(encoder) => encoder.set_vertex_bytes(index, length, bytes_ptr), | ||
| E::Fragment(encoder) => encoder.set_fragment_bytes(index, length, bytes_ptr), | ||
| E::Task(encoder) => encoder.set_object_bytes(index, length, bytes_ptr), | ||
| E::Mesh(encoder) => encoder.set_mesh_bytes(index, length, bytes_ptr), | ||
| E::Compute(encoder) => encoder.set_bytes(index, length, bytes_ptr), | ||
|
||
| } | ||
| } | ||
| } | ||
| let samplers = match stage { | ||
| naga::ShaderStage::Vertex => group.counters.vs.samplers, | ||
| naga::ShaderStage::Fragment => group.counters.fs.samplers, | ||
| naga::ShaderStage::Task => group.counters.ts.samplers, | ||
| naga::ShaderStage::Mesh => group.counters.ms.samplers, | ||
| naga::ShaderStage::Compute => group.counters.cs.samplers, | ||
| let samplers = match encoder.stage() { | ||
| S::Vertex => group.counters.vs.samplers, | ||
| S::Fragment => group.counters.fs.samplers, | ||
| S::Task => group.counters.ts.samplers, | ||
| S::Mesh => group.counters.ms.samplers, | ||
| S::Compute => group.counters.cs.samplers, | ||
| }; | ||
| for index in 0..samplers { | ||
| let res = group.samplers[(index_base.samplers + index) as usize]; | ||
| let a1 = (resource_indices.samplers + index) as u64; | ||
| let a2 = Some(res.as_native()); | ||
| match stage { | ||
| naga::ShaderStage::Vertex => { | ||
| render_encoder.unwrap().set_vertex_sampler_state(a1, a2) | ||
| } | ||
| naga::ShaderStage::Fragment => { | ||
| render_encoder.unwrap().set_fragment_sampler_state(a1, a2) | ||
| } | ||
| naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2), | ||
| naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2), | ||
| naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2), | ||
| let index = (resource_indices.samplers + index) as u64; | ||
| let state = Some(res.as_native()); | ||
| match encoder { | ||
| E::Vertex(encoder) => encoder.set_vertex_sampler_state(index, state), | ||
| E::Fragment(encoder) => encoder.set_fragment_sampler_state(index, state), | ||
| E::Task(encoder) => encoder.set_object_sampler_state(index, state), | ||
| E::Mesh(encoder) => encoder.set_mesh_sampler_state(index, state), | ||
| E::Compute(encoder) => encoder.set_sampler_state(index, state), | ||
| } | ||
| } | ||
|
|
||
| let textures = match stage { | ||
| naga::ShaderStage::Vertex => group.counters.vs.textures, | ||
| naga::ShaderStage::Fragment => group.counters.fs.textures, | ||
| naga::ShaderStage::Task => group.counters.ts.textures, | ||
| naga::ShaderStage::Mesh => group.counters.ms.textures, | ||
| naga::ShaderStage::Compute => group.counters.cs.textures, | ||
| let textures = match encoder.stage() { | ||
| S::Vertex => group.counters.vs.textures, | ||
| S::Fragment => group.counters.fs.textures, | ||
| S::Task => group.counters.ts.textures, | ||
| S::Mesh => group.counters.ms.textures, | ||
| S::Compute => group.counters.cs.textures, | ||
| }; | ||
| for index in 0..textures { | ||
| let res = group.textures[(index_base.textures + index) as usize]; | ||
| let a1 = (resource_indices.textures + index) as u64; | ||
| let a2 = Some(res.as_native()); | ||
| match stage { | ||
| naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2), | ||
| naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2), | ||
| naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2), | ||
| naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), | ||
| naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), | ||
| let index = (resource_indices.textures + index) as u64; | ||
| let texture = Some(res.as_native()); | ||
| match encoder { | ||
| E::Vertex(encoder) => encoder.set_vertex_texture(index, texture), | ||
| E::Fragment(encoder) => encoder.set_fragment_texture(index, texture), | ||
| E::Task(encoder) => encoder.set_object_texture(index, texture), | ||
| E::Mesh(encoder) => encoder.set_mesh_texture(index, texture), | ||
| E::Compute(encoder) => encoder.set_texture(index, texture), | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -826,9 +838,7 @@ impl crate::CommandEncoder for super::CommandEncoder { | |
| let compute_encoder = self.state.compute.clone(); | ||
| if let Some(encoder) = render_encoder { | ||
| self.update_bind_group_state( | ||
| naga::ShaderStage::Vertex, | ||
| Some(&encoder), | ||
| None, | ||
| Encoder::Vertex(&encoder), | ||
| // All zeros, as vs comes first | ||
| super::ResourceData::default(), | ||
| bg_info, | ||
|
|
@@ -837,9 +847,7 @@ impl crate::CommandEncoder for super::CommandEncoder { | |
| group, | ||
| ); | ||
| self.update_bind_group_state( | ||
| naga::ShaderStage::Task, | ||
| Some(&encoder), | ||
| None, | ||
| Encoder::Task(&encoder), | ||
| // All zeros, as ts comes first | ||
| super::ResourceData::default(), | ||
| bg_info, | ||
|
|
@@ -848,19 +856,15 @@ impl crate::CommandEncoder for super::CommandEncoder { | |
| group, | ||
| ); | ||
| self.update_bind_group_state( | ||
| naga::ShaderStage::Mesh, | ||
| Some(&encoder), | ||
| None, | ||
| Encoder::Mesh(&encoder), | ||
| group.counters.ts.clone(), | ||
| bg_info, | ||
| dynamic_offsets, | ||
| group_index, | ||
| group, | ||
| ); | ||
| self.update_bind_group_state( | ||
| naga::ShaderStage::Fragment, | ||
| Some(&encoder), | ||
| None, | ||
| Encoder::Fragment(&encoder), | ||
| super::ResourceData { | ||
| buffers: group.counters.vs.buffers | ||
| + group.counters.ts.buffers | ||
|
|
@@ -884,9 +888,7 @@ impl crate::CommandEncoder for super::CommandEncoder { | |
| } | ||
| if let Some(encoder) = compute_encoder { | ||
| self.update_bind_group_state( | ||
| naga::ShaderStage::Compute, | ||
| None, | ||
| Some(&encoder), | ||
| Encoder::Compute(&encoder), | ||
| super::ResourceData { | ||
| buffers: group.counters.vs.buffers | ||
| + group.counters.ts.buffers | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad we've been able to get rid of some of this!