-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgnn_train.py
24 lines (20 loc) · 843 Bytes
/
gnn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import tensorflow as tf
from data_loader.gnn_data_generator import load_data
from models.text_gnn import TextGNN
from utils.config_utils import get_config
def train(args):
tf.set_random_seed(19)
tf_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=tf_config) as sess:
data_generator = load_data(args['dataset']['path'], args['dataset']['dataset_name'])
model = TextGNN(sess=sess, data_generator=data_generator, **data_generator, **args['dataset'], **args['model'],
**args)
model.train()
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# config = get_config('gnn/aclImdb')
# config = get_config('gnn/cnews')
config = get_config('gnn/cnews_voc')
config['tag'] = 'base'
train(config)