-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathenv_external.py
60 lines (49 loc) · 1.46 KB
/
env_external.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
外部で動作する環境を想定
"""
import random
from typing import Callable
class ExternalEnv:
def __init__(self):
self.pos = 0
self.reward = 0
self.done = False
def step(self, action):
# posが5以下なら1で終了
# posが-5以下なら-1で終了
if action == 0:
self.pos -= 1
else:
self.pos += 1
if self.pos >= 5:
self.reward = 1
self.done = True
if self.pos <= -5:
self.reward = -1
self.done = True
def run_external_env(agent: Callable[[int, int], int]):
"""
ユーザが定義したagent関数を元に実際にシミュレーションする関数
agent関数の引数は[step, state]で戻り値はaction
"""
for episode in range(5):
env = ExternalEnv()
act_history = []
pos_history = [env.pos]
for step in range(30):
action = agent(step, env.pos)
act_history.append(action)
env.step(action)
pos_history.append(env.pos)
if env.done:
break
print(f"--- {episode} ---")
print(f"reward: {env.reward}")
print(f"action: {act_history}")
print(f"state : {pos_history}")
if __name__ == "__main__":
# サンプルagent
def sample_agent(step: int, state: int) -> int:
return random.randint(0, 1)
# 動かす例
run_external_env(sample_agent)