-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscatter.py
66 lines (48 loc) · 2.08 KB
/
scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
from rusty1s' torch_scatter implementation.
"""
import torch
from itertools import repeat
def maybe_dim_size(index, dim_size=None):
if dim_size is not None:
return dim_size
dim = index.max().item() + 1 if index.numel() > 0 else 0
return int(dim)
def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value.
if index.dim() == 1:
index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim)
if index.numel() > 0:
index = index.view(index_size).expand_as(src)
else: # pragma: no cover
# PyTorch has a bug when view is used on zero-element tensors.
index = src.new_empty(index_size, dtype=torch.long)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given.
if out is None:
out_size = list(src.size())
dim_size = maybe_dim_size(index, dim_size)
out_size[dim] = dim_size
out = src.new_full(out_size, fill_value)
return src, out, index, dim
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return out.scatter_add_(dim, index, src)
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
out = scatter_add(src, index, dim, out, dim_size, fill_value)
count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim))
return out / count.clamp(min=1)