Skip to content

Commit

Permalink
Merge pull request #30 from strongio/feature/bern-filter
Browse files Browse the repository at this point in the history
binomial filter
  • Loading branch information
jwdink authored Jan 10, 2025
2 parents 9df579c + 2828686 commit 0e09515
Show file tree
Hide file tree
Showing 8 changed files with 540 additions and 127 deletions.
222 changes: 130 additions & 92 deletions docs/examples/electricity.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion torchcast/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.3'
__version__ = '0.5.0'
11 changes: 3 additions & 8 deletions torchcast/covariance/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor, nn, jit

from torchcast.process.utils import Identity
from torchcast.covariance.util import num_off_diag
from torchcast.covariance.util import num_off_diag, mini_cov_mask
from torchcast.internals.utils import is_near_zero, validate_gt_shape
from torchcast.process.base import Process

Expand Down Expand Up @@ -80,7 +80,7 @@ def from_processes(cls,
if 'init_diag_multi' not in kwargs:
kwargs['init_diag_multi'] = .01
if 'method' in kwargs and kwargs['method'] == 'low_rank':
warn("``method='low_rank'`` not recommended for processes, consider 'low_rank+block_diag'")
warn("``method='low_rank'`` not recommended for processes")
elif cov_type == 'initial':
if (state_rank - len(no_cov_idx)) >= 10:
# by default, use low-rank parameterization for initial cov:
Expand Down Expand Up @@ -173,12 +173,7 @@ def __init__(self,
empty_idx = set(empty_idx)
assert all(isinstance(x, int) for x in empty_idx)
self.param_rank = self.rank - len(empty_idx)
mask = torch.zeros((self.rank, self.param_rank))
c = 0
for r in range(self.rank):
if r not in empty_idx:
mask[r, c] = 1.
c += 1
mask = mini_cov_mask(rank=self.rank, empty_idx=empty_idx)
self.register_buffer('mask', mask)

self._set_params(method, init_diag_multi)
Expand Down
13 changes: 13 additions & 0 deletions torchcast/covariance/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Collection

import torch
from torch import Tensor

Expand All @@ -9,3 +11,14 @@ def num_off_diag(rank: int) -> int:
def cov2corr(cov: Tensor) -> Tensor:
std_ = torch.sqrt(torch.diagonal(cov, dim1=-2, dim2=-1))
return cov / (std_.unsqueeze(-1) @ std_.unsqueeze(-2))


def mini_cov_mask(rank: int, empty_idx: Collection[int], **kwargs) -> Tensor:
param_rank = rank - len(empty_idx)
mask = torch.zeros((rank, param_rank), **kwargs)
c = 0
for r in range(rank):
if r not in empty_idx:
mask[r, c] = 1.
c += 1
return mask
Loading

0 comments on commit 0e09515

Please sign in to comment.