generated from mikayahlevi/transformer-train-script
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparallel_mru_op.py
170 lines (118 loc) · 7.1 KB
/
parallel_mru_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import math
# hs: Hillis-Steele scan
class hs_parallel_mru_op_class(torch.autograd.Function):
@staticmethod
def forward(ctx, start_matrix_states):
final_matrix_states = start_matrix_states.clone()
sequence_length = start_matrix_states.size(-3)
n_stages = math.ceil(math.log2(sequence_length))
for stage in range(n_stages):
stage_stride = 2 ** stage
final_matrix_states[..., stage_stride:, :, :] = final_matrix_states[..., :-stage_stride, :, :] @ final_matrix_states[..., stage_stride:, :, :]
ctx.save_for_backward(start_matrix_states, final_matrix_states)
ctx.sequence_length = sequence_length
return final_matrix_states
@staticmethod
def backward(ctx, grad_final_matrix_states):
def create_eye_for_shift(shape):
resized_eye = torch.eye(*shape[-2:], device = grad_final_matrix_states.device)
while resized_eye.dim() < len(shape):
resized_eye = resized_eye.unsqueeze(0)
resized_eye_shape = shape[:-3]
resized_eye_shape = list(resized_eye_shape)
while len(resized_eye_shape) < len(shape):
resized_eye_shape.append(1)
resized_eye = resized_eye.repeat(*resized_eye_shape)
return resized_eye
def create_zeros_for_shift(shape):
new_shape = list(shape)
new_shape[-3] = 1
return torch.zeros(new_shape, device = grad_final_matrix_states.device)
start_matrix_states, final_matrix_states = ctx.saved_tensors
# grad_before_start_matrix_states is A as described in the README
# tl is U as described in the README
# bl is L as described in the README
# grad_before_start_matrix_states = torch.cat((create_eye_for_shift(transposed_final_matrix_states.shape), transposed_final_matrix_states[..., :-1, :, :]), dim = -3)
# faster implementation:
grad_before_start_matrix_states = final_matrix_states.transpose(-1, -2).roll(1, dims = -3)
grad_before_start_matrix_states[..., 0, :, :] = torch.eye(grad_before_start_matrix_states.size(-2), device = grad_before_start_matrix_states.device)
# tl = torch.cat((start_matrix_states[..., 1:, :, :], create_zeros_for_shift(start_matrix_states.shape)), dim = -3).transpose(-1, -2)
# faster implementation:
tl = start_matrix_states.transpose(-1, -2).roll(-1, dims = -3)
tl[..., -1, :, :] = torch.zeros((tl.size(-2), tl.size(-1)), device = tl.device)
bl = grad_final_matrix_states
sequence_length = ctx.sequence_length
n_stages = math.ceil(math.log2(sequence_length))
for stage in range(n_stages):
stage_stride = 2 ** stage
bl[..., :-stage_stride, :, :] = bl[..., stage_stride:, :, :] @ tl[..., :-stage_stride, :, :] + bl[..., :-stage_stride, :, :]
tl[..., :-stage_stride, :, :] = tl[..., stage_stride:, :, :] @ tl[..., :-stage_stride, :, :]
grad_start_matrix_states = grad_before_start_matrix_states @ bl
return grad_start_matrix_states
# bk: Brent-Kung scan
class bk_parallel_mru_op_class(torch.autograd.Function):
@staticmethod
def forward(ctx, start_matrix_states):
final_matrix_states = start_matrix_states.clone()
sequence_length = start_matrix_states.size(-3)
n_stages = math.ceil(math.log2(sequence_length))
# first sweep
for stage in range(n_stages):
# abbreviate stage_stride as sts
sts = 2 ** stage
final_matrix_states[..., 2 * sts - 1::2 * sts, :, :] = final_matrix_states[..., sts - 1:-sts:2 * sts, :, :] @ final_matrix_states[..., 2 * sts - 1::2 * sts, :, :]
# second sweep
for stage in reversed(range(n_stages - 1)):
# abbreviate stage_stride as sts
sts = 2 ** stage
final_matrix_states[..., 2 * sts + sts - 1::2 * sts, :, :] = final_matrix_states[..., 2 * sts -1:-sts:2 * sts, :, :] @ final_matrix_states[..., 2 * sts + sts - 1::2 * sts, :, :]
ctx.save_for_backward(start_matrix_states, final_matrix_states)
ctx.sequence_length = sequence_length
return final_matrix_states
@staticmethod
def backward(ctx, grad_final_matrix_states):
def create_eye_for_shift(shape):
resized_eye = torch.eye(*shape[-2:], device = grad_final_matrix_states.device)
while resized_eye.dim() < len(shape):
resized_eye = resized_eye.unsqueeze(0)
resized_eye_shape = shape[:-3]
resized_eye_shape = list(resized_eye_shape)
while len(resized_eye_shape) < len(shape):
resized_eye_shape.append(1)
resized_eye = resized_eye.repeat(*resized_eye_shape)
return resized_eye
def create_zeros_for_shift(shape):
new_shape = list(shape)
new_shape[-3] = 1
return torch.zeros(new_shape, device = grad_final_matrix_states.device)
start_matrix_states, final_matrix_states = ctx.saved_tensors
# grad_before_start_matrix_states is A as described in the README
# tl is U as described in the README
# bl is L as described in the README
# grad_before_start_matrix_states = torch.cat((create_eye_for_shift(transposed_final_matrix_states.shape), transposed_final_matrix_states[..., :-1, :, :]), dim = -3)
# faster implementation:
grad_before_start_matrix_states = final_matrix_states.transpose(-1, -2).roll(1, dims = -3)
grad_before_start_matrix_states[..., 0, :, :] = torch.eye(grad_before_start_matrix_states.size(-2), device = grad_before_start_matrix_states.device)
# tl = torch.cat((start_matrix_states[..., 1:, :, :], create_zeros_for_shift(start_matrix_states.shape)), dim = -3).transpose(-1, -2)
# faster implementation:
tl = start_matrix_states.transpose(-1, -2).roll(-1, dims = -3)
tl[..., -1, :, :] = torch.zeros((tl.size(-2), tl.size(-1)), device = tl.device)
bl = grad_final_matrix_states
sequence_length = ctx.sequence_length
n_stages = math.ceil(math.log2(sequence_length))
# first sweep
for stage in range(n_stages):
# abbreviate stage_stride as sts
sts = 2 ** stage
bl[..., :-sts:2*sts, :, :] = bl[..., sts::2*sts, :, :] @ tl[..., :-sts:2*sts, :, :] + bl[..., :-sts:2*sts, :, :]
tl[..., :-sts:2*sts, :, :] = tl[..., sts::2*sts, :, :] @ tl[..., :-sts:2*sts, :, :]
# second sweep
for stage in reversed(range(n_stages - 1)):
# abbreviate stage_stride as sts
sts = 2 ** stage
bl[..., sts:-sts:2*sts, :, :] = bl[..., 2*sts::2*sts, :, :] @ tl[..., sts:-sts:2*sts, :, :] + bl[..., sts:-sts:2*sts, :, :]
tl[..., sts:-sts:2*sts, :, :] = tl[..., 2*sts::2*sts, :, :] @ tl[..., sts:-sts:2*sts, :, :]
grad_start_matrix_states = grad_before_start_matrix_states @ bl
return grad_start_matrix_states
parallel_mru_op = bk_parallel_mru_op_class.apply