forked from BIDData/BIDMach
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnistlr2.ssc
executable file
·42 lines (33 loc) · 1.69 KB
/
mnistlr2.ssc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class xopts extends Learner.Options with SFilesDS.Opts with GLM.Opts with ADAGrad.Opts;
val mnopts = new xopts
val mdir = "/data/MNIST8M/parts/"
mnopts.fnames = List(FilesDS.simpleEnum(mdir+"cats%02d.smat.lz4", 1, 0), // File name templates, %02d is replaced by a number
FilesDS.simpleEnum(mdir+"part%02d.smat.lz4", 1, 0));
mnopts.nstart = 0; // Starting file number
mnopts.nend = 70; // Ending file number
mnopts.order = 0; // sample order, 0=linear, 1=random
mnopts.lookahead = 2; // number of prefetch threads
mnopts.featType = 1; // feature type, 0=binary, 1=linear
mnopts.fcounts = icol(10,784); // how many rows to pull from each input matrix
mnopts.eltsPerSample = 300 // how many rows to allocate (non-zeros per sample)
mnopts.batchSize = 1000
mnopts.npasses = 2
mnopts.lrate = 0.001
mnopts.links = 2*iones(10,1);
mnopts.targets = mkdiag(ones(10,1)) \ zeros(10, 784);
mnopts.rmask = zeros(1,10) \ ones(1, 784);
val ds = {
implicit val ec = threadPool(4); // make sure there are enough threads (more than the lookahead count)
new SFilesDS(mnopts); // the datasource
}
val nn = new Learner( // make a learner instance
ds, // datasource
new GLM(mnopts), // the model (a GLM model)
null, // list of mixins or regularizers
new ADAGrad(mnopts), // the optimization class to use
mnopts) // pass the options to the learner as well
nn.train
mnopts.nstart=71
mnopts.nend=80
mnopts.npasses=1
nn.repredict