-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
27 lines (24 loc) · 943 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import sys
import argparse
from train_utils import make_label,load_data, build_model
def main(dirname):
x_train,y_train,x_test,y_test=load_data(dirname)
print("x_train_shape----------->")
print(x_train.shape)
print("y_train_shape")
print (y_train.shape)
num_val_samples=(x_train.shape[0])//5
model=build_model(y_train.shape[1])
print('Training stage')
print('==============')
history=model.fit(x_train,y_train,epochs=1000,batch_size=16,validation_data=(x_test,y_test))
score, acc = model.evaluate(x_test,y_test,batch_size=16,verbose=0)
print('Test performance: accuracy={0}, loss={1}'.format(acc, score))
model.save('model.h5')
if __name__=='__main__':
parser = argparse.ArgumentParser(description='Training Model')
parser.add_argument("--input_train_path",help=" ")
args=parser.parse_args()
input_train_path=args.input_train_path
main(input_train_path)