Skip to content

Commit

Permalink
[metal] set threadgroup memory sizes reflected from the shader
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Aug 19, 2021
1 parent 4af887c commit a296cf6
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 19 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## TBD
- All:
- expose more formats via adapter-specific feature
- Metal:
- fix usage of work group memory

## v0.10 (2021-08-18)
- Infrastructure:
- `gfx-hal` is replaced by the in-house graphics abstraction `wgpu-hal`. Backends: Vulkan, Metal, D3D-12, and OpenGL ES-3.
Expand Down
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ default-members = ["wgpu", "wgpu-hal", "wgpu-info"]
#glow = { path = "../glow" }

[patch.crates-io]
#metal = { path = "../metal-rs" }
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }
#wasm-bindgen = { path = "../wasm-bindgen" }
2 changes: 1 addition & 1 deletion wgpu-hal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ winapi = { version = "0.3", features = ["libloaderapi", "windef", "winuser"] }
native = { package = "d3d12", version = "0.4.1", features = ["libloading"], optional = true }

[target.'cfg(any(target_os="macos", target_os="ios"))'.dependencies]
mtl = { package = "metal", version = "0.23" }
mtl = { package = "metal", version = "0.23.1" }
objc = "0.2.5"
core-graphics-types = "0.1"

Expand Down
33 changes: 29 additions & 4 deletions wgpu-hal/src/metal/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ impl Default for super::CommandState {
raw_wg_size: mtl::MTLSize::new(0, 0, 0),
stage_infos: Default::default(),
storage_buffer_length_map: Default::default(),
work_group_memory_sizes: Vec::new(),
}
}
}
Expand Down Expand Up @@ -45,15 +46,20 @@ impl super::CommandEncoder {
}

fn begin_pass(&mut self) {
self.state.storage_buffer_length_map.clear();
self.state.stage_infos.vs.clear();
self.state.stage_infos.fs.clear();
self.state.stage_infos.cs.clear();
self.state.reset();
self.leave_blit();
}
}

impl super::CommandState {
fn reset(&mut self) {
self.storage_buffer_length_map.clear();
self.stage_infos.vs.clear();
self.stage_infos.fs.clear();
self.stage_infos.cs.clear();
self.work_group_memory_sizes.clear();
}

fn make_sizes_buffer_update<'a>(
&self,
stage: naga::ShaderStage,
Expand Down Expand Up @@ -840,6 +846,25 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
sizes.as_ptr() as _,
);
}

// update the threadgroup memory sizes
while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() {
self.state.work_group_memory_sizes.push(0);
}
for (index, (cur_size, pipeline_size)) in self
.state
.work_group_memory_sizes
.iter_mut()
.zip(pipeline.work_group_memory_sizes.iter())
.enumerate()
{
const ALIGN_MASK: u32 = 0xF; // must be a multiple of 16 bytes
let size = ((*pipeline_size - 1) | ALIGN_MASK) + 1;
if *cur_size != size {
*cur_size = size;
encoder.set_threadgroup_memory_length(index as _, size as _);
}
}
}

unsafe fn dispatch(&mut self, count: [u32; 3]) {
Expand Down
37 changes: 25 additions & 12 deletions wgpu-hal/src/metal/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct CompiledShader {
library: mtl::Library,
function: mtl::Function,
wg_size: mtl::MTLSize,
wg_memory_sizes: Vec<u32>,
sized_bindings: Vec<naga::ResourceBinding>,
immutable_buffer_mask: usize,
}
Expand Down Expand Up @@ -104,28 +105,38 @@ impl super::Device {
crate::PipelineError::EntryPoint(naga_stage)
})?;

// collect sizes indices and immutable buffers
// collect sizes indices, immutable buffers, and work group memory sizes
let ep_info = &stage.module.naga.info.get_entry_point(ep_index);
let mut wg_memory_sizes = Vec::new();
let mut sized_bindings = Vec::new();
let mut immutable_buffer_mask = 0;
for (var_handle, var) in module.global_variables.iter() {
if var.class == naga::StorageClass::WorkGroup {
let size = module.types[var.ty].inner.span(&module.constants);
wg_memory_sizes.push(size);
}

if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner {
let br = match var.binding {
Some(ref br) => br.clone(),
None => continue,
};
let storage_access_store = if let naga::StorageClass::Storage { access } = var.class
{
access.contains(naga::StorageAccess::STORE)
} else {
false
};
// check for an immutable buffer
if !ep_info[var_handle].is_empty() && !storage_access_store {
let psm = &layout.naga_options.per_stage_map[naga_stage];
let slot = psm.resources[&br].buffer.unwrap();
immutable_buffer_mask |= 1 << slot;

if !ep_info[var_handle].is_empty() {
let storage_access_store = match var.class {
naga::StorageClass::Storage { access } => {
access.contains(naga::StorageAccess::STORE)
}
_ => false,
};
// check for an immutable buffer
if !storage_access_store {
let psm = &layout.naga_options.per_stage_map[naga_stage];
let slot = psm.resources[&br].buffer.unwrap();
immutable_buffer_mask |= 1 << slot;
}
}

// check for the unsized buffer
if let Some(member) = members.last() {
if let naga::TypeInner::Array {
Expand All @@ -144,6 +155,7 @@ impl super::Device {
library,
function,
wg_size,
wg_memory_sizes,
sized_bindings,
immutable_buffer_mask,
})
Expand Down Expand Up @@ -915,6 +927,7 @@ impl crate::Device<super::Api> for super::Device {
},
cs_lib: cs.library,
work_group_size: cs.wg_size,
work_group_memory_sizes: cs.wg_memory_sizes,
})
}
unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {}
Expand Down
2 changes: 2 additions & 0 deletions wgpu-hal/src/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ pub struct ComputePipeline {
cs_lib: mtl::Library,
cs_info: PipelineStageInfo,
work_group_size: mtl::MTLSize,
work_group_memory_sizes: Vec<u32>,
}

unsafe impl Send for ComputePipeline {}
Expand Down Expand Up @@ -689,6 +690,7 @@ struct CommandState {
raw_wg_size: mtl::MTLSize,
stage_infos: MultiStageData<PipelineStageInfo>,
storage_buffer_length_map: fxhash::FxHashMap<naga::ResourceBinding, wgt::BufferSize>,
work_group_memory_sizes: Vec<u32>,
}

pub struct CommandEncoder {
Expand Down

0 comments on commit a296cf6

Please sign in to comment.