Skip to content

Commit

Permalink
refs #32, #34: Update Theano kernel to use PUSH/PULL protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Feb 25, 2017
1 parent d1e08f3 commit 44231a7
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 120 deletions.
248 changes: 131 additions & 117 deletions python3-theano/run.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,62 @@
#! /usr/bin/env python3
#! /usr/bin/env python

import builtins as builtin_mod
import code
import io
import enum
from functools import partial
import logging
from namedlist import namedtuple, namedlist
import os
from os import path
import sys
import time
import traceback
import types
import zmq
try:
import simplejson
has_simplejson = True
except ImportError:
has_simplejson = False

import sorna.drawing
import simplejson as json
import zmq

ExceptionInfo = namedtuple('ExceptionInfo', [
'exc',
('args', tuple()),
('raised_before_exec', False),
('traceback', None),
])
import getpass

Result = namedlist('Result', [
('stdout', ''),
('stderr', ''),
('media', None),
])
from sorna.types import (
InputRequest, ControlRecord, ConsoleRecord, MediaRecord, HTMLRecord,
)

log = logging.getLogger('code-runner')

@staticmethod
def _create_excinfo(e, raised_before_exec, tb):
assert isinstance(e, Exception)
return ExceptionInfo(type(e).__name__, e.args, raised_before_exec, tb)
ExceptionInfo.create = _create_excinfo

class StreamToEmitter:

class SockWriter(object):
def __init__(self, sock, cell_id):
self.cell_id_encoded = '{0}'.format(cell_id).encode('ascii')
self.sock = sock
self.buffer = io.StringIO()
def __init__(self, emitter, stream_type):
self.emit = emitter
self.stream_type = stream_type

def write(self, s):
if '\n' in s: # flush on occurrence of a newline.
s1, s2 = s.split('\n', maxsplit=1)
s0 = self.buffer.getvalue()
self.sock.send_multipart([self.cell_id_encoded, (s0 + s1 + '\n').encode('utf8')])
self.buffer.seek(0)
self.buffer.truncate(0)
self.buffer.write(s2)
else:
self.buffer.write(s)
if self.buffer.tell() > 1024: # flush if the buffer is too large.
s0 = self.buffer.getvalue()
self.sock.send_multipart([self.cell_id_encoded, s0.encode('utf8')])
self.buffer.seek(0)
self.buffer.truncate(0)
# TODO: timeout to flush?
self.emit(ConsoleRecord(self.stream_type, s))

def flush(self):
pass


class CodeRunner(object):
class CodeRunner:
'''
A thin wrapper for REPL.
It creates a dummy module that user codes run and keeps the references to user-created objects
(e.g., variables and functions).
'''

def __init__(self):
self.stdout_buffer = io.StringIO()
self.stderr_buffer = io.StringIO()
def __init__(self, api_version=1):
self.api_version = api_version
self.input_supported = (api_version >= 2)

ctx = zmq.Context.instance()
self.input_stream = ctx.socket(zmq.PULL)
self.input_stream.bind('tcp://*:2000')
self.output_stream = ctx.socket(zmq.PUSH)
self.output_stream.bind('tcp://*:2001')

self.stdout_emitter = StreamToEmitter(self.emit, 'stdout')
self.stderr_emitter = StreamToEmitter(self.emit, 'stderr')

# Initialize user module and namespaces.
user_module = types.ModuleType('__main__',
Expand All @@ -82,83 +66,113 @@ def __init__(self):
self.user_module = user_module
self.user_ns = user_module.__dict__

def execute(self, cell_id, src):
self.stdout_writer = self.stdout_buffer
self.stderr_writer = self.stderr_buffer
sys.stdout, orig_stdout = self.stdout_writer, sys.stdout
sys.stderr, orig_stderr = self.stderr_writer, sys.stderr

exceptions = []
result = Result()
before_exec = True

def my_excepthook(type_, value, tb):
exceptions.append(ExceptionInfo.create(value, before_exec, tb))
sys.excepthook = my_excepthook

try:
code_obj = code.compile_command(src, symbol='exec')
except IndentationError as e:
exceptions.append(ExceptionInfo.create(e, before_exec, None))
except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError) as e:
exceptions.append(ExceptionInfo.create(e, before_exec, None))
def handle_input(self, prompt=None, password=False):
if prompt is None:
prompt = 'Password: ' if password else ''
self.emit(ConsoleRecord('stdout', prompt))
self.emit(InputRequest(is_password=password))
data = self.input_stream.recv_multipart()
return data[1].decode('utf8')

def emit(self, record):
if isinstance(record, ConsoleRecord):
assert record.target in ('stdout', 'stderr')
self.output_stream.send_multipart([
record.target.encode('ascii'),
record.data.encode('utf8'),
])
elif isinstance(record, MediaRecord):
self.output_stream.send_multipart([
b'media',
json.dumps({
'type': record.type,
'data': record.data,
}).encode('utf8'),
])
elif isinstance(record, HTMLRecord):
self.output_stream.send_multipart([
b'html',
record.html.encode('utf8'),
])
elif isinstance(record, InputRequest):
self.output_stream.send_multipart([
b'waiting-input',
json.dumps({
'is_password': record.is_password,
}).encode('utf8'),
])
elif isinstance(record, ControlRecord):
self.output_stream.send_multipart([
record.event.encode('ascii'),
b'',
])
else:
self.user_module.__builtins__._sorna_media = []
before_exec = False
raise TypeError('Unsupported record type.')

@staticmethod
def strip_traceback(tb):
while tb is not None:
frame_summary = traceback.extract_tb(tb, limit=1)[0]
if frame_summary[0] == '<input>':
break
tb = tb.tb_next
return tb

def run(self):
json_opts = {'namedtuple_as_object': False}
while True:
data = self.input_stream.recv_multipart()
code_id = data[0].decode('ascii')
code_text = data[1].decode('utf8')
self.user_module.__builtins__._sorna_emit = self.emit
if self.input_supported:
self.user_module.__builtins__.input = self.handle_input
getpass.getpass = partial(self.handle_input, password=True)
try:
exec(code_obj, self.user_ns)
except Exception as e:
exceptions.append(ExceptionInfo.create(e, before_exec, None))

sys.excepthook = sys.__excepthook__

result.stdout = self.stdout_writer.getvalue()
result.stderr = self.stderr_writer.getvalue()
# TODO: sanitize media?
result.media = self.user_module.__builtins__._sorna_media
self.stdout_writer.seek(0, io.SEEK_SET)
self.stdout_writer.truncate(0)
self.stderr_writer.seek(0, io.SEEK_SET)
self.stderr_writer.truncate(0)
code_obj = code.compile_command(code_text, symbol='exec')
except (OverflowError, IndentationError, SyntaxError,
ValueError, TypeError, MemoryError) as e:
exc_type, exc_val, tb = sys.exc_info()
user_tb = type(self).strip_traceback(tb)
err_str = ''.join(traceback.format_exception(exc_type, exc_val, user_tb))
hdr_str = 'Traceback (most recent call last):\n' if not err_str.startswith('Traceback ') else ''
self.emit(ConsoleRecord('stderr', hdr_str + err_str))
self.emit(ControlRecord('finished'))
else:
sys.stdout, orig_stdout = self.stdout_emitter, sys.stdout
sys.stderr, orig_stderr = self.stderr_emitter, sys.stderr
try:
exec(code_obj, self.user_ns)
except Exception as e:
# strip the first frame
exc_type, exc_val, tb = sys.exc_info()
user_tb = type(self).strip_traceback(tb)
traceback.print_exception(exc_type, exc_val, user_tb)
finally:
self.emit(ControlRecord('finished'))
sys.stdout = orig_stdout
sys.stderr = orig_stderr


def main():
log = logging.getLogger('main')

sys.stdout = orig_stdout
sys.stderr = orig_stderr
return exceptions, result


if __name__ == '__main__':
# Use the "confined" working directory
os.chdir('/home/work')
# Replace stdin with a "null" file
# (trying to read stdin will raise EOFError immediately afterwards.)
sys.stdin = open(os.devnull, 'rb')

# Initialize context object.
runner = CodeRunner()

# Initialize minimal ZMQ server socket.
ctx = zmq.Context(io_threads=1)
sock = ctx.socket(zmq.REP)
sock.bind('tcp://*:2001')
print('serving at port 2001...')

runner = CodeRunner(api_version=2)
try:
while True:
data = sock.recv_multipart()
exceptions, result = runner.execute(data[0].decode('ascii'),
data[1].decode('utf8'))
response = {
'stdout': result.stdout,
'stderr': result.stderr,
'media': result.media,
'exceptions': exceptions,
}
json_opts = {}
if has_simplejson:
json_opts['namedtuple_as_object'] = False
sock.send_json(response, **json_opts)
runner.run()
except (KeyboardInterrupt, SystemExit):
pass
except:
log.exception('unexpected error')
finally:
sock.close()
print('exit.')


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
main()
3 changes: 0 additions & 3 deletions python3-theano/sorna_media-0.2.0-py2.py3-none-any.whl

This file was deleted.

3 changes: 3 additions & 0 deletions python3-theano/sorna_media-0.3.0-py2.py3-none-any.whl
Git LFS file not shown

0 comments on commit 44231a7

Please sign in to comment.