Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 90 additions & 88 deletions wgpu-hal/src/metal/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ impl Default for super::CommandState {
}
}

/// Helper for passing encoders to `update_bind_group_state`.
///
/// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for
/// that stage.
enum Encoder<'e> {
Vertex(&'e metal::RenderCommandEncoder),
Fragment(&'e metal::RenderCommandEncoder),
Task(&'e metal::RenderCommandEncoder),
Mesh(&'e metal::RenderCommandEncoder),
Compute(&'e metal::ComputeCommandEncoder),
}

impl Encoder<'_> {
fn stage(&self) -> naga::ShaderStage {
match self {
Self::Vertex(_) => naga::ShaderStage::Vertex,
Self::Fragment(_) => naga::ShaderStage::Fragment,
Self::Task(_) => naga::ShaderStage::Task,
Self::Mesh(_) => naga::ShaderStage::Mesh,
Self::Compute(_) => naga::ShaderStage::Compute,
}
}
}

impl super::CommandEncoder {
pub fn raw_command_buffer(&self) -> Option<&metal::CommandBuffer> {
self.raw_cmd_buf.as_ref()
Expand Down Expand Up @@ -146,31 +170,30 @@ impl super::CommandEncoder {
}

/// Updates the bindings for a single shader stage, called in `set_bind_group`.
#[expect(clippy::too_many_arguments)]
fn update_bind_group_state(
&mut self,
stage: naga::ShaderStage,
render_encoder: Option<&metal::RenderCommandEncoder>,
compute_encoder: Option<&metal::ComputeCommandEncoder>,
Comment on lines -149 to -154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad we've been able to get rid of some of this!

encoder: Encoder<'_>,
index_base: super::ResourceData<u32>,
bg_info: &super::BindGroupLayoutInfo,
dynamic_offsets: &[wgt::DynamicOffset],
group_index: u32,
group: &super::BindGroup,
) {
let resource_indices = match stage {
naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs,
naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs,
naga::ShaderStage::Task => &bg_info.base_resource_indices.ts,
naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms,
naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs,
use naga::ShaderStage as S;
use Encoder as E;
let resource_indices = match encoder.stage() {
S::Vertex => &bg_info.base_resource_indices.vs,
S::Fragment => &bg_info.base_resource_indices.fs,
S::Task => &bg_info.base_resource_indices.ts,
S::Mesh => &bg_info.base_resource_indices.ms,
S::Compute => &bg_info.base_resource_indices.cs,
};
let buffers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.buffers,
naga::ShaderStage::Fragment => group.counters.fs.buffers,
naga::ShaderStage::Task => group.counters.ts.buffers,
naga::ShaderStage::Mesh => group.counters.ms.buffers,
naga::ShaderStage::Compute => group.counters.cs.buffers,
let buffers = match encoder.stage() {
S::Vertex => group.counters.vs.buffers,
S::Fragment => group.counters.fs.buffers,
S::Task => group.counters.ts.buffers,
S::Mesh => group.counters.ms.buffers,
S::Compute => group.counters.cs.buffers,
};
let mut changes_sizes_buffer = false;
for index in 0..buffers {
Expand All @@ -179,17 +202,14 @@ impl super::CommandEncoder {
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
let a1 = (resource_indices.buffers + index) as u64;
let a2 = Some(buf.ptr.as_native());
let a3 = offset;
Comment on lines -182 to -184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eww gross, glad we got rid of "a1, a2, a3"

match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3),
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_buffer(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3),
let index = (resource_indices.buffers + index) as u64;
let buf_ptr = Some(buf.ptr.as_native());
match encoder {
E::Vertex(encoder) => encoder.set_vertex_buffer(index, buf_ptr, offset),
E::Fragment(encoder) => encoder.set_fragment_buffer(index, buf_ptr, offset),
E::Task(encoder) => encoder.set_object_buffer(index, buf_ptr, offset),
E::Mesh(encoder) => encoder.set_mesh_buffer(index, buf_ptr, offset),
E::Compute(encoder) => encoder.set_buffer(index, buf_ptr, offset),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some new and somewhat similar logic for mesh pipeline creation that was handled in #8139, although I doubt it would be applicable here. Still, maybe we could move this into a function on the encoder to simplify this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the logic from #8139?

My preference would be to defer additional refactoring (extracting these to encoder methods) until after objc2 lands, to minimize conflict resolution. Do you mind putting it off for a bit? I can file an issue so I don't forget.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really mind. Yeah most of the code I'm commenting on is originally mine, but since you introduced the Encoder struct, I'm thinking we might as well use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see what you mean now. In my head I was thinking of something more complicated -- while working on this I started looking at changing the encoder Options in CommandState to an enum, but realized that was going to be a much bigger change. When I first read your comment I was thinking of something like that.

Moving them onto Encoder isn't that bad, I'll just do it.

}
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
Expand All @@ -203,65 +223,57 @@ impl super::CommandEncoder {
if changes_sizes_buffer {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(stage, &mut self.temp.binding_sizes)
.make_sizes_buffer_update(encoder.stage(), &mut self.temp.binding_sizes)
{
let a1 = index as _;
let a2 = (sizes.len() * WORD_SIZE) as u64;
let a3 = sizes.as_ptr().cast();
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_bytes(a1, a2, a3)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_bytes(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3),
let index = index as _;
let length = (sizes.len() * WORD_SIZE) as u64;
let bytes_ptr = sizes.as_ptr().cast();
match encoder {
E::Vertex(encoder) => encoder.set_vertex_bytes(index, length, bytes_ptr),
E::Fragment(encoder) => encoder.set_fragment_bytes(index, length, bytes_ptr),
E::Task(encoder) => encoder.set_object_bytes(index, length, bytes_ptr),
E::Mesh(encoder) => encoder.set_mesh_bytes(index, length, bytes_ptr),
E::Compute(encoder) => encoder.set_bytes(index, length, bytes_ptr),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, lets put this on Encoder

}
}
}
let samplers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.samplers,
naga::ShaderStage::Fragment => group.counters.fs.samplers,
naga::ShaderStage::Task => group.counters.ts.samplers,
naga::ShaderStage::Mesh => group.counters.ms.samplers,
naga::ShaderStage::Compute => group.counters.cs.samplers,
let samplers = match encoder.stage() {
S::Vertex => group.counters.vs.samplers,
S::Fragment => group.counters.fs.samplers,
S::Task => group.counters.ts.samplers,
S::Mesh => group.counters.ms.samplers,
S::Compute => group.counters.cs.samplers,
};
for index in 0..samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
let a1 = (resource_indices.samplers + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_sampler_state(a1, a2)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_sampler_state(a1, a2)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2),
let index = (resource_indices.samplers + index) as u64;
let state = Some(res.as_native());
match encoder {
E::Vertex(encoder) => encoder.set_vertex_sampler_state(index, state),
E::Fragment(encoder) => encoder.set_fragment_sampler_state(index, state),
E::Task(encoder) => encoder.set_object_sampler_state(index, state),
E::Mesh(encoder) => encoder.set_mesh_sampler_state(index, state),
E::Compute(encoder) => encoder.set_sampler_state(index, state),
}
}

let textures = match stage {
naga::ShaderStage::Vertex => group.counters.vs.textures,
naga::ShaderStage::Fragment => group.counters.fs.textures,
naga::ShaderStage::Task => group.counters.ts.textures,
naga::ShaderStage::Mesh => group.counters.ms.textures,
naga::ShaderStage::Compute => group.counters.cs.textures,
let textures = match encoder.stage() {
S::Vertex => group.counters.vs.textures,
S::Fragment => group.counters.fs.textures,
S::Task => group.counters.ts.textures,
S::Mesh => group.counters.ms.textures,
S::Compute => group.counters.cs.textures,
};
for index in 0..textures {
let res = group.textures[(index_base.textures + index) as usize];
let a1 = (resource_indices.textures + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2),
naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2),
naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2),
let index = (resource_indices.textures + index) as u64;
let texture = Some(res.as_native());
match encoder {
E::Vertex(encoder) => encoder.set_vertex_texture(index, texture),
E::Fragment(encoder) => encoder.set_fragment_texture(index, texture),
E::Task(encoder) => encoder.set_object_texture(index, texture),
E::Mesh(encoder) => encoder.set_mesh_texture(index, texture),
E::Compute(encoder) => encoder.set_texture(index, texture),
}
}
}
Expand Down Expand Up @@ -826,9 +838,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
let compute_encoder = self.state.compute.clone();
if let Some(encoder) = render_encoder {
self.update_bind_group_state(
naga::ShaderStage::Vertex,
Some(&encoder),
None,
Encoder::Vertex(&encoder),
// All zeros, as vs comes first
super::ResourceData::default(),
bg_info,
Expand All @@ -837,9 +847,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
group,
);
self.update_bind_group_state(
naga::ShaderStage::Task,
Some(&encoder),
None,
Encoder::Task(&encoder),
// All zeros, as ts comes first
super::ResourceData::default(),
bg_info,
Expand All @@ -848,19 +856,15 @@ impl crate::CommandEncoder for super::CommandEncoder {
group,
);
self.update_bind_group_state(
naga::ShaderStage::Mesh,
Some(&encoder),
None,
Encoder::Mesh(&encoder),
group.counters.ts.clone(),
bg_info,
dynamic_offsets,
group_index,
group,
);
self.update_bind_group_state(
naga::ShaderStage::Fragment,
Some(&encoder),
None,
Encoder::Fragment(&encoder),
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
Expand All @@ -884,9 +888,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
if let Some(encoder) = compute_encoder {
self.update_bind_group_state(
naga::ShaderStage::Compute,
None,
Some(&encoder),
Encoder::Compute(&encoder),
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
Expand Down