-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreadWeights.py
67 lines (52 loc) · 1.65 KB
/
readWeights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import sys
from reignforce import Agent
import os
import gym
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("path",help = "Folder Path to weights")
args = parser.parse_args()
envs = {'CartPole': 'CartPole-v0',
'Acrobot': 'Acrobot-v1',
'LunarLander': 'LunarLander-v2',
'MountainCar':'MountainCar-v0'}
path = args.path
directory = os.path.dirname(path)
#Model data
modeldata = os.path.split(directory)[1]
parentdir = os.path.dirname(directory)
#Variant
variant = os.path.split(parentdir)[1]
#Environment
environment = os.path.dirname(parentdir)
print(environment)
if environment in envs:
opengym_env = envs.get(environment, None)
else:
print("Please provide the right environment")
exit()
env = gym.make(opengym_env)
n_actions = env.action_space.n
n_states = len(env.observation_space.low)
temp = modeldata.rsplit("_")
print(temp)
h1_layer = 0
h2_layer = 0
if(variant == "Reinforce"):
if(temp[3]=='1'):
h1_layer = temp[4]
elif(temp[3]=='2'):
h1_layer = int(temp[4])
h2_layer = int(temp[5])
else:
if(temp[2]=='1'):
h1_layer = temp[3]
elif(temp[2]=='2'):
h1_layer = int(temp[3])
h2_layer = int(temp[4])
agent = Agent(input_dims=n_states, n_actions = n_actions,layer1_size=h1_layer, layer2_size=h1_layer)
# Load_model = keras.models.load_model(sys.argv[1])
agent.policy.load_weights(path)
print(agent.policy.get_weights())