Source code for aiorpc.client

# -*- coding: utf-8 -*-
import asyncio
import msgpack
import logging

from aiorpc.connection import Connection
from aiorpc.log import rootLogger
from aiorpc.constants import MSGPACKRPC_RESPONSE, MSGPACKRPC_REQUEST
from aiorpc.exceptions import RPCProtocolError, RPCError, EnhancedRPCError

__all__ = ['RPCClient']

_logger = rootLogger.getChild(__name__)


[docs]class RPCClient: """RPC client. Usage: >>> from aiorpc.client import RPCClient >>> client = RPCClient('127.0.0.1', 6000) >>> import asyncio >>> loop = asyncio.get_event_loop() >>> loop.run_until_complete(client.call('sum', 1, 2)) :param str host: Hostname. :param int port: Port number. :param int timeout: (optional) Socket timeout. :param str pack_encoding: (optional) Character encoding used to pack data using Messagepack. :param str unpack_encoding: (optional) Character encoding used to unpack data using Messagepack. :param dict pack_params: (optional) Parameters to pass to Messagepack Packer :param dict unpack_params: (optional) Parameters to pass to Messagepack Unpacker. """ def __init__(self, host, port, *, timeout=3, loop=None, pack_encoding='utf-8', unpack_encoding='utf-8', pack_params=None, unpack_params=None): self._host = host self._port = port self._timeout = timeout self._loop = loop self._conn = None self._msg_id = 0 self._pack_encoding = pack_encoding self._pack_params = pack_params or dict() self._unpack_encoding = unpack_encoding self._unpack_params = unpack_params or dict(use_list=False)
[docs] def getpeername(self): """Return the address of the remote endpoint.""" return self._host, self._port
[docs] def close(self): try: self._conn.close() except AttributeError: pass
async def _open_connection(self): _logger.debug("connect to {}:{}...".format(self._host, self._port)) reader, writer = await asyncio.open_connection(self._host, self._port, loop=self._loop) self._conn = Connection(reader, writer, msgpack.Unpacker(encoding=self._unpack_encoding, **self._unpack_params)) _logger.debug("Connection to {}:{} established".format(self._host, self._port))
[docs] async def call(self, method, *args, _close=False): """Calls a RPC method. :param str method: Method name. :param args: Method arguments. :param _close: Close the connection at the end of the request. Defaults to false """ _logger.debug('creating request') req = self._create_request(method, args) if self._conn is None or self._conn.is_closed(): await self._open_connection() try: _logger.debug('Sending req: {}'.format(req)) await self._conn.sendall(req, self._timeout) _logger.debug('Sending complete') except asyncio.TimeoutError as te: _logger.error("Write request to {}:{} timeout".format(self._host, self._port)) raise te except Exception as e: raise e response = None try: _logger.debug('receiving result from server') response = await self._conn.recvall(self._timeout) _logger.debug('receiving result completed') except asyncio.TimeoutError as te: _logger.error("Read request to {}:{} timeout".format(self._host, self._port)) self._conn.reader.set_exception(te) raise te except Exception as e: self._conn.reader.set_exception(e) raise e if response is None: raise IOError("Connection closed") if type(response) != tuple: logging.debug('Protocol error, received unexpected data: {}'.format(response)) raise RPCProtocolError('Invalid protocol') if _close: self.close() return self._parse_response(response)
[docs] async def call_once(self, method, *args): """Call an RPC Method, then close the connection :param str method: Method name. :param args: Method arguments. :param _close: Close the connection at the end of the request. Defaults to false """ return await self.call(method, *args, _close=True)
def _create_request(self, method, args): self._msg_id += 1 req = (MSGPACKRPC_REQUEST, self._msg_id, method, args) return msgpack.packb(req, encoding=self._pack_encoding, **self._pack_params) def _parse_response(self, response): if (len(response) != 4 or response[0] != MSGPACKRPC_RESPONSE): raise RPCProtocolError('Invalid protocol') (_, msg_id, error, result) = response if msg_id != self._msg_id: raise RPCError('Invalid Message ID') if error and len(error) == 2: raise EnhancedRPCError(*error) elif error: raise RPCError(error) return result async def __aenter__(self): await self._open_connection() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self._conn and not self._conn.is_closed(): logging.debug('Closing connection from context manager') self.close()