Source code for aiorpc.server

# -*- coding: utf-8 -*-
import asyncio
import msgpack
import datetime

from aiorpc.constants import MSGPACKRPC_REQUEST, MSGPACKRPC_RESPONSE
from aiorpc.exceptions import MethodNotFoundError, RPCProtocolError, MethodRegisteredError
from aiorpc.connection import Connection
from aiorpc.log import rootLogger

__all__ = ['register', 'msgpack_init', 'set_timeout', 'serve', 'register_class']

_logger = rootLogger.getChild(__name__)
_methods = dict()
_class_methods = dict()
_pack_encoding = 'utf-8'
_pack_params = dict()
_unpack_encoding = 'utf-8'
_unpack_params = dict(use_list=False)
_timeout = 3


[docs]def register(name, f): """Register a function on the RPC server. Usage: >>> def sum(x, y): >>> return x + y >>> register('sum', sum) :param name: The remote name of the function, can be different with the f.__name__. :param f: Function object. Must be a callable object or a coroutine object. :return: None """ global _methods if not hasattr(f, "__call__"): raise MethodRegisteredError("{} is not a callable object".format(f.__name__)) if name in _methods: raise MethodRegisteredError("Name {} has already been used".format(name)) _methods[name] = f
[docs]def register_class(cls): """ Registers a class on the RPC server. Methods can be accessed by ClassName.Method :param cls: class to load :return: """ name = cls.__name__ _logger.info("Loaded class `{0}`".format(name)) if name in _class_methods: raise MethodRegisteredError("Class {} has already been loaded".format(name)) _class_methods[name] = cls()
[docs]def msgpack_init(**kwargs): """Init parameters of msgpack packer and unpacker. Usage: >>> msgpack_init(pack_encoding='utf-8') :param kwargs: See http://pythonhosted.org/msgpack-python/api.html default: pack_encoding='utf-8' pack_params=dict() unpack_encoding='utf-8' unpack_params=dict(use_list=False) :return: None """ global _pack_encoding, _pack_params, _unpack_encoding, _unpack_params _pack_encoding = kwargs.pop('pack_encoding', 'utf-8') _pack_params = kwargs.pop('pack_params', dict()) _unpack_encoding = kwargs.pop('unpack_encoding', 'utf-8') _unpack_params = kwargs.pop('unpack_params', dict(use_list=False))
[docs]def set_timeout(timeout): """Set the IO timeout Usage: >>> set_timeout(1) :param timeout: Timeout. Seconds. :return: None """ global _timeout _timeout = timeout
async def _send_error(conn, exception, error, msg_id): response = (MSGPACKRPC_RESPONSE, msg_id, (exception, error), None) try: await conn.sendall(msgpack.packb(response, encoding=_pack_encoding, **_pack_params), _timeout) except asyncio.TimeoutError as te: _logger.error("Timeout when _send_error {} to {}".format( error, conn.writer.get_extra_info('peername'))) except Exception as e: _logger.error("Exception {} raised when _send_error {} to {}".format( str(e), error, conn.writer.get_extra_info("peername") )) async def _send_result(conn, result, msg_id): _logger.debug('entering _send_result') response = (MSGPACKRPC_RESPONSE, msg_id, None, result) try: _logger.debug('begin to sendall') ret = msgpack.packb(response, encoding=_pack_encoding, **_pack_params) await conn.sendall(ret, _timeout) _logger.debug('sendall completed') except asyncio.TimeoutError as te: _logger.error("Timeout when _send_result {} to {}".format( str(result), conn.writer.get_extra_info('peername'))) except Exception as e: _logger.error("Exception {} raised when _send_result {} to {}".format( str(e), str(result), conn.writer.get_extra_info("peername") )) def _parse_request(req): if len(req) != 4 or req[0] != MSGPACKRPC_REQUEST: raise RPCProtocolError('Invalid protocol') _, msg_id, method_name, args = req _method_soup = method_name.split('.') if len(_method_soup) == 1: method = _methods.get(method_name) else: method = getattr(_class_methods.get(_method_soup[0]), _method_soup[1]) if not method: raise MethodNotFoundError("No such method {}".format(method_name)) return msg_id, method, args, method_name
[docs]async def serve(reader, writer): """Serve function. Don't use this outside asyncio.start_server. """ global _unpack_encoding, _unpack_params _logger.debug('enter serve: {}'.format(writer.get_extra_info('peername'))) conn = Connection(reader, writer, msgpack.Unpacker(encoding=_unpack_encoding, **_unpack_params)) while not conn.is_closed(): req = None try: req = await conn.recvall(_timeout) except asyncio.TimeoutError as te: await asyncio.sleep(3) _logger.warning("Client did not send any data before timeout. Closing connection...") conn.close() continue except IOError as ie: break except Exception as e: conn.reader.set_exception(e) raise e if not isinstance(req, tuple): try: await _send_error(conn, "Invalid protocol", -1) # skip the rest of iteration code after sending error continue except Exception as e: _logger.error("Error when receiving req: {}".format(str(e))) req_start = datetime.datetime.now() method = None msg_id = None args = None try: _logger.debug('parsing req: {}'.format(str(req))) msg_id, method, args, method_name = _parse_request(req) _logger.debug('parsing completed: {}'.format(str(req))) except Exception as e: _logger.error("Exception {} raised when _parse_request {}".format(str(e), req)) # skip the rest of iteration code since we already got an error continue # Execute the parsed request try: _logger.debug('calling method: {}'.format(str(method))) ret = method.__call__(*args) if asyncio.iscoroutine(ret): _logger.debug("start to wait_for") ret = await asyncio.wait_for(ret, _timeout) _logger.debug('calling {} completed. result: {}'.format(str(method), str(ret))) except Exception as e: _logger.error("Caught Exception in `{0}`. {1}: {2}".format(method_name, type(e).__name__, str(e))) await _send_error(conn, type(e).__name__, str(e), msg_id) _logger.debug('sending exception {} completed'.format(str(e))) else: _logger.debug('sending result: {}'.format(str(ret))) await _send_result(conn, ret, msg_id) _logger.debug('sending result {} completed'.format(str(ret))) req_end = datetime.datetime.now() _logger.info("Method `{0}` took {1}ms".format(method_name, (req_end - req_start).microseconds / 1000))