diff --git a/CHANGELOG.md b/CHANGELOG.md index e1cb97dc99..db0fe48212 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## TBD + - All: + - expose more formats via adapter-specific feature + - Metal: + - fix usage of work group memory + ## v0.10 (2021-08-18) - Infrastructure: - `gfx-hal` is replaced by the in-house graphics abstraction `wgpu-hal`. Backends: Vulkan, Metal, D3D-12, and OpenGL ES-3. diff --git a/Cargo.lock b/Cargo.lock index 083bdc6982..fd3038fcd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -909,9 +909,9 @@ dependencies = [ [[package]] name = "metal" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79d7d769f1c104b8388294d6594d491d2e21240636f5f94d37f8a0f3d7904450" +checksum = "e0514f491f4cc03632ab399ee01e2c1c1b12d3e1cf2d667c1ff5f87d6dcd2084" dependencies = [ "bitflags", "block", diff --git a/Cargo.toml b/Cargo.toml index 7601192131..3784d373ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ default-members = ["wgpu", "wgpu-hal", "wgpu-info"] #glow = { path = "../glow" } [patch.crates-io] +#metal = { path = "../metal-rs" } #web-sys = { path = "../wasm-bindgen/crates/web-sys" } #js-sys = { path = "../wasm-bindgen/crates/js-sys" } #wasm-bindgen = { path = "../wasm-bindgen" } diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml index b52606f939..f6f24e62ee 100644 --- a/wgpu-hal/Cargo.toml +++ b/wgpu-hal/Cargo.toml @@ -63,7 +63,7 @@ winapi = { version = "0.3", features = ["libloaderapi", "windef", "winuser"] } native = { package = "d3d12", version = "0.4.1", features = ["libloading"], optional = true } [target.'cfg(any(target_os="macos", target_os="ios"))'.dependencies] -mtl = { package = "metal", version = "0.23" } +mtl = { package = "metal", version = "0.23.1" } objc = "0.2.5" core-graphics-types = "0.1" diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 5fb648e151..e7a7a3cf5b 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -14,6 +14,7 @@ impl Default for super::CommandState { raw_wg_size: mtl::MTLSize::new(0, 0, 0), stage_infos: Default::default(), storage_buffer_length_map: Default::default(), + work_group_memory_sizes: Vec::new(), } } } @@ -45,15 +46,20 @@ impl super::CommandEncoder { } fn begin_pass(&mut self) { - self.state.storage_buffer_length_map.clear(); - self.state.stage_infos.vs.clear(); - self.state.stage_infos.fs.clear(); - self.state.stage_infos.cs.clear(); + self.state.reset(); self.leave_blit(); } } impl super::CommandState { + fn reset(&mut self) { + self.storage_buffer_length_map.clear(); + self.stage_infos.vs.clear(); + self.stage_infos.fs.clear(); + self.stage_infos.cs.clear(); + self.work_group_memory_sizes.clear(); + } + fn make_sizes_buffer_update<'a>( &self, stage: naga::ShaderStage, @@ -840,6 +846,25 @@ impl crate::CommandEncoder for super::CommandEncoder { sizes.as_ptr() as _, ); } + + // update the threadgroup memory sizes + while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { + self.state.work_group_memory_sizes.push(0); + } + for (index, (cur_size, pipeline_size)) in self + .state + .work_group_memory_sizes + .iter_mut() + .zip(pipeline.work_group_memory_sizes.iter()) + .enumerate() + { + const ALIGN_MASK: u32 = 0xF; // must be a multiple of 16 bytes + let size = ((*pipeline_size - 1) | ALIGN_MASK) + 1; + if *cur_size != size { + *cur_size = size; + encoder.set_threadgroup_memory_length(index as _, size as _); + } + } } unsafe fn dispatch(&mut self, count: [u32; 3]) { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index fa7f20d224..5f01abef59 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -13,6 +13,7 @@ struct CompiledShader { library: mtl::Library, function: mtl::Function, wg_size: mtl::MTLSize, + wg_memory_sizes: Vec, sized_bindings: Vec, immutable_buffer_mask: usize, } @@ -104,28 +105,38 @@ impl super::Device { crate::PipelineError::EntryPoint(naga_stage) })?; - // collect sizes indices and immutable buffers + // collect sizes indices, immutable buffers, and work group memory sizes let ep_info = &stage.module.naga.info.get_entry_point(ep_index); + let mut wg_memory_sizes = Vec::new(); let mut sized_bindings = Vec::new(); let mut immutable_buffer_mask = 0; for (var_handle, var) in module.global_variables.iter() { + if var.class == naga::StorageClass::WorkGroup { + let size = module.types[var.ty].inner.span(&module.constants); + wg_memory_sizes.push(size); + } + if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner { let br = match var.binding { Some(ref br) => br.clone(), None => continue, }; - let storage_access_store = if let naga::StorageClass::Storage { access } = var.class - { - access.contains(naga::StorageAccess::STORE) - } else { - false - }; - // check for an immutable buffer - if !ep_info[var_handle].is_empty() && !storage_access_store { - let psm = &layout.naga_options.per_stage_map[naga_stage]; - let slot = psm.resources[&br].buffer.unwrap(); - immutable_buffer_mask |= 1 << slot; + + if !ep_info[var_handle].is_empty() { + let storage_access_store = match var.class { + naga::StorageClass::Storage { access } => { + access.contains(naga::StorageAccess::STORE) + } + _ => false, + }; + // check for an immutable buffer + if !storage_access_store { + let psm = &layout.naga_options.per_stage_map[naga_stage]; + let slot = psm.resources[&br].buffer.unwrap(); + immutable_buffer_mask |= 1 << slot; + } } + // check for the unsized buffer if let Some(member) = members.last() { if let naga::TypeInner::Array { @@ -144,6 +155,7 @@ impl super::Device { library, function, wg_size, + wg_memory_sizes, sized_bindings, immutable_buffer_mask, }) @@ -915,6 +927,7 @@ impl crate::Device for super::Device { }, cs_lib: cs.library, work_group_size: cs.wg_size, + work_group_memory_sizes: cs.wg_memory_sizes, }) } unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {} diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 54aa39eac4..4ebbacce5a 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -626,6 +626,7 @@ pub struct ComputePipeline { cs_lib: mtl::Library, cs_info: PipelineStageInfo, work_group_size: mtl::MTLSize, + work_group_memory_sizes: Vec, } unsafe impl Send for ComputePipeline {} @@ -689,6 +690,7 @@ struct CommandState { raw_wg_size: mtl::MTLSize, stage_infos: MultiStageData, storage_buffer_length_map: fxhash::FxHashMap, + work_group_memory_sizes: Vec, } pub struct CommandEncoder {