-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvi.py
159 lines (143 loc) · 5.17 KB
/
envi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import sys
import config as conf
import torch
import random
import collections
import numpy as np
from collections import Counter
sys.path.insert(0, conf.ENV_DIR)
import r
from env import Env as CEnv
class Env(CEnv):
def __init__(self, debug=False, seed=None):
if seed:
super(Env, self).__init__(seed=seed)
else:
super(Env, self).__init__()
self.taken = np.zeros((15,))
self.left = np.array([17, 20, 17], dtype=np.int)
# 0表示上家,1表示地主,2表示下家
self.history = collections.defaultdict(lambda: np.zeros((15,)))
self.recent_handout = collections.defaultdict(lambda: np.zeros((15,)))
self.old_cards = dict()
self.debug = debug
def reset(self):
super(Env, self).reset()
self.taken = np.zeros((15,))
self.left = np.array([17, 20, 17])
self.history = collections.defaultdict(lambda: np.zeros((15,)))
self.recent_handout = collections.defaultdict(lambda: np.zeros((15,)))
self.old_cards = dict()
def _update(self, role, cards):
self.left[role] -= len(cards)
for card, count in Counter(cards - 3).items():
self.taken[card] += count
self.history[role][card] += count
self.recent_handout[role] = self.cards2arr(cards)
if self.debug:
char = '$'
handcards = self.cards2str(self.old_cards[role])
if role == 1:
char = '#'
name = '地主'
print('\n# 地主手牌: {}'.format(handcards), end='')
input()
elif role == 0:
name = '上家'
print('\n$ 上家手牌: {}'.format(handcards), end='')
input()
else:
name = '下家'
print('\n$ 下家手牌: {}'.format(handcards), end='')
input()
print('{} {}出牌: {},分别剩余: {}'.format(
char, name, self.cards2str(cards), self.left))
def step_manual(self, onehot_cards):
role = self.get_role_ID() - 1
self.old_cards[role] = self.get_curr_handcards()
arr_cards = self.onehot2arr(onehot_cards)
cards = self.arr2cards(arr_cards)
self._update(role, cards)
return super(Env, self).step_manual(cards)
def step_auto(self):
role = self.get_role_ID() - 1
self.old_cards[role] = self.get_curr_handcards()
cards, r, _ = super(Env, self).step_auto()
self._update(role, cards)
return cards, r, _
def step_random(self):
role = self.get_role_ID() - 1
self.old_cards[role] = self.get_curr_handcards()
actions = self.valid_actions(tensor=False)
cards = self.arr2cards(random.choice(actions))
self._update(role, cards)
return super(Env, self).step_manual(cards)
@property
def face(self):
"""
:return: 4 * 15 * 4 的数组,作为当前状态
"""
handcards = self.cards2arr(self.get_curr_handcards())
known = self.batch_arr2onehot([handcards, self.taken])
prob = self.get_state_prob().reshape(2, 15, 4)
face = np.concatenate((known, prob))
return torch.tensor(face, dtype=torch.float).to(DEVICE)
def valid_actions(self, tensor=True):
"""
:return: batch_size * 15 * 4 的可行动作集合
"""
handcards = self.cards2arr(self.get_curr_handcards())
last_two = self.get_last_two_cards()
if last_two[0]:
last = last_two[0]
elif last_two[1]:
last = last_two[1]
else:
last = []
last = self.cards2arr(last)
actions = r.get_moves(handcards, last)
if tensor:
return torch.tensor(self.batch_arr2onehot(actions),
dtype=torch.float).to(DEVICE)
else:
return actions
@classmethod
def arr2cards(cls, arr):
"""
:param arr: 15 * 4
:return: ['A','A','A', '3', '3'] 用 [3,3,14,14,14]表示
[3,4,5,6,7,8,9,10, J, Q, K, A, 2,BJ,CJ]
[3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
"""
res = []
for idx in range(15):
for _ in range(arr[idx]):
res.append(idx + 3)
return np.array(res, dtype=np.int)
@classmethod
def cards2arr(cls, cards):
arr = np.zeros((15,), dtype=np.int)
for card in cards:
arr[card - 3] += 1
return arr
@classmethod
def batch_arr2onehot(cls, batch_arr):
res = np.zeros((len(batch_arr), 15, 4), dtype=np.int)
for idx, arr in enumerate(batch_arr):
for card_idx, count in enumerate(arr):
if count > 0:
res[idx][card_idx][:int(count)] = 1
return res
@classmethod
def onehot2arr(cls, onehot_cards):
"""
:param onehot_cards: 15 * 4
:return: (15,)
"""
res = np.zeros((15,), dtype=np.int)
for idx, onehot in enumerate(onehot_cards):
res[idx] = sum(onehot)
return res
def cards2str(self, cards):
res = [conf.DICT[i] for i in cards]
return res