diff --git a/dev/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip b/dev/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip index 0a10b45..9a91e8b 100644 Binary files a/dev/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip and b/dev/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip differ diff --git a/dev/_downloads/0d78e075dd52a34e158d7f5f710dfe89/plot_incremental_FNO_darcy.zip b/dev/_downloads/0d78e075dd52a34e158d7f5f710dfe89/plot_incremental_FNO_darcy.zip index 918ad6f..041085c 100644 Binary files a/dev/_downloads/0d78e075dd52a34e158d7f5f710dfe89/plot_incremental_FNO_darcy.zip and b/dev/_downloads/0d78e075dd52a34e158d7f5f710dfe89/plot_incremental_FNO_darcy.zip differ diff --git a/dev/_downloads/20c43dd37baf603889c4dc23e93bdb60/plot_count_flops.zip b/dev/_downloads/20c43dd37baf603889c4dc23e93bdb60/plot_count_flops.zip index 5be3881..f329c07 100644 Binary files a/dev/_downloads/20c43dd37baf603889c4dc23e93bdb60/plot_count_flops.zip and b/dev/_downloads/20c43dd37baf603889c4dc23e93bdb60/plot_count_flops.zip differ diff --git a/dev/_downloads/3864a2d85c7ce11adeac9580559229ab/plot_darcy_flow.zip b/dev/_downloads/3864a2d85c7ce11adeac9580559229ab/plot_darcy_flow.zip index f6d37eb..fb9d0d8 100644 Binary files a/dev/_downloads/3864a2d85c7ce11adeac9580559229ab/plot_darcy_flow.zip and b/dev/_downloads/3864a2d85c7ce11adeac9580559229ab/plot_darcy_flow.zip differ diff --git a/dev/_downloads/3faf9d2eaee5cc8e9f1c631c002ce544/plot_darcy_flow_spectrum.zip b/dev/_downloads/3faf9d2eaee5cc8e9f1c631c002ce544/plot_darcy_flow_spectrum.zip index 17b2788..8e9efd8 100644 Binary files a/dev/_downloads/3faf9d2eaee5cc8e9f1c631c002ce544/plot_darcy_flow_spectrum.zip and b/dev/_downloads/3faf9d2eaee5cc8e9f1c631c002ce544/plot_darcy_flow_spectrum.zip differ diff --git a/dev/_downloads/4e1cba46bb51062a073424e1efbaad57/plot_DISCO_convolutions.zip b/dev/_downloads/4e1cba46bb51062a073424e1efbaad57/plot_DISCO_convolutions.zip index 2726632..71adda5 100644 Binary files a/dev/_downloads/4e1cba46bb51062a073424e1efbaad57/plot_DISCO_convolutions.zip and b/dev/_downloads/4e1cba46bb51062a073424e1efbaad57/plot_DISCO_convolutions.zip differ diff --git a/dev/_downloads/5e60095ce99919773daa83384f767e02/plot_SFNO_swe.zip b/dev/_downloads/5e60095ce99919773daa83384f767e02/plot_SFNO_swe.zip index 7e8cc50..2110cdf 100644 Binary files a/dev/_downloads/5e60095ce99919773daa83384f767e02/plot_SFNO_swe.zip and b/dev/_downloads/5e60095ce99919773daa83384f767e02/plot_SFNO_swe.zip differ diff --git a/dev/_downloads/645da00b8fbbb9bb5cae877fd0f31635/plot_FNO_darcy.zip b/dev/_downloads/645da00b8fbbb9bb5cae877fd0f31635/plot_FNO_darcy.zip index 63b427a..9e3bdd0 100644 Binary files a/dev/_downloads/645da00b8fbbb9bb5cae877fd0f31635/plot_FNO_darcy.zip and b/dev/_downloads/645da00b8fbbb9bb5cae877fd0f31635/plot_FNO_darcy.zip differ diff --git a/dev/_downloads/6f1e7a639e0699d6164445b55e6c116d/auto_examples_jupyter.zip b/dev/_downloads/6f1e7a639e0699d6164445b55e6c116d/auto_examples_jupyter.zip index eb038f7..ae51831 100644 Binary files a/dev/_downloads/6f1e7a639e0699d6164445b55e6c116d/auto_examples_jupyter.zip and b/dev/_downloads/6f1e7a639e0699d6164445b55e6c116d/auto_examples_jupyter.zip differ diff --git a/dev/_downloads/7296405f6df7c2cfe184e9b258cee33e/checkpoint_FNO_darcy.zip b/dev/_downloads/7296405f6df7c2cfe184e9b258cee33e/checkpoint_FNO_darcy.zip index 1e3e4ab..824cc7c 100644 Binary files a/dev/_downloads/7296405f6df7c2cfe184e9b258cee33e/checkpoint_FNO_darcy.zip and b/dev/_downloads/7296405f6df7c2cfe184e9b258cee33e/checkpoint_FNO_darcy.zip differ diff --git a/dev/_downloads/cefc537c5730a6b3e916b83c1fd313d6/plot_UNO_darcy.zip b/dev/_downloads/cefc537c5730a6b3e916b83c1fd313d6/plot_UNO_darcy.zip index f9b5af3..b3d1288 100644 Binary files a/dev/_downloads/cefc537c5730a6b3e916b83c1fd313d6/plot_UNO_darcy.zip and b/dev/_downloads/cefc537c5730a6b3e916b83c1fd313d6/plot_UNO_darcy.zip differ diff --git a/dev/_images/sphx_glr_plot_SFNO_swe_001.png b/dev/_images/sphx_glr_plot_SFNO_swe_001.png index cf7bd9b..2139597 100644 Binary files a/dev/_images/sphx_glr_plot_SFNO_swe_001.png and b/dev/_images/sphx_glr_plot_SFNO_swe_001.png differ diff --git a/dev/_images/sphx_glr_plot_SFNO_swe_thumb.png b/dev/_images/sphx_glr_plot_SFNO_swe_thumb.png index 8d095c9..d06682d 100644 Binary files a/dev/_images/sphx_glr_plot_SFNO_swe_thumb.png and b/dev/_images/sphx_glr_plot_SFNO_swe_thumb.png differ diff --git a/dev/_images/sphx_glr_plot_UNO_darcy_001.png b/dev/_images/sphx_glr_plot_UNO_darcy_001.png index e4c5d6f..0dd1a96 100644 Binary files a/dev/_images/sphx_glr_plot_UNO_darcy_001.png and b/dev/_images/sphx_glr_plot_UNO_darcy_001.png differ diff --git a/dev/_images/sphx_glr_plot_UNO_darcy_thumb.png b/dev/_images/sphx_glr_plot_UNO_darcy_thumb.png index 2ad7b46..f349070 100644 Binary files a/dev/_images/sphx_glr_plot_UNO_darcy_thumb.png and b/dev/_images/sphx_glr_plot_UNO_darcy_thumb.png differ diff --git a/dev/_modules/neuralop/models/base_model.html b/dev/_modules/neuralop/models/base_model.html index 53ea87d..efb4f8d 100644 --- a/dev/_modules/neuralop/models/base_model.html +++ b/dev/_modules/neuralop/models/base_model.html @@ -191,7 +191,68 @@

Source code for neuralop.models.base_model

         instance._init_kwargs = kwargs
 
         return instance
-    
+
+    def state_dict(self, destination: dict=None, prefix: str='', keep_vars: bool=False):
+        """
+        state_dict subclasses nn.Module.state_dict() and adds a metadata field
+        to track the model version and ensure only compatible saves are loaded.
+
+        Parameters
+        ----------
+        destination : dict, optional
+            If provided, the state of module will
+            be updated into the dict and the same object is returned.
+            Otherwise, an OrderedDict will be created and returned, by default None
+        prefix : str, optional
+            a prefix added to parameter and buffer
+            names to compose the keys in state_dict, by default ``''``
+        keep_vars (bool, optional): by default the torch.Tensors
+            returned in the state dict are detached from autograd. 
+            If True, detaching will not be performed, by default False
+
+        """
+        state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+        if state_dict.get('_metadata') == None:
+            state_dict['_metadata'] = self._init_kwargs
+        else:
+            warnings.warn("Attempting to update metadata for a module with metadata already in self.state_dict()")
+        return state_dict
+
+    def load_state_dict(self, state_dict, strict=True, assign=False):
+        """load_state_dict subclasses nn.Module.load_state_dict() and adds a metadata field
+        to track the model version and ensure only compatible saves are loaded.
+
+        Parameters
+        ----------
+        state_dict : dict
+            state dictionary generated by ``nn.Module.state_dict()``
+        strict : bool, optional
+            whether to strictly enforce that the keys in ``state_dict``
+            match the keys returned by this module's, by default True.
+        assign : bool, optional
+            whether to assign items in the state dict to their corresponding keys
+            in the module instead of copying them inplace into the module's current
+            parameters and buffers. When False, the properties of the tensors in the
+            current module are preserved while when True, the properties of the Tensors
+            in the state dict are preserved, by default False
+
+        Returns
+        -------
+        _type_
+            _description_
+        """
+        metadata = state_dict.pop('_metadata', None)
+
+        if metadata is not None:
+            saved_version = metadata.get('_version', None)
+            if saved_version is None:
+                warnings.warn(f"Saved instance of {self.__class__} has no stored version attribute.")
+            if saved_version != self._version:
+                warnings.warn(f"Attempting to load a {self.__class__} of version {saved_version},"
+                              f"But current version of {self.__class__} is {saved_version}")
+            # remove state dict metadata at the end to ensure proper loading with PyTorch module
+        return super().load_state_dict(state_dict, strict=strict, assign=assign)
+
     def save_checkpoint(self, save_folder, save_name):
         """Saves the model state and init param in the given folder under the given name
         """
diff --git a/dev/_sources/auto_examples/plot_DISCO_convolutions.rst.txt b/dev/_sources/auto_examples/plot_DISCO_convolutions.rst.txt
index b319b4f..ead0a56 100644
--- a/dev/_sources/auto_examples/plot_DISCO_convolutions.rst.txt
+++ b/dev/_sources/auto_examples/plot_DISCO_convolutions.rst.txt
@@ -357,7 +357,7 @@ in order to compute the convolved image, we need to first bring it into the righ
  .. code-block:: none
 
 
-    
+    
 
 
 
@@ -448,7 +448,7 @@ in order to compute the convolved image, we need to first bring it into the righ
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 31.674 seconds)
+   **Total running time of the script:** (0 minutes 31.281 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_DISCO_convolutions.py:
diff --git a/dev/_sources/auto_examples/plot_FNO_darcy.rst.txt b/dev/_sources/auto_examples/plot_FNO_darcy.rst.txt
index 832c146..77e4a0e 100644
--- a/dev/_sources/auto_examples/plot_FNO_darcy.rst.txt
+++ b/dev/_sources/auto_examples/plot_FNO_darcy.rst.txt
@@ -248,13 +248,13 @@ Training the model
     )
 
     ### SCHEDULER ###
-     
+     
 
     ### LOSSES ###
 
-     * Train: 
+     * Train: 
 
-     * Test: {'h1': , 'l2': }
+     * Test: {'h1': , 'l2': }
 
 
 
@@ -311,22 +311,22 @@ Then train the model on our small Darcy-Flow dataset:
     Training on 1000 samples
     Testing on [50, 50] samples         on resolutions [16, 32].
     Raw outputs of shape torch.Size([32, 1, 16, 16])
-    [0] time=2.59, avg_loss=0.6956, train_err=21.7383
+    [0] time=2.57, avg_loss=0.6956, train_err=21.7383
     Eval: 16_h1=0.4298, 16_l2=0.3487, 32_h1=0.5847, 32_l2=0.3542
     [3] time=2.57, avg_loss=0.2103, train_err=6.5705
     Eval: 16_h1=0.2030, 16_l2=0.1384, 32_h1=0.5075, 32_l2=0.1774
-    [6] time=2.58, avg_loss=0.1911, train_err=5.9721
+    [6] time=2.55, avg_loss=0.1911, train_err=5.9721
     Eval: 16_h1=0.2099, 16_l2=0.1374, 32_h1=0.4907, 32_l2=0.1783
-    [9] time=2.64, avg_loss=0.1410, train_err=4.4073
+    [9] time=2.56, avg_loss=0.1410, train_err=4.4073
     Eval: 16_h1=0.2052, 16_l2=0.1201, 32_h1=0.5268, 32_l2=0.1615
-    [12] time=2.63, avg_loss=0.1422, train_err=4.4434
+    [12] time=2.57, avg_loss=0.1422, train_err=4.4434
     Eval: 16_h1=0.2131, 16_l2=0.1285, 32_h1=0.5413, 32_l2=0.1741
-    [15] time=2.61, avg_loss=0.1198, train_err=3.7424
+    [15] time=2.56, avg_loss=0.1198, train_err=3.7424
     Eval: 16_h1=0.1984, 16_l2=0.1137, 32_h1=0.5255, 32_l2=0.1569
-    [18] time=2.59, avg_loss=0.1104, train_err=3.4502
+    [18] time=2.55, avg_loss=0.1104, train_err=3.4502
     Eval: 16_h1=0.2039, 16_l2=0.1195, 32_h1=0.5062, 32_l2=0.1603
 
-    {'train_err': 2.9605126455426216, 'avg_loss': 0.0947364046573639, 'avg_lasso_loss': None, 'epoch_train_time': 2.582470736999994}
+    {'train_err': 2.9605126455426216, 'avg_loss': 0.0947364046573639, 'avg_lasso_loss': None, 'epoch_train_time': 2.5690329919999613}
 
 
 
@@ -476,7 +476,7 @@ are other ways to scale the outputs of the FNO to train a true super-resolution
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 53.239 seconds)
+   **Total running time of the script:** (0 minutes 52.571 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_FNO_darcy.py:
diff --git a/dev/_sources/auto_examples/plot_SFNO_swe.rst.txt b/dev/_sources/auto_examples/plot_SFNO_swe.rst.txt
index 75afd7c..359c6f9 100644
--- a/dev/_sources/auto_examples/plot_SFNO_swe.rst.txt
+++ b/dev/_sources/auto_examples/plot_SFNO_swe.rst.txt
@@ -234,13 +234,13 @@ Creating the losses
     )
 
     ### SCHEDULER ###
-     
+     
 
     ### LOSSES ###
 
-     * Train: 
+     * Train: 
 
-     * Test: {'l2': }
+     * Test: {'l2': }
 
 
 
@@ -297,22 +297,22 @@ Train the model on the spherical SWE dataset
     Training on 200 samples
     Testing on [50, 50] samples         on resolutions [(32, 64), (64, 128)].
     Raw outputs of shape torch.Size([4, 3, 32, 64])
-    [0] time=3.71, avg_loss=2.5890, train_err=10.3559
-    Eval: (32, 64)_l2=2.0514, (64, 128)_l2=2.5055
-    [3] time=3.64, avg_loss=0.4467, train_err=1.7867
-    Eval: (32, 64)_l2=0.6580, (64, 128)_l2=2.5187
-    [6] time=3.60, avg_loss=0.2807, train_err=1.1226
-    Eval: (32, 64)_l2=0.5458, (64, 128)_l2=2.4767
-    [9] time=3.59, avg_loss=0.2355, train_err=0.9419
-    Eval: (32, 64)_l2=0.5549, (64, 128)_l2=2.4920
-    [12] time=3.65, avg_loss=0.2170, train_err=0.8680
-    Eval: (32, 64)_l2=0.5206, (64, 128)_l2=2.4921
-    [15] time=3.68, avg_loss=0.1824, train_err=0.7297
-    Eval: (32, 64)_l2=0.4906, (64, 128)_l2=2.4846
-    [18] time=3.62, avg_loss=0.1764, train_err=0.7054
-    Eval: (32, 64)_l2=0.4640, (64, 128)_l2=2.4902
+    [0] time=3.61, avg_loss=2.5783, train_err=10.3132
+    Eval: (32, 64)_l2=1.7929, (64, 128)_l2=2.3944
+    [3] time=3.61, avg_loss=0.3960, train_err=1.5841
+    Eval: (32, 64)_l2=0.4045, (64, 128)_l2=2.6128
+    [6] time=3.61, avg_loss=0.2645, train_err=1.0581
+    Eval: (32, 64)_l2=0.2855, (64, 128)_l2=2.5574
+    [9] time=3.57, avg_loss=0.2303, train_err=0.9211
+    Eval: (32, 64)_l2=0.3216, (64, 128)_l2=2.5180
+    [12] time=3.58, avg_loss=0.1810, train_err=0.7240
+    Eval: (32, 64)_l2=0.2103, (64, 128)_l2=2.5075
+    [15] time=3.57, avg_loss=0.1538, train_err=0.6152
+    Eval: (32, 64)_l2=0.1890, (64, 128)_l2=2.4402
+    [18] time=3.57, avg_loss=0.1323, train_err=0.5294
+    Eval: (32, 64)_l2=0.1798, (64, 128)_l2=2.4265
 
-    {'train_err': 0.6676547992229461, 'avg_loss': 0.16691369980573653, 'avg_lasso_loss': None, 'epoch_train_time': 3.599016188000064}
+    {'train_err': 0.5195050823688507, 'avg_loss': 0.12987627059221268, 'avg_lasso_loss': None, 'epoch_train_time': 3.607363260999989}
 
 
 
@@ -383,7 +383,7 @@ In practice we would train a Neural Operator on one or multiple GPUs
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (1 minutes 28.830 seconds)
+   **Total running time of the script:** (1 minutes 28.039 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_SFNO_swe.py:
diff --git a/dev/_sources/auto_examples/plot_UNO_darcy.rst.txt b/dev/_sources/auto_examples/plot_UNO_darcy.rst.txt
index 1860cb5..05ae47d 100644
--- a/dev/_sources/auto_examples/plot_UNO_darcy.rst.txt
+++ b/dev/_sources/auto_examples/plot_UNO_darcy.rst.txt
@@ -345,13 +345,13 @@ Creating the losses
     )
 
     ### SCHEDULER ###
-     
+     
 
     ### LOSSES ###
 
-     * Train: 
+     * Train: 
 
-     * Test: {'h1': , 'l2': }
+     * Test: {'h1': , 'l2': }
 
 
 
@@ -410,22 +410,22 @@ Actually train the model on our small Darcy-Flow dataset
     Training on 1000 samples
     Testing on [50, 50] samples         on resolutions [16, 32].
     Raw outputs of shape torch.Size([32, 1, 16, 16])
-    [0] time=10.19, avg_loss=0.6679, train_err=20.8720
-    Eval: 16_h1=0.3981, 16_l2=0.2711, 32_h1=0.9217, 32_l2=0.6494
-    [3] time=10.13, avg_loss=0.2426, train_err=7.5825
-    Eval: 16_h1=0.2854, 16_l2=0.1799, 32_h1=0.7671, 32_l2=0.4978
-    [6] time=10.09, avg_loss=0.2373, train_err=7.4160
-    Eval: 16_h1=0.3061, 16_l2=0.1973, 32_h1=0.7619, 32_l2=0.4908
-    [9] time=10.08, avg_loss=0.2195, train_err=6.8592
-    Eval: 16_h1=0.2422, 16_l2=0.1575, 32_h1=0.7381, 32_l2=0.4485
-    [12] time=10.11, avg_loss=0.1690, train_err=5.2798
-    Eval: 16_h1=0.2518, 16_l2=0.1580, 32_h1=0.7492, 32_l2=0.4722
-    [15] time=10.12, avg_loss=0.1825, train_err=5.7038
-    Eval: 16_h1=0.2699, 16_l2=0.1664, 32_h1=0.7599, 32_l2=0.4589
-    [18] time=10.08, avg_loss=0.1676, train_err=5.2361
-    Eval: 16_h1=0.2356, 16_l2=0.1478, 32_h1=0.7352, 32_l2=0.4557
+    [0] time=10.15, avg_loss=0.6631, train_err=20.7222
+    Eval: 16_h1=0.4579, 16_l2=0.2992, 32_h1=0.9470, 32_l2=0.6763
+    [3] time=10.05, avg_loss=0.2476, train_err=7.7385
+    Eval: 16_h1=0.2439, 16_l2=0.1618, 32_h1=0.9045, 32_l2=0.6343
+    [6] time=10.04, avg_loss=0.2285, train_err=7.1392
+    Eval: 16_h1=0.2579, 16_l2=0.1738, 32_h1=0.8590, 32_l2=0.6153
+    [9] time=10.07, avg_loss=0.1985, train_err=6.2036
+    Eval: 16_h1=0.2429, 16_l2=0.1520, 32_h1=0.8739, 32_l2=0.5978
+    [12] time=10.04, avg_loss=0.1856, train_err=5.8014
+    Eval: 16_h1=0.2410, 16_l2=0.1467, 32_h1=0.8608, 32_l2=0.5604
+    [15] time=10.05, avg_loss=0.1546, train_err=4.8301
+    Eval: 16_h1=0.3156, 16_l2=0.2123, 32_h1=0.8466, 32_l2=0.6000
+    [18] time=10.03, avg_loss=0.1247, train_err=3.8961
+    Eval: 16_h1=0.2340, 16_l2=0.1354, 32_h1=0.8477, 32_l2=0.5822
 
-    {'train_err': 4.981805760413408, 'avg_loss': 0.15941778433322906, 'avg_lasso_loss': None, 'epoch_train_time': 10.092775133000032}
+    {'train_err': 4.671443767845631, 'avg_loss': 0.14948620057106018, 'avg_lasso_loss': None, 'epoch_train_time': 10.053496939999945}
 
 
 
@@ -499,7 +499,7 @@ In practice we would train a Neural Operator on one or multiple GPUs
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (3 minutes 26.114 seconds)
+   **Total running time of the script:** (3 minutes 24.779 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_UNO_darcy.py:
diff --git a/dev/_sources/auto_examples/plot_count_flops.rst.txt b/dev/_sources/auto_examples/plot_count_flops.rst.txt
index 4536697..181286f 100644
--- a/dev/_sources/auto_examples/plot_count_flops.rst.txt
+++ b/dev/_sources/auto_examples/plot_count_flops.rst.txt
@@ -80,7 +80,7 @@ This output is organized as a defaultdict object that counts the FLOPS used in e
 
  .. code-block:: none
 
-    defaultdict(. at 0x7f568b61f1a0>, {'': defaultdict(, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'projection': defaultdict(, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(, {'convolution.default': 4194304})})
+    defaultdict(. at 0x7f79193fb7e0>, {'': defaultdict(, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(, {'convolution.default': 134217728}), 'projection': defaultdict(, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(, {'convolution.default': 4194304})})
 
 
 
@@ -125,7 +125,7 @@ To check the maximum FLOPS used during the forward pass, let's create a recursiv
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 3.316 seconds)
+   **Total running time of the script:** (0 minutes 3.210 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_count_flops.py:
diff --git a/dev/_sources/auto_examples/plot_darcy_flow.rst.txt b/dev/_sources/auto_examples/plot_darcy_flow.rst.txt
index eedcee1..75f63a7 100644
--- a/dev/_sources/auto_examples/plot_darcy_flow.rst.txt
+++ b/dev/_sources/auto_examples/plot_darcy_flow.rst.txt
@@ -163,7 +163,7 @@ Visualizing the data
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 0.333 seconds)
+   **Total running time of the script:** (0 minutes 0.328 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_darcy_flow.py:
diff --git a/dev/_sources/auto_examples/plot_darcy_flow_spectrum.rst.txt b/dev/_sources/auto_examples/plot_darcy_flow_spectrum.rst.txt
index 4b0afdb..71c7802 100644
--- a/dev/_sources/auto_examples/plot_darcy_flow_spectrum.rst.txt
+++ b/dev/_sources/auto_examples/plot_darcy_flow_spectrum.rst.txt
@@ -219,7 +219,7 @@ Loading the Navier-Stokes dataset in 128x128 resolution
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 0.204 seconds)
+   **Total running time of the script:** (0 minutes 0.200 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_darcy_flow_spectrum.py:
diff --git a/dev/_sources/auto_examples/plot_incremental_FNO_darcy.rst.txt b/dev/_sources/auto_examples/plot_incremental_FNO_darcy.rst.txt
index e9294c2..d4d537e 100644
--- a/dev/_sources/auto_examples/plot_incremental_FNO_darcy.rst.txt
+++ b/dev/_sources/auto_examples/plot_incremental_FNO_darcy.rst.txt
@@ -240,15 +240,15 @@ Set up the losses
     )
 
     ### SCHEDULER ###
-     
+     
 
     ### LOSSES ###
 
     ### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###
 
-     * Train: 
+     * Train: 
 
-     * Test: {'h1': , 'l2': }
+     * Test: {'h1': , 'l2': }
 
 
 
@@ -337,9 +337,9 @@ Train the model
     Eval: 16_h1=0.7698, 16_l2=0.4235, 32_h1=0.9175, 32_l2=0.4405
     [3] time=0.23, avg_loss=0.6082, train_err=8.6886
     Eval: 16_h1=0.7600, 16_l2=0.4570, 32_h1=1.0126, 32_l2=0.5005
-    [4] time=0.23, avg_loss=0.5604, train_err=8.0054
+    [4] time=0.22, avg_loss=0.5604, train_err=8.0054
     Eval: 16_h1=0.8301, 16_l2=0.4577, 32_h1=1.1722, 32_l2=0.4987
-    [5] time=0.22, avg_loss=0.5366, train_err=7.6663
+    [5] time=0.23, avg_loss=0.5366, train_err=7.6663
     Eval: 16_h1=0.8099, 16_l2=0.4076, 32_h1=1.1414, 32_l2=0.4436
     [6] time=0.23, avg_loss=0.5050, train_err=7.2139
     Eval: 16_h1=0.7204, 16_l2=0.3888, 32_h1=0.9945, 32_l2=0.4285
@@ -352,9 +352,9 @@ Train the model
     Incre Res Update: change index to 1
     Incre Res Update: change sub to 1
     Incre Res Update: change res to 16
-    [10] time=0.29, avg_loss=0.5158, train_err=7.3681
+    [10] time=0.31, avg_loss=0.5158, train_err=7.3681
     Eval: 16_h1=0.5133, 16_l2=0.3167, 32_h1=0.6120, 32_l2=0.3022
-    [11] time=0.28, avg_loss=0.4536, train_err=6.4795
+    [11] time=0.27, avg_loss=0.4536, train_err=6.4795
     Eval: 16_h1=0.4680, 16_l2=0.3422, 32_h1=0.6436, 32_l2=0.3659
     [12] time=0.28, avg_loss=0.4155, train_err=5.9358
     Eval: 16_h1=0.4119, 16_l2=0.2692, 32_h1=0.5285, 32_l2=0.2692
@@ -366,14 +366,14 @@ Train the model
     Eval: 16_h1=0.3611, 16_l2=0.2323, 32_h1=0.5079, 32_l2=0.2455
     [16] time=0.28, avg_loss=0.3251, train_err=4.6438
     Eval: 16_h1=0.3433, 16_l2=0.2224, 32_h1=0.4757, 32_l2=0.2351
-    [17] time=0.29, avg_loss=0.3072, train_err=4.3888
+    [17] time=0.28, avg_loss=0.3072, train_err=4.3888
     Eval: 16_h1=0.3458, 16_l2=0.2226, 32_h1=0.4776, 32_l2=0.2371
-    [18] time=0.29, avg_loss=0.2982, train_err=4.2593
+    [18] time=0.28, avg_loss=0.2982, train_err=4.2593
     Eval: 16_h1=0.3251, 16_l2=0.2116, 32_h1=0.4519, 32_l2=0.2245
     [19] time=0.28, avg_loss=0.2802, train_err=4.0024
     Eval: 16_h1=0.3201, 16_l2=0.2110, 32_h1=0.4533, 32_l2=0.2245
 
-    {'train_err': 4.002395851271493, 'avg_loss': 0.2801677095890045, 'avg_lasso_loss': None, 'epoch_train_time': 0.28468873500003156, '16_h1': tensor(0.3201), '16_l2': tensor(0.2110), '32_h1': tensor(0.4533), '32_l2': tensor(0.2245)}
+    {'train_err': 4.002395851271493, 'avg_loss': 0.2801677095890045, 'avg_lasso_loss': None, 'epoch_train_time': 0.2810444039999993, '16_h1': tensor(0.3201), '16_l2': tensor(0.2110), '32_h1': tensor(0.4533), '32_l2': tensor(0.2245)}
 
 
 
@@ -447,7 +447,7 @@ In practice we would train a Neural Operator on one or multiple GPUs
 
 .. rst-class:: sphx-glr-timing
 
-   **Total running time of the script:** (0 minutes 7.057 seconds)
+   **Total running time of the script:** (0 minutes 7.023 seconds)
 
 
 .. _sphx_glr_download_auto_examples_plot_incremental_FNO_darcy.py:
diff --git a/dev/_sources/auto_examples/sg_execution_times.rst.txt b/dev/_sources/auto_examples/sg_execution_times.rst.txt
index 2461063..de1cfe5 100644
--- a/dev/_sources/auto_examples/sg_execution_times.rst.txt
+++ b/dev/_sources/auto_examples/sg_execution_times.rst.txt
@@ -6,7 +6,7 @@
 
 Computation times
 =================
-**06:30.766** total execution time for 9 files **from auto_examples**:
+**06:27.430** total execution time for 9 files **from auto_examples**:
 
 .. container::
 
@@ -33,28 +33,28 @@ Computation times
      - Time
      - Mem (MB)
    * - :ref:`sphx_glr_auto_examples_plot_UNO_darcy.py` (``plot_UNO_darcy.py``)
-     - 03:26.114
+     - 03:24.779
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_SFNO_swe.py` (``plot_SFNO_swe.py``)
-     - 01:28.830
+     - 01:28.039
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_FNO_darcy.py` (``plot_FNO_darcy.py``)
-     - 00:53.239
+     - 00:52.571
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_DISCO_convolutions.py` (``plot_DISCO_convolutions.py``)
-     - 00:31.674
+     - 00:31.281
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_incremental_FNO_darcy.py` (``plot_incremental_FNO_darcy.py``)
-     - 00:07.057
+     - 00:07.023
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_count_flops.py` (``plot_count_flops.py``)
-     - 00:03.316
+     - 00:03.210
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_darcy_flow.py` (``plot_darcy_flow.py``)
-     - 00:00.333
+     - 00:00.328
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_darcy_flow_spectrum.py` (``plot_darcy_flow_spectrum.py``)
-     - 00:00.204
+     - 00:00.200
      - 0.0
    * - :ref:`sphx_glr_auto_examples_checkpoint_FNO_darcy.py` (``checkpoint_FNO_darcy.py``)
      - 00:00.000
diff --git a/dev/_sources/sg_execution_times.rst.txt b/dev/_sources/sg_execution_times.rst.txt
index a586820..637751a 100644
--- a/dev/_sources/sg_execution_times.rst.txt
+++ b/dev/_sources/sg_execution_times.rst.txt
@@ -6,7 +6,7 @@
 
 Computation times
 =================
-**06:30.766** total execution time for 9 files **from all galleries**:
+**06:27.430** total execution time for 9 files **from all galleries**:
 
 .. container::
 
@@ -33,28 +33,28 @@ Computation times
      - Time
      - Mem (MB)
    * - :ref:`sphx_glr_auto_examples_plot_UNO_darcy.py` (``../../examples/plot_UNO_darcy.py``)
-     - 03:26.114
+     - 03:24.779
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_SFNO_swe.py` (``../../examples/plot_SFNO_swe.py``)
-     - 01:28.830
+     - 01:28.039
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_FNO_darcy.py` (``../../examples/plot_FNO_darcy.py``)
-     - 00:53.239
+     - 00:52.571
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_DISCO_convolutions.py` (``../../examples/plot_DISCO_convolutions.py``)
-     - 00:31.674
+     - 00:31.281
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_incremental_FNO_darcy.py` (``../../examples/plot_incremental_FNO_darcy.py``)
-     - 00:07.057
+     - 00:07.023
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_count_flops.py` (``../../examples/plot_count_flops.py``)
-     - 00:03.316
+     - 00:03.210
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_darcy_flow.py` (``../../examples/plot_darcy_flow.py``)
-     - 00:00.333
+     - 00:00.328
      - 0.0
    * - :ref:`sphx_glr_auto_examples_plot_darcy_flow_spectrum.py` (``../../examples/plot_darcy_flow_spectrum.py``)
-     - 00:00.204
+     - 00:00.200
      - 0.0
    * - :ref:`sphx_glr_auto_examples_checkpoint_FNO_darcy.py` (``../../examples/checkpoint_FNO_darcy.py``)
      - 00:00.000
diff --git a/dev/auto_examples/plot_DISCO_convolutions.html b/dev/auto_examples/plot_DISCO_convolutions.html
index 90bb05c..09a63fb 100644
--- a/dev/auto_examples/plot_DISCO_convolutions.html
+++ b/dev/auto_examples/plot_DISCO_convolutions.html
@@ -330,7 +330,7 @@
 # plt.show()
 
-plot DISCO convolutions
<matplotlib.colorbar.Colorbar object at 0x7f568db98550>
+plot DISCO convolutions
<matplotlib.colorbar.Colorbar object at 0x7f790614c550>
 
convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quadrature_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()
@@ -381,7 +381,7 @@
 plot DISCO convolutions
torch.Size([1, 1, 120, 90])
 
-

Total running time of the script: (0 minutes 31.674 seconds)

+

Total running time of the script: (0 minutes 31.281 seconds)

Create the trainer

@@ -323,22 +323,22 @@
Training on 200 samples
 Testing on [50, 50] samples         on resolutions [(32, 64), (64, 128)].
 Raw outputs of shape torch.Size([4, 3, 32, 64])
-[0] time=3.71, avg_loss=2.5890, train_err=10.3559
-Eval: (32, 64)_l2=2.0514, (64, 128)_l2=2.5055
-[3] time=3.64, avg_loss=0.4467, train_err=1.7867
-Eval: (32, 64)_l2=0.6580, (64, 128)_l2=2.5187
-[6] time=3.60, avg_loss=0.2807, train_err=1.1226
-Eval: (32, 64)_l2=0.5458, (64, 128)_l2=2.4767
-[9] time=3.59, avg_loss=0.2355, train_err=0.9419
-Eval: (32, 64)_l2=0.5549, (64, 128)_l2=2.4920
-[12] time=3.65, avg_loss=0.2170, train_err=0.8680
-Eval: (32, 64)_l2=0.5206, (64, 128)_l2=2.4921
-[15] time=3.68, avg_loss=0.1824, train_err=0.7297
-Eval: (32, 64)_l2=0.4906, (64, 128)_l2=2.4846
-[18] time=3.62, avg_loss=0.1764, train_err=0.7054
-Eval: (32, 64)_l2=0.4640, (64, 128)_l2=2.4902
-
-{'train_err': 0.6676547992229461, 'avg_loss': 0.16691369980573653, 'avg_lasso_loss': None, 'epoch_train_time': 3.599016188000064}
+[0] time=3.61, avg_loss=2.5783, train_err=10.3132
+Eval: (32, 64)_l2=1.7929, (64, 128)_l2=2.3944
+[3] time=3.61, avg_loss=0.3960, train_err=1.5841
+Eval: (32, 64)_l2=0.4045, (64, 128)_l2=2.6128
+[6] time=3.61, avg_loss=0.2645, train_err=1.0581
+Eval: (32, 64)_l2=0.2855, (64, 128)_l2=2.5574
+[9] time=3.57, avg_loss=0.2303, train_err=0.9211
+Eval: (32, 64)_l2=0.3216, (64, 128)_l2=2.5180
+[12] time=3.58, avg_loss=0.1810, train_err=0.7240
+Eval: (32, 64)_l2=0.2103, (64, 128)_l2=2.5075
+[15] time=3.57, avg_loss=0.1538, train_err=0.6152
+Eval: (32, 64)_l2=0.1890, (64, 128)_l2=2.4402
+[18] time=3.57, avg_loss=0.1323, train_err=0.5294
+Eval: (32, 64)_l2=0.1798, (64, 128)_l2=2.4265
+
+{'train_err': 0.5195050823688507, 'avg_loss': 0.12987627059221268, 'avg_lasso_loss': None, 'epoch_train_time': 3.607363260999989}
 

Plot the prediction, and compare with the ground-truth @@ -385,7 +385,7 @@ fig.show()

-Inputs, ground-truth output and prediction., Input x (32, 64), Ground-truth y, Model prediction, Input x (64, 128), Ground-truth y, Model prediction

Total running time of the script: (1 minutes 28.830 seconds)

+Inputs, ground-truth output and prediction., Input x (32, 64), Ground-truth y, Model prediction, Input x (64, 128), Ground-truth y, Model prediction

Total running time of the script: (1 minutes 28.039 seconds)

Create the trainer

@@ -454,22 +454,22 @@
Training on 1000 samples
 Testing on [50, 50] samples         on resolutions [16, 32].
 Raw outputs of shape torch.Size([32, 1, 16, 16])
-[0] time=10.19, avg_loss=0.6679, train_err=20.8720
-Eval: 16_h1=0.3981, 16_l2=0.2711, 32_h1=0.9217, 32_l2=0.6494
-[3] time=10.13, avg_loss=0.2426, train_err=7.5825
-Eval: 16_h1=0.2854, 16_l2=0.1799, 32_h1=0.7671, 32_l2=0.4978
-[6] time=10.09, avg_loss=0.2373, train_err=7.4160
-Eval: 16_h1=0.3061, 16_l2=0.1973, 32_h1=0.7619, 32_l2=0.4908
-[9] time=10.08, avg_loss=0.2195, train_err=6.8592
-Eval: 16_h1=0.2422, 16_l2=0.1575, 32_h1=0.7381, 32_l2=0.4485
-[12] time=10.11, avg_loss=0.1690, train_err=5.2798
-Eval: 16_h1=0.2518, 16_l2=0.1580, 32_h1=0.7492, 32_l2=0.4722
-[15] time=10.12, avg_loss=0.1825, train_err=5.7038
-Eval: 16_h1=0.2699, 16_l2=0.1664, 32_h1=0.7599, 32_l2=0.4589
-[18] time=10.08, avg_loss=0.1676, train_err=5.2361
-Eval: 16_h1=0.2356, 16_l2=0.1478, 32_h1=0.7352, 32_l2=0.4557
-
-{'train_err': 4.981805760413408, 'avg_loss': 0.15941778433322906, 'avg_lasso_loss': None, 'epoch_train_time': 10.092775133000032}
+[0] time=10.15, avg_loss=0.6631, train_err=20.7222
+Eval: 16_h1=0.4579, 16_l2=0.2992, 32_h1=0.9470, 32_l2=0.6763
+[3] time=10.05, avg_loss=0.2476, train_err=7.7385
+Eval: 16_h1=0.2439, 16_l2=0.1618, 32_h1=0.9045, 32_l2=0.6343
+[6] time=10.04, avg_loss=0.2285, train_err=7.1392
+Eval: 16_h1=0.2579, 16_l2=0.1738, 32_h1=0.8590, 32_l2=0.6153
+[9] time=10.07, avg_loss=0.1985, train_err=6.2036
+Eval: 16_h1=0.2429, 16_l2=0.1520, 32_h1=0.8739, 32_l2=0.5978
+[12] time=10.04, avg_loss=0.1856, train_err=5.8014
+Eval: 16_h1=0.2410, 16_l2=0.1467, 32_h1=0.8608, 32_l2=0.5604
+[15] time=10.05, avg_loss=0.1546, train_err=4.8301
+Eval: 16_h1=0.3156, 16_l2=0.2123, 32_h1=0.8466, 32_l2=0.6000
+[18] time=10.03, avg_loss=0.1247, train_err=3.8961
+Eval: 16_h1=0.2340, 16_l2=0.1354, 32_h1=0.8477, 32_l2=0.5822
+
+{'train_err': 4.671443767845631, 'avg_loss': 0.14948620057106018, 'avg_lasso_loss': None, 'epoch_train_time': 10.053496939999945}
 

Plot the prediction, and compare with the ground-truth @@ -519,7 +519,7 @@ fig.show() -Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (3 minutes 26.114 seconds)

+Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (3 minutes 24.779 seconds)

-Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (0 minutes 7.057 seconds)

+Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (0 minutes 7.023 seconds)