From a296cf6e9eee39265eb5beae17b1f81d6b16016d Mon Sep 17 00:00:00 2001
From: Dzmitry Malyshau <kvarkus@gmail.com>
Date: Thu, 19 Aug 2021 12:22:54 -0400
Subject: [PATCH] [metal] set threadgroup memory sizes reflected from the
 shader

---
 CHANGELOG.md                  |  6 ++++++
 Cargo.lock                    |  4 ++--
 Cargo.toml                    |  1 +
 wgpu-hal/Cargo.toml           |  2 +-
 wgpu-hal/src/metal/command.rs | 33 +++++++++++++++++++++++++++----
 wgpu-hal/src/metal/device.rs  | 37 +++++++++++++++++++++++------------
 wgpu-hal/src/metal/mod.rs     |  2 ++
 7 files changed, 66 insertions(+), 19 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index e1cb97dc99..db0fe48212 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,11 @@
 # Change Log
 
+## TBD
+  - All:
+    - expose more formats via adapter-specific feature
+  - Metal:
+    - fix usage of work group memory
+
 ## v0.10 (2021-08-18)
   - Infrastructure:
     - `gfx-hal` is replaced by the in-house graphics abstraction `wgpu-hal`. Backends: Vulkan, Metal, D3D-12, and OpenGL ES-3.
diff --git a/Cargo.lock b/Cargo.lock
index 083bdc6982..fd3038fcd9 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -909,9 +909,9 @@ dependencies = [
 
 [[package]]
 name = "metal"
-version = "0.23.0"
+version = "0.23.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79d7d769f1c104b8388294d6594d491d2e21240636f5f94d37f8a0f3d7904450"
+checksum = "e0514f491f4cc03632ab399ee01e2c1c1b12d3e1cf2d667c1ff5f87d6dcd2084"
 dependencies = [
  "bitflags",
  "block",
diff --git a/Cargo.toml b/Cargo.toml
index 7601192131..3784d373ec 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -24,6 +24,7 @@ default-members = ["wgpu", "wgpu-hal", "wgpu-info"]
 #glow = { path = "../glow" }
 
 [patch.crates-io]
+#metal = { path = "../metal-rs" }
 #web-sys = { path = "../wasm-bindgen/crates/web-sys" }
 #js-sys = { path = "../wasm-bindgen/crates/js-sys" }
 #wasm-bindgen = { path = "../wasm-bindgen" }
diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml
index b52606f939..f6f24e62ee 100644
--- a/wgpu-hal/Cargo.toml
+++ b/wgpu-hal/Cargo.toml
@@ -63,7 +63,7 @@ winapi = { version = "0.3", features = ["libloaderapi", "windef", "winuser"] }
 native = { package = "d3d12", version = "0.4.1", features = ["libloading"], optional = true }
 
 [target.'cfg(any(target_os="macos", target_os="ios"))'.dependencies]
-mtl = { package = "metal", version = "0.23" }
+mtl = { package = "metal", version = "0.23.1" }
 objc = "0.2.5"
 core-graphics-types = "0.1"
 
diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs
index 5fb648e151..e7a7a3cf5b 100644
--- a/wgpu-hal/src/metal/command.rs
+++ b/wgpu-hal/src/metal/command.rs
@@ -14,6 +14,7 @@ impl Default for super::CommandState {
             raw_wg_size: mtl::MTLSize::new(0, 0, 0),
             stage_infos: Default::default(),
             storage_buffer_length_map: Default::default(),
+            work_group_memory_sizes: Vec::new(),
         }
     }
 }
@@ -45,15 +46,20 @@ impl super::CommandEncoder {
     }
 
     fn begin_pass(&mut self) {
-        self.state.storage_buffer_length_map.clear();
-        self.state.stage_infos.vs.clear();
-        self.state.stage_infos.fs.clear();
-        self.state.stage_infos.cs.clear();
+        self.state.reset();
         self.leave_blit();
     }
 }
 
 impl super::CommandState {
+    fn reset(&mut self) {
+        self.storage_buffer_length_map.clear();
+        self.stage_infos.vs.clear();
+        self.stage_infos.fs.clear();
+        self.stage_infos.cs.clear();
+        self.work_group_memory_sizes.clear();
+    }
+
     fn make_sizes_buffer_update<'a>(
         &self,
         stage: naga::ShaderStage,
@@ -840,6 +846,25 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
                 sizes.as_ptr() as _,
             );
         }
+
+        // update the threadgroup memory sizes
+        while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() {
+            self.state.work_group_memory_sizes.push(0);
+        }
+        for (index, (cur_size, pipeline_size)) in self
+            .state
+            .work_group_memory_sizes
+            .iter_mut()
+            .zip(pipeline.work_group_memory_sizes.iter())
+            .enumerate()
+        {
+            const ALIGN_MASK: u32 = 0xF; // must be a multiple of 16 bytes
+            let size = ((*pipeline_size - 1) | ALIGN_MASK) + 1;
+            if *cur_size != size {
+                *cur_size = size;
+                encoder.set_threadgroup_memory_length(index as _, size as _);
+            }
+        }
     }
 
     unsafe fn dispatch(&mut self, count: [u32; 3]) {
diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs
index fa7f20d224..5f01abef59 100644
--- a/wgpu-hal/src/metal/device.rs
+++ b/wgpu-hal/src/metal/device.rs
@@ -13,6 +13,7 @@ struct CompiledShader {
     library: mtl::Library,
     function: mtl::Function,
     wg_size: mtl::MTLSize,
+    wg_memory_sizes: Vec<u32>,
     sized_bindings: Vec<naga::ResourceBinding>,
     immutable_buffer_mask: usize,
 }
@@ -104,28 +105,38 @@ impl super::Device {
             crate::PipelineError::EntryPoint(naga_stage)
         })?;
 
-        // collect sizes indices and immutable buffers
+        // collect sizes indices, immutable buffers, and work group memory sizes
         let ep_info = &stage.module.naga.info.get_entry_point(ep_index);
+        let mut wg_memory_sizes = Vec::new();
         let mut sized_bindings = Vec::new();
         let mut immutable_buffer_mask = 0;
         for (var_handle, var) in module.global_variables.iter() {
+            if var.class == naga::StorageClass::WorkGroup {
+                let size = module.types[var.ty].inner.span(&module.constants);
+                wg_memory_sizes.push(size);
+            }
+
             if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner {
                 let br = match var.binding {
                     Some(ref br) => br.clone(),
                     None => continue,
                 };
-                let storage_access_store = if let naga::StorageClass::Storage { access } = var.class
-                {
-                    access.contains(naga::StorageAccess::STORE)
-                } else {
-                    false
-                };
-                // check for an immutable buffer
-                if !ep_info[var_handle].is_empty() && !storage_access_store {
-                    let psm = &layout.naga_options.per_stage_map[naga_stage];
-                    let slot = psm.resources[&br].buffer.unwrap();
-                    immutable_buffer_mask |= 1 << slot;
+
+                if !ep_info[var_handle].is_empty() {
+                    let storage_access_store = match var.class {
+                        naga::StorageClass::Storage { access } => {
+                            access.contains(naga::StorageAccess::STORE)
+                        }
+                        _ => false,
+                    };
+                    // check for an immutable buffer
+                    if !storage_access_store {
+                        let psm = &layout.naga_options.per_stage_map[naga_stage];
+                        let slot = psm.resources[&br].buffer.unwrap();
+                        immutable_buffer_mask |= 1 << slot;
+                    }
                 }
+
                 // check for the unsized buffer
                 if let Some(member) = members.last() {
                     if let naga::TypeInner::Array {
@@ -144,6 +155,7 @@ impl super::Device {
             library,
             function,
             wg_size,
+            wg_memory_sizes,
             sized_bindings,
             immutable_buffer_mask,
         })
@@ -915,6 +927,7 @@ impl crate::Device<super::Api> for super::Device {
             },
             cs_lib: cs.library,
             work_group_size: cs.wg_size,
+            work_group_memory_sizes: cs.wg_memory_sizes,
         })
     }
     unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {}
diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs
index 54aa39eac4..4ebbacce5a 100644
--- a/wgpu-hal/src/metal/mod.rs
+++ b/wgpu-hal/src/metal/mod.rs
@@ -626,6 +626,7 @@ pub struct ComputePipeline {
     cs_lib: mtl::Library,
     cs_info: PipelineStageInfo,
     work_group_size: mtl::MTLSize,
+    work_group_memory_sizes: Vec<u32>,
 }
 
 unsafe impl Send for ComputePipeline {}
@@ -689,6 +690,7 @@ struct CommandState {
     raw_wg_size: mtl::MTLSize,
     stage_infos: MultiStageData<PipelineStageInfo>,
     storage_buffer_length_map: fxhash::FxHashMap<naga::ResourceBinding, wgt::BufferSize>,
+    work_group_memory_sizes: Vec<u32>,
 }
 
 pub struct CommandEncoder {