-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtransforms.py
62 lines (50 loc) · 1.71 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import skimage
import numpy as np
class ToTensor(nn.Module):
def __init__(self):
super(ToTensor, self).__init__()
def forward(self, tracks):
with torch.no_grad():
for i, el in enumerate(tracks):
tracks[i] = torch.Tensor(el)
return tracks
class Normalize(nn.Module):
def __init__(self):
super(Normalize, self).__init__()
def forward(self, tracks):
with torch.no_grad():
vector = tracks[0]
# vector = torch.Tensor(vector)
min_v = np.min(vector)
range_v = np.max(vector) - min_v
if range_v > 0:
normalised = (vector - min_v) / range_v
else:
normalised = np.zeros(vector.shape)
tracks.insert(0, normalised)
return tracks
class HorizontalCrop(nn.Module):
def __init__(self, crop_size):
super(HorizontalCrop, self).__init__()
self.crop_size = crop_size
def forward(self, vector):
processed_tracks = []
with torch.no_grad():
for track in vector:
cropped_track = track[:, :self.crop_size]
processed_tracks.append(cropped_track)
return processed_tracks
class Resize(nn.Module):
def __init__(self, width, height):
super(Resize, self).__init__()
self.width = width
self.height = height
def forward(self, vector):
processed_tracks = []
with torch.no_grad():
for track in vector:
temp = skimage.transform.resize(track, (self.width, self.height))
processed_tracks.append(torch.Tensor(temp))
return processed_tracks