-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmysqlstore.py
117 lines (107 loc) · 3.81 KB
/
mysqlstore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import MySQLdb
import pickle
import threading
import logging
import random
logger = logging.getLogger(__name__)
create_table = '''
CREATE TABLE IF NOT EXISTS `state_storage` (
`modelid` INT NOT NULL,
`key` VARCHAR(128) NOT NULL,
`state` LONGBLOB NOT NULL,
`last_token` INT NOT NULL,
`modified_on` TIMESTAMP on update CURRENT_TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`modelid`, `key`),
INDEX (`modified_on`));
'''
class MySQLStore:
def __init__(self, model, dbhost, dbname, dbuser, dbpass, default_token=0, commit_every = 256, max_cache = 1024, modelid=0, max_days=None):
self.model = model
self.default_token = default_token
self.dbhost = dbhost
self.dbname = dbname
self.dbuser = dbuser
self.dbpass = dbpass
self.commit_every = commit_every
self.modelid = modelid
self.max_days = max_days
self.max_cache = max_cache
self.cached = {}
self.dirty = set()
self.lock = threading.RLock()
self.writes = 0
c = self.opendb()
c.cursor().execute(create_table)
c.commit()
c.close()
def opendb(self):
conn = MySQLdb.connect(host=self.dbhost, database=self.dbname, user=self.dbuser, password=self.dbpass)
return conn
def get_state(self, key):
with self.lock:
if key in self.cached:
return self.cached[key]
con = self.opendb()
try:
cur = con.cursor()
cur.execute('SELECT `state`, `last_token` FROM `state_storage` WHERE `modelid` = %s AND `key` = %s', (self.modelid, key))
self.check_cache_size()
return cur.fetchone()
finally:
con.close()
def check_cache_size(self):
#logger.info("Current cache size %d", len(self.cached))
if len(self.cached) > self.max_cache:
cleankeys = [x for x in self.cached.keys() if (x not in self.dirty)]
random.shuffle(cleankeys)
for k in cleankeys[:len(cleankeys)//8]:
del self.cached[k]
def write_state(self, key, state, token):
with self.lock:
self.cached[key] = (state, token)
self.dirty.add(key)
self.writes += 1
if self.writes > self.commit_every:
self.commit()
def forward(self, request):
r = self.get_state(request.key)
if r:
(state, token) = r
else:
r = self.get_state('_default')
if r:
logger.info("loading default state for %s", request.key)
(state, token) = r
else:
logger.info("loading empty state for %s", request.key)
(state, token) = (None, self.default_token)
request.initial_state = pickle.loads(state) if state else None
request.initial_token = token
def backward(self, request):
#for k,v in request.final_state.items():
# print('state', k, v, v.size())
#print('serialized size:', len(pickle.dumps(request.final_state)))
self.write_state(request.key, pickle.dumps(request.final_state), request.last_token)
def commit(self):
with self.lock:
logger.info("Commiting %d states", len(self.dirty))
if len(self.dirty) == 0:
return
try:
dirty_states = [(self.modelid, x) + self.cached[x] for x in self.dirty]
con = self.opendb()
try:
cur = con.cursor()
cur.executemany("REPLACE INTO state_storage (`modelid`, `key`, `state`, `last_token`) VALUES (%s, %s, %s, %s)", dirty_states)
if self.max_days:
logger.info("Deleting states older than %d days", self.max_days)
cur.execute("DELETE FROM state_storage WHERE TIMESTAMPDIFF(DAY, `modified_on`, CURRENT_TIMESTAMP()) > %s AND `key` <> '_default'", (self.max_days,))
logger.info("Deleted %d states", cur.rowcount)
con.commit()
self.writes = 0
self.dirty.clear()
finally:
con.close()
logger.info("Commited")
except Exception:
logger.exception("Error writing states")