-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTransform.py
83 lines (75 loc) · 3.79 KB
/
Transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch as pt
from torch.distributions.transforms import ComposeTransform
from torch.distributions.constraints import independent, real
from .torch_utils import jit_script
def calc_arma_transform(x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift):
assert(x_is_in_not_out.shape==x.shape[-1:])
# Match dimensions
input = input[[None] * (len(x.shape) - len(input.shape)) + [Ellipsis]].expand(*(x.shape[:-1] + (input.shape[-1],)))
output = output[[None] * (len(x.shape) - len(output.shape)) + [Ellipsis]].expand(*(x.shape[:-1] + (output.shape[-1],)))
return calc_arma_transform_core(x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift)
@jit_script
def calc_arma_transform_core(x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift):
# Assume last coefficient is one
i_coefs = i_coefs[..., :-1]
o_coefs = o_coefs[..., :-1]
# Trim input and output
input = input[..., (-i_coefs.shape[-1]):]
output = output[..., (-o_coefs.shape[-1]):]
# Loop over input
ret_val = []
for n in range(x.shape[-1]):
next_x = x[..., n][..., None]
if x_is_in_not_out[n]:
next_val = next_x + ((input * i_coefs).sum(-1) - \
(output * o_coefs).sum(-1))[..., None] + drift
input = pt.cat([input[..., 1:], next_x], dim=-1)
output = pt.cat([output[..., 1:], next_val], dim=-1)
else:
next_val = next_x + ((output * o_coefs).sum(-1) - \
(input * i_coefs).sum(-1))[..., None] - drift
input = pt.cat([input[..., 1:], next_val], dim=-1)
output = pt.cat([output[..., 1:], next_x], dim=-1)
ret_val.append(next_val)
return pt.cat(ret_val, dim=-1)
class ARMATransform(pt.distributions.transforms.Transform):
'''
Invertible ARMA transform with support for transforming only part of the samples.
The transform has a Jacobian determinant of one even if only part of the samples are used as input.
See a discussion with ChatGPT on the subject at https://chat.openai.com/share/55d34600-6b9d-49ea-b7de-0b70b0e2382f.
'''
domain = independent(real, 1)
codomain = independent(real, 1)
bijective = True
def __init__(self, i_tail, o_tail, i_coefs, o_coefs, drift, x=None, idx=None, x_is_in_not_out=None):
super().__init__()
self.i_tail, self.o_tail, self.i_coefs, self.o_coefs, self.drift = i_tail, o_tail, i_coefs, o_coefs, drift
self.x, self.idx, self.x_is_in_not_out = x, idx, x_is_in_not_out
if x_is_in_not_out is not None:
if not x_is_in_not_out[idx].all():
raise UserWarning('Inputs must be innovations.')
def log_abs_det_jacobian(self, x, y):
return x.new_zeros(x.shape[:(-1)])
def get_x(self, x):
x_is_in_not_out = pt.tensor([True] * (x if self.x is None else self.x).shape[-1]) if self.x_is_in_not_out is None else self.x_is_in_not_out.clone()
if self.x is not None:
x_clone = self.x.clone()
x_clone[..., self.idx] = x
x = x_clone
return x, x_is_in_not_out
def _call(self, x):
x, x_is_in_not_out = self.get_x(x)
x = calc_arma_transform(x, x_is_in_not_out, self.i_tail, self.o_tail, self.i_coefs, self.o_coefs, self.drift)
if self.x is not None:
x = x[..., self.idx]
return x
def _inverse(self, x):
x, x_is_in_not_out = self.get_x(x)
x_is_in_not_out = ~x_is_in_not_out
if self.x is not None:
x_is_in_not_out[self.idx] = ~x_is_in_not_out[self.idx]
x_is_in_not_out = ~x_is_in_not_out
x = calc_arma_transform(x, x_is_in_not_out, self.i_tail, self.o_tail, self.i_coefs, self.o_coefs, self.drift)
if self.x is not None:
x = x[..., self.idx]
return x