From 8f1001b039ce5e6c901774bafe0ffac78b40f02f Mon Sep 17 00:00:00 2001 From: Jacob Date: Mon, 9 Dec 2024 18:32:12 -0600 Subject: [PATCH 1/2] TimeSeriesDataset .split_measures --- torchcast/utils/data.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/torchcast/utils/data.py b/torchcast/utils/data.py index 89e33ee..3b380da 100644 --- a/torchcast/utils/data.py +++ b/torchcast/utils/data.py @@ -219,38 +219,26 @@ def get_groups(self, groups: Sequence[Any]) -> 'TimeSeriesDataset': """ return self[np.isin(self.group_names, groups)] - def split_measures(self, *measure_groups, which: Optional[int] = None) -> 'TimeSeriesDataset': + def split_measures(self, *measure_groups) -> 'TimeSeriesDataset': """ - Take a dataset with one tensor, split it into a dataset with multiple tensors. + Take a dataset and split it into a dataset with multiple tensors. - :param measure_groups: Each argument should be be a list of measure-names, or an indexer (i.e. list of ints or - a slice). - :param which: If there are already multiple measure groups, the split will occur within one of them; must - specify which. + :param measure_groups: Each argument should be a list of measure-names. :return: A :class:`.TimeSeriesDataset`, now with multiple tensors for the measure-groups. """ + concat_tensors = torch.cat(self.tensors, dim=2) - if which is None: - if len(self.measures) > 1: - raise RuntimeError(f"Must pass `which` if there's more than one groups:\n{self.measures}") - which = 0 - - self_tensor = self.tensors[which] - self_measures = self.measures[which] - - idxs = [] + idx_groups = [] for measure_group in measure_groups: - if isinstance(measure_group, slice) or isinstance(measure_group[0], int): - idxs.append(measure_group) - else: - idxs.append([self_measures.index(m) for m in measure_group]) + idx_groups.append([]) + for measure in measure_group: + idx_groups[-1].append(self.all_measures.index(measure)) - self_measures = np.array(self_measures) return type(self)( - *(self_tensor[:, :, idx].clone() for idx in idxs), + *(concat_tensors[:, :, idxs] for idxs in idx_groups), start_times=self.start_times, group_names=self.group_names, - measures=[tuple(self_measures[idx]) for idx in idxs], + measures=measure_groups, dt_unit=self.dt_unit ) From 62ae734254d8a074994a5d5076667f27acc7eb20 Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 12 Dec 2024 16:52:04 -0600 Subject: [PATCH 2/2] add model_mat_kwarg_name --- torchcast/process/regression.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchcast/process/regression.py b/torchcast/process/regression.py index b0ab33d..161fda0 100644 --- a/torchcast/process/regression.py +++ b/torchcast/process/regression.py @@ -31,7 +31,8 @@ def __init__(self, predictors: Sequence[str], measure: Optional[str] = None, fixed: bool = True, - decay: Optional[Tuple[float, float]] = None): + decay: Optional[Tuple[float, float]] = None, + model_mat_kwarg_name: str = 'X'): super().__init__( id=id, @@ -50,7 +51,7 @@ def __init__(self, if isinstance(decay, tuple): decay = SingleOutput(numel=len(predictors), transform=Bounded(*decay)) self.f_modules['all_self'] = decay - self.expected_kwargs = ['X'] + self.expected_kwargs = [model_mat_kwarg_name] def _build_h_mat(self, inputs: Dict[str, torch.Tensor], num_groups: int, num_times: int) -> torch.Tensor: # if not torch.jit.is_scripting(): @@ -59,7 +60,7 @@ def _build_h_mat(self, inputs: Dict[str, torch.Tensor], num_groups: int, num_tim # except KeyError as e: # raise TypeError(f"Missing required keyword-arg `X` (or `{self.id}__X`).") from e # else: - X = inputs['X'] + X = inputs[self.expected_kwargs[0]] assert not torch.isnan(X).any() assert not torch.isinf(X).any()