-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathamaz_sampling.py
32 lines (26 loc) · 964 Bytes
/
amaz_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import six
import numpy as np
class Sampling(object):
def __init__(self):
pass
def random_sampling(self,epoch,batch_size,data_length):
"""
yield indices result of random sampling
"""
for i in six.moves.range(epoch):
yield np.random.permutation(data_length)[:batch_size]
def random_sampling_label_normarize(self,data_length,batch_size,category_num):
"""
### FIX ME ###
yield indices result of random sampling but the sampled-item
number is equal dependigng on category
"""
return
def pick_random_permutation(self,pick_number, sample_number, sort=False):
pick_number = int(pick_number)
sample_number = int(sample_number)
sort = bool(sort)
if sort:
return np.sort(np.random.permutation(sample_number)[:pick_number])
else:
return np.random.permutation(sample_number)[:pick_number]