Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 115 additions & 93 deletions wgpu-hal/src/metal/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use alloc::{
use core::ops::Range;
use metal::{
MTLIndexType, MTLLoadAction, MTLPrimitiveType, MTLScissorRect, MTLSize, MTLStoreAction,
MTLViewport, MTLVisibilityResultMode, NSRange,
MTLViewport, MTLVisibilityResultMode, NSRange, NSUInteger,
};
use smallvec::SmallVec;

Expand All @@ -31,6 +31,75 @@ impl Default for super::CommandState {
}
}

/// Helper for passing encoders to `update_bind_group_state`.
///
/// Combines [`naga::ShaderStage`] and an encoder of the appropriate type for
/// that stage.
enum Encoder<'e> {
Vertex(&'e metal::RenderCommandEncoder),
Fragment(&'e metal::RenderCommandEncoder),
Task(&'e metal::RenderCommandEncoder),
Mesh(&'e metal::RenderCommandEncoder),
Compute(&'e metal::ComputeCommandEncoder),
}

impl Encoder<'_> {
fn stage(&self) -> naga::ShaderStage {
match self {
Self::Vertex(_) => naga::ShaderStage::Vertex,
Self::Fragment(_) => naga::ShaderStage::Fragment,
Self::Task(_) => naga::ShaderStage::Task,
Self::Mesh(_) => naga::ShaderStage::Mesh,
Self::Compute(_) => naga::ShaderStage::Compute,
}
}

fn set_buffer(
&self,
index: NSUInteger,
buffer: Option<&metal::BufferRef>,
offset: wgt::BufferAddress,
) {
match *self {
Self::Vertex(enc) => enc.set_vertex_buffer(index, buffer, offset),
Self::Fragment(enc) => enc.set_fragment_buffer(index, buffer, offset),
Self::Task(enc) => enc.set_object_buffer(index, buffer, offset),
Self::Mesh(enc) => enc.set_mesh_buffer(index, buffer, offset),
Self::Compute(enc) => enc.set_buffer(index, buffer, offset),
}
}

fn set_bytes(&self, index: NSUInteger, length: u64, bytes: *const core::ffi::c_void) {
match *self {
Self::Vertex(enc) => enc.set_vertex_bytes(index, length, bytes),
Self::Fragment(enc) => enc.set_fragment_bytes(index, length, bytes),
Self::Task(enc) => enc.set_object_bytes(index, length, bytes),
Self::Mesh(enc) => enc.set_mesh_bytes(index, length, bytes),
Self::Compute(enc) => enc.set_bytes(index, length, bytes),
}
}

fn set_sampler_state(&self, index: NSUInteger, state: Option<&metal::SamplerStateRef>) {
match *self {
Self::Vertex(enc) => enc.set_vertex_sampler_state(index, state),
Self::Fragment(enc) => enc.set_fragment_sampler_state(index, state),
Self::Task(enc) => enc.set_object_sampler_state(index, state),
Self::Mesh(enc) => enc.set_mesh_sampler_state(index, state),
Self::Compute(enc) => enc.set_sampler_state(index, state),
}
}

fn set_texture(&self, index: NSUInteger, texture: Option<&metal::TextureRef>) {
match *self {
Self::Vertex(enc) => enc.set_vertex_texture(index, texture),
Self::Fragment(enc) => enc.set_fragment_texture(index, texture),
Self::Task(enc) => enc.set_object_texture(index, texture),
Self::Mesh(enc) => enc.set_mesh_texture(index, texture),
Self::Compute(enc) => enc.set_texture(index, texture),
}
}
}

impl super::CommandEncoder {
pub fn raw_command_buffer(&self) -> Option<&metal::CommandBuffer> {
self.raw_cmd_buf.as_ref()
Expand Down Expand Up @@ -146,31 +215,29 @@ impl super::CommandEncoder {
}

/// Updates the bindings for a single shader stage, called in `set_bind_group`.
#[expect(clippy::too_many_arguments)]
fn update_bind_group_state(
&mut self,
stage: naga::ShaderStage,
render_encoder: Option<&metal::RenderCommandEncoder>,
compute_encoder: Option<&metal::ComputeCommandEncoder>,
Comment on lines -149 to -154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad we've been able to get rid of some of this!

encoder: Encoder<'_>,
index_base: super::ResourceData<u32>,
bg_info: &super::BindGroupLayoutInfo,
dynamic_offsets: &[wgt::DynamicOffset],
group_index: u32,
group: &super::BindGroup,
) {
let resource_indices = match stage {
naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs,
naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs,
naga::ShaderStage::Task => &bg_info.base_resource_indices.ts,
naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms,
naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs,
use naga::ShaderStage as S;
let resource_indices = match encoder.stage() {
S::Vertex => &bg_info.base_resource_indices.vs,
S::Fragment => &bg_info.base_resource_indices.fs,
S::Task => &bg_info.base_resource_indices.ts,
S::Mesh => &bg_info.base_resource_indices.ms,
S::Compute => &bg_info.base_resource_indices.cs,
};
let buffers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.buffers,
naga::ShaderStage::Fragment => group.counters.fs.buffers,
naga::ShaderStage::Task => group.counters.ts.buffers,
naga::ShaderStage::Mesh => group.counters.ms.buffers,
naga::ShaderStage::Compute => group.counters.cs.buffers,
let buffers = match encoder.stage() {
S::Vertex => group.counters.vs.buffers,
S::Fragment => group.counters.fs.buffers,
S::Task => group.counters.ts.buffers,
S::Mesh => group.counters.ms.buffers,
S::Compute => group.counters.cs.buffers,
};
let mut changes_sizes_buffer = false;
for index in 0..buffers {
Expand All @@ -179,18 +246,9 @@ impl super::CommandEncoder {
if let Some(dyn_index) = buf.dynamic_index {
offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress;
}
let a1 = (resource_indices.buffers + index) as u64;
let a2 = Some(buf.ptr.as_native());
let a3 = offset;
Comment on lines -182 to -184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eww gross, glad we got rid of "a1, a2, a3"

match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3),
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_buffer(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3),
}
let index = (resource_indices.buffers + index) as u64;
let buffer = Some(buf.ptr.as_native());
encoder.set_buffer(index, buffer, offset);
if let Some(size) = buf.binding_size {
let br = naga::ResourceBinding {
group: group_index,
Expand All @@ -203,66 +261,40 @@ impl super::CommandEncoder {
if changes_sizes_buffer {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(stage, &mut self.temp.binding_sizes)
.make_sizes_buffer_update(encoder.stage(), &mut self.temp.binding_sizes)
{
let a1 = index as _;
let a2 = (sizes.len() * WORD_SIZE) as u64;
let a3 = sizes.as_ptr().cast();
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_bytes(a1, a2, a3)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_bytes(a1, a2, a3)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3),
}
let index = index as _;
let length = (sizes.len() * WORD_SIZE) as u64;
let bytes_ptr = sizes.as_ptr().cast();
encoder.set_bytes(index, length, bytes_ptr);
}
}
let samplers = match stage {
naga::ShaderStage::Vertex => group.counters.vs.samplers,
naga::ShaderStage::Fragment => group.counters.fs.samplers,
naga::ShaderStage::Task => group.counters.ts.samplers,
naga::ShaderStage::Mesh => group.counters.ms.samplers,
naga::ShaderStage::Compute => group.counters.cs.samplers,
let samplers = match encoder.stage() {
S::Vertex => group.counters.vs.samplers,
S::Fragment => group.counters.fs.samplers,
S::Task => group.counters.ts.samplers,
S::Mesh => group.counters.ms.samplers,
S::Compute => group.counters.cs.samplers,
};
for index in 0..samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
let a1 = (resource_indices.samplers + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => {
render_encoder.unwrap().set_vertex_sampler_state(a1, a2)
}
naga::ShaderStage::Fragment => {
render_encoder.unwrap().set_fragment_sampler_state(a1, a2)
}
naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2),
}
let index = (resource_indices.samplers + index) as u64;
let state = Some(res.as_native());
encoder.set_sampler_state(index, state);
}

let textures = match stage {
naga::ShaderStage::Vertex => group.counters.vs.textures,
naga::ShaderStage::Fragment => group.counters.fs.textures,
naga::ShaderStage::Task => group.counters.ts.textures,
naga::ShaderStage::Mesh => group.counters.ms.textures,
naga::ShaderStage::Compute => group.counters.cs.textures,
let textures = match encoder.stage() {
S::Vertex => group.counters.vs.textures,
S::Fragment => group.counters.fs.textures,
S::Task => group.counters.ts.textures,
S::Mesh => group.counters.ms.textures,
S::Compute => group.counters.cs.textures,
};
for index in 0..textures {
let res = group.textures[(index_base.textures + index) as usize];
let a1 = (resource_indices.textures + index) as u64;
let a2 = Some(res.as_native());
match stage {
naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2),
naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2),
naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2),
naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2),
naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2),
}
let index = (resource_indices.textures + index) as u64;
let texture = Some(res.as_native());
encoder.set_texture(index, texture);
}
}
}
Expand Down Expand Up @@ -826,9 +858,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
let compute_encoder = self.state.compute.clone();
if let Some(encoder) = render_encoder {
self.update_bind_group_state(
naga::ShaderStage::Vertex,
Some(&encoder),
None,
Encoder::Vertex(&encoder),
// All zeros, as vs comes first
super::ResourceData::default(),
bg_info,
Expand All @@ -837,9 +867,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
group,
);
self.update_bind_group_state(
naga::ShaderStage::Task,
Some(&encoder),
None,
Encoder::Task(&encoder),
// All zeros, as ts comes first
super::ResourceData::default(),
bg_info,
Expand All @@ -848,19 +876,15 @@ impl crate::CommandEncoder for super::CommandEncoder {
group,
);
self.update_bind_group_state(
naga::ShaderStage::Mesh,
Some(&encoder),
None,
Encoder::Mesh(&encoder),
group.counters.ts.clone(),
bg_info,
dynamic_offsets,
group_index,
group,
);
self.update_bind_group_state(
naga::ShaderStage::Fragment,
Some(&encoder),
None,
Encoder::Fragment(&encoder),
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
Expand All @@ -884,9 +908,7 @@ impl crate::CommandEncoder for super::CommandEncoder {
}
if let Some(encoder) = compute_encoder {
self.update_bind_group_state(
naga::ShaderStage::Compute,
None,
Some(&encoder),
Encoder::Compute(&encoder),
super::ResourceData {
buffers: group.counters.vs.buffers
+ group.counters.ts.buffers
Expand Down
Loading