diff --git a/eg/README.md b/eg/README.md index feb8461..6c499ae 100644 --- a/eg/README.md +++ b/eg/README.md @@ -16,6 +16,12 @@ RiveScript-Python. * [twilio](twilio/) - An example that uses the Twilio SMS API to create a bot that can receive SMS text messages from users and reply to them using RiveScript. +* [sessions](sessions/) - An example RiveScript bot which manages user session + data using RiveScript event callbacks, assuming single user per bot instance + scenario. Session continues when the script is run again (as it would happen + if the script was a stateless webhook service for example). + A simple JSON file-based data store is provided, but it's easy to implement + database persistence by subclassing SessionStore class. ## Code Snippets diff --git a/eg/brain/rpg.rive b/eg/brain/rpg.rive index aaeacf4..e8f6735 100644 --- a/eg/brain/rpg.rive +++ b/eg/brain/rpg.rive @@ -167,7 +167,7 @@ ^ life support system comes on, which includes an anesthesia to put you to sleep\s ^ for the duration of the long flight to Mars.\n\n ^ When you awaken, you are on Mars. The space shuttle seems to have crash-landed.\s - ^ There is a space suit here.{topic=crashed} + ^ There is a space suit here.{topic=crashed}{@look} < topic // Crashed on Mars diff --git a/eg/sessions/example.py b/eg/sessions/example.py new file mode 100755 index 0000000..a618eb0 --- /dev/null +++ b/eg/sessions/example.py @@ -0,0 +1,218 @@ +#!/usr/bin/python + +from __future__ import print_function +import os +import sys +import json +import logging as log +log.basicConfig(format='%(levelname)s -- %(message)s', level=log.WARN) + +# Setup sys.path to be able to import rivescript from this local git repo. +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from rivescript import RiveScript + +SCRIPT_DIR = os.path.join(os.path.dirname(__file__), '..', 'brain') +SESSION_FILE = os.path.join(os.path.dirname(__file__), 'sessions.json') + +print('-' * 70) +print("""An example RiveScript bot which saves user sessions and lets +the user continue when the script is run again. + +To see how it works, run the script with user name as the argument, eg: +`python example.py john` and enter the RPG demo by typing `rpg demo`. + +After playing a while, exit the script (via /quit, Ctrl-C, Ctrl-D, etc.) +and run again to jump back to where you stopped in the Rive script. Then +start the script as a different user. You can also check the contents +of `sessions.json` file once in a while to see session data.""") +print('-' * 70) + + +class SessionStore(object): + """Abstract SessionStore class.""" + def load(self, user): + """Load and return RiveSession object for a given user""" + raise NotImplementedError('Subclass SessionStore and override load()') + + def save(self, session): + """Save session based on given RiveSession object""" + raise NotImplementedError('Subclass SessionStore and override save()') + + +class SimpleSessionStore(SessionStore): + """Basic SessionStore implementation, reading/writing a single JSON file.""" + def __init__(self, file_name): + super(SimpleSessionStore, self).__init__() + self._file_name = file_name + + def load(self, user): + """Load session data from JSON file.""" + try: + with file(self._file_name, 'rb') as sf: + data = json.load(sf) + if user in data: + return RiveSession(user, data=data[user]) + except ValueError: + log.warn("Malformed JSON data in file: {}".format(file_name)) + except IOError: + # file not found, ignore + pass + + return RiveSession(user) # new (empty) session + + def save(self, session): + """Save session to JSON file, preserving other sessions (if any).""" + alldata = {} + try: + with file(self._file_name, 'rb') as sf: + alldata = json.load(sf) + except ValueError: + log.warn("Malformed JSON data in file: {}".format(file_name)) + except IOError: + # file not found, ignore + pass + + alldata[session._user] = session._data + with file(self._file_name, 'wb') as sf: + json.dump(alldata, sf, indent=4) + + +class RiveSession(object): + """User session object. + + Structure of session data: + { + 'topic': 'topic/redirect', + 'vars' : { + 'name' : 'value', + ... + } + } + """ + def __init__(self, user, data={'vars':{}}): + self._user = user + self._data = data + + def set_topic(self, topic, redirect=None): + self._data['topic'] = "{}/{}".format(topic, redirect) if redirect else topic + + def get_topic(self): + return self._data['topic'] if 'topic' in self._data else None + + def set_variable(self, name, value): + self._data['vars'][name] = value + + def get_variable(self, name): + if 'name' in self._data['vars']: + return self._data['vars'][name] + else: + return None + + def variables(self): + """User variables iterator.""" + for k, v in self._data['vars'].items(): + yield k, v + + +class RiveBot(object): + """An example RiveScript bot using callbacks to manage user session data. + + This example assumes single user per RiveScript instance and as + such it's suitable for use in stateless services (e.g. in web apps + receiving webhooks). Just init, get reply and teardown. Of course, + it will also work in RTM implementations with custom longer-lived + bot threads. + + Session state is persisted to a single JSON file. This wouldn't be + thread-safe in a concurrent environment (e.g. web server). In such + case it would be recommended to subclass SessionStore and implement + database persistence (preferably via one of great Python ORMs such + as SQLAlchemy or peewee). + """ + def __init__(self, script_dir, user, ss, debug=False): + self._user = user + self._redirect = None + + # init RiveScript + self._rs = RiveScript(debug=debug) + self._rs.load_directory(script_dir) + self._rs.sort_replies() + + # restore session + if isinstance(ss, SessionStore): + self._ss = ss + else: + raise RuntimeError("RiveBot init error: provided session store object is not a SessionStore instance.") + + self._restore_session() + + # register event callbacks + self._rs.on('topic', self._topic_cb) + self._rs.on('uservar', self._uservar_cb) + + def _topic_cb(self, user, topic, redirect=None): + """Topic callback. + + This is a single-user-per-rive (stateless instance) scenario; in a multi-user + scenario within a single thread, callback functions should delegate the + execution to proper user session objects. + """ + log.debug("Topic callback: user={}, topic={}, redirect={}".format(user, topic, redirect)) + self._session.set_topic(topic, redirect) + + def _uservar_cb(self, user, name, value): + """Topic callback. See comment for `_topic_cb()`""" + log.debug("User variable callback: user={}, name={}, value={}".format(user, name, value)) + self._session.set_variable(name, value) + + def _restore_session(self): + self._session = self._ss.load(self._user) + + # set saved user variables + for name, value in self._session.variables(): + self._rs.set_uservar(self._user, name, value) + + # set saved topic + topic = self._session.get_topic() + if topic: + if '/' in topic: + topic, self._redirect = topic.split('/') + self._rs.set_topic(self._user, topic) + + def _save_session(self): + self._ss.save(self._session) + + def run(self): + log.info("RiveBot starting...") + if self._redirect: + # Repeat saved redirect so that the user gets the context + # after session restart. + redir_reply = self._rs.redirect(self._user, self._redirect) + print("bot> Welcome back!") + print("bot>", redir_reply) + + while True: + msg = raw_input("{}> ".format(self._user)) + if msg == '/quit': + self.stop() + break + reply = self._rs.reply(self._user, msg) + print("bot>", reply) + + def stop(self): + log.info("RiveBot shutting down...") + print("\nbot> Bye.") + self._save_session() + + +if __name__ == "__main__": + user = sys.argv[1] if len(sys.argv) > 1 else 'default' + store = SimpleSessionStore(SESSION_FILE) + bot = RiveBot(SCRIPT_DIR, user, store) + try: + bot.run() + except (KeyboardInterrupt, EOFError): + bot.stop() + +# vim:expandtab diff --git a/requirements.txt b/requirements.txt index a735e77..1ffda83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ nose==1.3.7 +pyreadline==2.1 six==1.10.0 diff --git a/rivescript/rivescript.py b/rivescript/rivescript.py index f0aec35..82e7c53 100644 --- a/rivescript/rivescript.py +++ b/rivescript/rivescript.py @@ -32,6 +32,21 @@ import pprint import copy import codecs +from collections import deque + +# Configure readline for all interactive prompts +try: + import readline +except ImportError: + # Windows + import pyreadline as readline +else: + # Unix + import rlcompleter + if(sys.platform == 'darwin'): + readline.parse_and_bind("bind ^I rl_complete") + else: + readline.parse_and_bind("tab: complete") from . import __version__ from . import python @@ -84,7 +99,7 @@ class RiveScript(object): # Initialization and Utility Methods # ############################################################################ - def __init__(self, debug=False, strict=True, depth=50, log="", utf8=False): + def __init__(self, debug=False, strict=True, depth=50, log="", utf8=False, history_length=8): """Initialize a new RiveScript interpreter.""" ### @@ -100,8 +115,10 @@ def __init__(self, debug=False, strict=True, depth=50, log="", utf8=False): self.unicode_punctuation = re.compile(r'[.,!?;:]') # Misc. - self._strict = strict # Strict mode - self._depth = depth # Recursion depth limit + self._strict = strict # Strict mode + self._depth = depth # Recursion depth limit + self._callbacks = {} # Callbacks + self._history_length = history_length ### # Internal fields. @@ -1492,6 +1509,7 @@ def set_uservar(self, user, name, value): self._users[user] = {"topic": "random"} self._users[user][name] = value + self._fire_event('uservar', user, name, value) def set_uservars(self, user, data=None): """Set many variables for a user, or set many variables for many users. @@ -1749,6 +1767,58 @@ def current_user(self): self._warn("current_user() is meant to be used from within a Python object macro!") return self._current_user + def on(self, event_name, callback): + """Register an event callback. + + Supported events (`event_name`): + + * 'topic' - fired on topic change: `callback(user, topic)` or on + redirect: `callback(user, topic, redirect=target)` + + * 'uservar' - fired when user variable is set: callback(user, name, value) + + Pass `None` as `callback` to clear it. + """ + if not callback: + if event_name in self._callbacks: + self._say("Clearing callback for event '{}'".format(event_name)) + del self._callbacks[event_name] + else: + self._warn("Callback for event '{}' is not set".format(event_name)) + elif callable(callback): + self._say("Setting callback for event '{}'".format(event_name)) + self._callbacks[event_name] = callback + else: + self._warn("Refusing to set callback for event '{}': {} is not callable".format(event_name, callback)) + + def get_topic(self, user): + """Get user's current topic. + + :param str user: User name + :return: Topic or ``None`` if user does not exist. + """ + if user in self._users: + return self._users[user]['topic'] + else: + self._warn("Cannot get topic: user '{}' does not exist".format(user)) + return None + + def set_topic(self, user, topic): + """Set user's topic. + + :param str user: User name + :param str topic: New topic + """ + if topic in self._topics: + if not user in self._users: + self._users[user] = {} + self._say("Setting topic of user '{}' to '{}'".format(user, topic)) + self._users[user]['topic'] = topic + self._fire_event('topic', user, topic) + else: + self._warn("Cannot set topic for user '{}': topic '{}' does not exist".format(user,topic)) + + ############################################################################ # Reply Fetching Methods # ############################################################################ @@ -1810,12 +1880,8 @@ def reply(self, user, msg, errors_as_replies=True): reply = e.error_message # Save their reply history. - oldInput = self._users[user]['__history__']['input'][:8] - self._users[user]['__history__']['input'] = [msg] - self._users[user]['__history__']['input'].extend(oldInput) - oldReply = self._users[user]['__history__']['reply'][:8] - self._users[user]['__history__']['reply'] = [reply] - self._users[user]['__history__']['reply'].extend(oldReply) + self._users[user]['__history__']['input'].appendleft(msg) + self._users[user]['__history__']['reply'].appendleft(reply) # Unset the current user. self._current_user = None @@ -1904,16 +1970,8 @@ def _getreply(self, user, msg, context='normal', step=0, ignore_object_errors=Tr # Initialize this user's history. if '__history__' not in self._users[user]: self._users[user]['__history__'] = { - 'input': [ - 'undefined', 'undefined', 'undefined', 'undefined', - 'undefined', 'undefined', 'undefined', 'undefined', - 'undefined' - ], - 'reply': [ - 'undefined', 'undefined', 'undefined', 'undefined', - 'undefined', 'undefined', 'undefined', 'undefined', - 'undefined' - ] + 'input': deque(['undefined'] * self._history_length, maxlen=self._history_length), + 'reply': deque(['undefined'] * self._history_length, maxlen=self._history_length) } # More topic sanity checking. @@ -2134,14 +2192,13 @@ def _getreply(self, user, msg, context='normal', step=0, ignore_object_errors=Tr # later! reTopic = re.findall(RE.topic_tag, reply) for match in reTopic: - self._say("Setting user's topic to " + match) - self._users[user]["topic"] = match + self.set_topic(user, match) reply = reply.replace('{{topic={match}}}'.format(match=match), '') reSet = re.findall(RE.set_tag, reply) for match in reSet: self._say("Set uservar " + str(match[0]) + "=" + str(match[1])) - self._users[user][match[0]] = match[1] + self.set_uservar(user, match[0], match[1]) reply = reply.replace(''.format(key=match[0], value=match[1]), '') else: # Process more tags if not in BEGIN. @@ -2248,7 +2305,8 @@ def _expand_array(self, array_name): :return list: The final array contents. - Warning is issued when exceptions occur.""" + Warning is issued when exceptions occur. + """ ret = self._arrays[array_name] if array_name in self._arrays else [] try: ret = self._do_expand_array(array_name) @@ -2482,7 +2540,7 @@ def _process_tags(self, user, msg, reply, st=[], bst=[], depth=0, ignore_object_ # user vars. parts = data.split("=") self._say("Set uservar " + text_type(parts[0]) + "=" + text_type(parts[1])) - self._users[user][parts[0]] = parts[1] + self.set_uservar(user, parts[0], parts[1]) elif tag in ["add", "sub", "mult", "div"]: # Math operator tags. parts = data.split("=") @@ -2531,8 +2589,7 @@ def _process_tags(self, user, msg, reply, st=[], bst=[], depth=0, ignore_object_ # Topic setter. reTopic = re.findall(RE.topic_tag, reply) for match in reTopic: - self._say("Setting user's topic to " + match) - self._users[user]["topic"] = match + self.set_topic(user, match) reply = reply.replace('{{topic={match}}}'.format(match=match), '') # Inline redirecter. @@ -2542,6 +2599,7 @@ def _process_tags(self, user, msg, reply, st=[], bst=[], depth=0, ignore_object_ at = match.strip() subreply = self._getreply(user, at, step=(depth + 1)) reply = reply.replace('{{@{match}}}'.format(match=match), subreply) + self._fire_event('topic', user, self.get_topic(user), redirect=at) # Object caller. reply = reply.replace("{__call__}", "") @@ -2819,6 +2877,15 @@ def _strip_nasties(self, s): s = re.sub(RE.nasties, '', s) return s + def _fire_event(self, event_name, *args, **kwargs): + """Call callback function for event if it's been set.""" + self._say("Fire event '{}' with args: {} {}".format(event_name, args, kwargs)) + if event_name in self._callbacks: + try: + self._callbacks[event_name].__call__(*args, **kwargs) + except Exception as e: + self._warn("Error while executing callback for event '{}': {}".format(event_name, str(e))) + def _dump(self): """For debugging, dump the entire data structure.""" pp = pprint.PrettyPrinter(indent=4) diff --git a/tests/test_rivescript.py b/tests/test_rivescript.py index 026824d..735a132 100644 --- a/tests/test_rivescript.py +++ b/tests/test_rivescript.py @@ -402,6 +402,9 @@ def test_redirects(self): + hi there - {@hello} + + + howdy + - Howdy. """) for greeting in ["hello", "hey", "hi there"]: self.reply(greeting, "Hi there!") @@ -559,9 +562,12 @@ def test_punishment_topic(self): self.reply("Swear word!", "How rude! Apologize or I won't talk to you again.") self.reply("hello", "Say you're sorry!") self.reply("How are you?", "Say you're sorry!") + self.assertEqual(self.rs.get_topic(self.username), 'sorry') self.reply("Sorry!", "It's ok!") self.reply("hello", "Hi there!") self.reply("How are you?", "Catch-all.") + self.rs.set_topic(self.username, 'sorry') + self.reply("Hi there!", "Say you're sorry!") def test_topic_inheritence(self): @@ -864,5 +870,63 @@ def test_unicode_punctuation(self): self.reply("Hello, bot!", RS_ERR_MATCH) +class EventCallbackTest(RiveScriptTestCase): + """Test event callbacks.""" + + def test_on_topic_cb(self): + self.new(""" + + test + - Hi.{topic=test} + + + testred + - Hi {topic=testred}{@start} + + > topic test + + * + - Hello. + < topic + + > topic testred + + start + - there! + < topic + """) + def topic_cb(user, topic, redirect=None): + self.topic_cb_user = user + self.topic_cb_topic = topic + self.topic_cb_redirect = redirect + + self.rs.on('topic', topic_cb) + + self.reply("test", "Hi.") + self.assertEqual(self.topic_cb_user, self.username) + self.assertEqual(self.topic_cb_topic, 'test') + self.assertEqual(self.topic_cb_redirect, None) + + self.rs.set_topic(self.username, 'random') + self.reply("testred", "Hi there!") + self.assertEqual(self.topic_cb_user, self.username) + self.assertEqual(self.topic_cb_topic, 'testred') + self.assertEqual(self.topic_cb_redirect, 'start') + + + def test_on_uservar_cb(self): + self.new(""" + + test + - Hi. + """) + def uservar_cb(user, name, value): + self.var_cb_user = user + self.var_cb_name = name + self.var_cb_value = value + + self.rs.on('uservar', uservar_cb) + + self.reply("test", "Hi.") + self.assertEqual(self.var_cb_user, self.username) + self.assertEqual(self.var_cb_name, "test") + self.assertEqual(self.var_cb_value, "123") + + if __name__ == "__main__": unittest.main()