first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from zmq.backend.cython cimport libzmq
from zmq.backend.cython.context cimport Context
from zmq.backend.cython.message cimport Frame
from zmq.backend.cython.socket cimport Socket

View File

@@ -0,0 +1,165 @@
"""Python bindings for 0MQ."""
"""""" # start delvewheel patch
def _delvewheel_init_patch_0_0_12():
import os
import sys
libs_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'pyzmq.libs'))
if sys.version_info[:2] >= (3, 8):
os.add_dll_directory(libs_dir)
else:
from ctypes import WinDLL
with open(os.path.join(libs_dir, '.load-order-pyzmq-23.0.0')) as file:
load_order = file.read().split()
for lib in load_order:
WinDLL(os.path.join(libs_dir, lib))
_delvewheel_init_patch_0_0_12()
del _delvewheel_init_patch_0_0_12
# end delvewheel patch
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
# load bundled libzmq, if there is one:
import os
import sys
from contextlib import contextmanager
def _load_libzmq():
"""load bundled libzmq if there is one"""
import platform
dlopen = hasattr(sys, 'getdlopenflags') # unix-only
# RTLD flags are added to os in Python 3
# get values from os because ctypes values are WRONG on pypy
PYPY = platform.python_implementation().lower() == 'pypy'
if dlopen:
import ctypes
dlflags = sys.getdlopenflags()
# set RTLD_GLOBAL, unset RTLD_LOCAL
flags = ctypes.RTLD_GLOBAL | dlflags
# ctypes.RTLD_LOCAL is 0 on pypy, which is *wrong*
flags &= ~getattr(os, 'RTLD_LOCAL', 4)
# pypy on darwin needs RTLD_LAZY for some reason
if PYPY and sys.platform == 'darwin':
flags |= getattr(os, 'RTLD_LAZY', 1)
flags &= ~getattr(os, 'RTLD_NOW', 2)
sys.setdlopenflags(flags)
try:
from . import libzmq
except ImportError:
# raise on failure to load if libzmq is present
from importlib.util import find_spec
if find_spec(".libzmq", "zmq"):
# found libzmq, but failed to load it!
# raise instead of silently moving on
raise
else:
# store libzmq as zmq._libzmq for backward-compat
globals()['_libzmq'] = libzmq
if PYPY:
# should already have been imported above, so reimporting is as cheap as checking
import ctypes
# some versions of pypy (5.3 < ? < 5.8) needs explicit CDLL load for some reason,
# otherwise symbols won't be globally available
# do this unconditionally because it should be harmless (?)
ctypes.CDLL(libzmq.__file__, ctypes.RTLD_GLOBAL)
finally:
if dlopen:
sys.setdlopenflags(dlflags)
_load_libzmq()
@contextmanager
def _libs_on_path():
"""context manager for libs directory on $PATH
Works around mysterious issue where os.add_dll_directory
does not resolve imports (conda-forge Python >= 3.8)
"""
if not sys.platform.startswith("win"):
yield
return
libs_dir = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.pardir,
"pyzmq.libs",
)
)
if not os.path.exists(libs_dir):
# no bundled libs
yield
return
path_before = os.environ.get("PATH")
try:
os.environ["PATH"] = os.pathsep.join([path_before or "", libs_dir])
yield
finally:
if path_before is None:
os.environ.pop("PATH")
else:
os.environ["PATH"] = path_before
# zmq top-level imports
# workaround for Windows
with _libs_on_path():
from zmq import backend
from . import constants # noqa
from .constants import * # noqa
from zmq.backend import * # noqa
from zmq import sugar
from zmq.sugar import * # noqa
def get_includes():
"""Return a list of directories to include for linking against pyzmq with cython."""
from os.path import abspath, dirname, exists, join, pardir
base = dirname(__file__)
parent = abspath(join(base, pardir))
includes = [parent] + [join(parent, base, subdir) for subdir in ('utils',)]
if exists(join(parent, base, 'include')):
includes.append(join(parent, base, 'include'))
return includes
def get_library_dirs():
"""Return a list of directories used to link against pyzmq's bundled libzmq."""
from os.path import abspath, dirname, join, pardir
base = dirname(__file__)
parent = abspath(join(base, pardir))
return [join(parent, base)]
COPY_THRESHOLD = 65536
DRAFT_API = backend.has("draft")
__all__ = (
[
'get_includes',
'COPY_THRESHOLD',
'DRAFT_API',
]
+ sugar.__all__
+ backend.__all__
)

View File

@@ -0,0 +1,29 @@
from typing import List
from . import backend, sugar
COPY_THRESHOLD: int
DRAFT_API: bool
__version__: str
# mypy doesn't like overwriting symbols with * so be explicit
# about what comes from backend, not from sugar
# see tools/backend_imports.py to generate this list
# note: `x as x` is required for re-export
# see https://github.com/python/mypy/issues/2190
from .backend import IPC_PATH_MAX_LEN as IPC_PATH_MAX_LEN
from .backend import curve_keypair as curve_keypair
from .backend import curve_public as curve_public
from .backend import device as device
from .backend import has as has
from .backend import proxy as proxy
from .backend import proxy_steerable as proxy_steerable
from .backend import strerror as strerror
from .backend import zmq_errno as zmq_errno
from .backend import zmq_poll as zmq_poll
from .constants import *
from .error import *
from .sugar import *
def get_includes() -> List[str]: ...
def get_library_dirs() -> List[str]: ...

View File

@@ -0,0 +1,698 @@
"""Future-returning APIs for coroutines."""
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import warnings
from asyncio import Future
from collections import deque
from itertools import chain
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
from zmq._typing import Literal
class _FutureEvent(NamedTuple):
future: Future
kind: str
kwargs: Dict
msg: Any
timer: Any
# These are incomplete classes and need a Mixin for compatibility with an eventloop
# defining the following attributes:
#
# _Future
# _READ
# _WRITE
# _default_loop()
class _Async:
"""Mixin for common async logic"""
_current_loop: Any = None
_Future: Type[Future]
def _get_loop(self) -> Any:
"""Get event loop
Notice if event loop has changed,
and register init_io_state on activation of a new event loop
"""
if self._current_loop is None:
self._current_loop = self._default_loop()
self._init_io_state(self._current_loop)
return self._current_loop
current_loop = self._default_loop()
if current_loop is not self._current_loop:
# warn? This means a socket is being used in multiple loops!
self._current_loop = current_loop
self._init_io_state(current_loop)
return current_loop
def _default_loop(self) -> Any:
raise NotImplementedError("Must be implemented in a subclass")
def _init_io_state(self, loop=None) -> None:
pass
class _AsyncPoller(_Async, _zmq.Poller):
"""Poller that returns a Future on poll, instead of blocking."""
_socket_class: Type["_AsyncSocket"]
_READ: int
_WRITE: int
raw_sockets: List[Any]
def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
"""Schedule callback for a raw socket"""
raise NotImplementedError()
def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
"""Unschedule callback for a raw socket"""
raise NotImplementedError()
def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore
"""Return a Future for a poll event"""
future = self._Future()
if timeout == 0:
try:
result = super().poll(0)
except Exception as e:
future.set_exception(e)
else:
future.set_result(result)
return future
loop = self._get_loop()
# register Future to be called as soon as any event is available on any socket
watcher = self._Future()
# watch raw sockets:
raw_sockets: List[Any] = []
def wake_raw(*args):
if not watcher.done():
watcher.set_result(None)
watcher.add_done_callback(
lambda f: self._unwatch_raw_sockets(loop, *raw_sockets)
)
for socket, mask in self.sockets:
if isinstance(socket, _zmq.Socket):
if not isinstance(socket, self._socket_class):
# it's a blocking zmq.Socket, wrap it in async
socket = self._socket_class.from_socket(socket)
if mask & _zmq.POLLIN:
socket._add_recv_event('poll', future=watcher)
if mask & _zmq.POLLOUT:
socket._add_send_event('poll', future=watcher)
else:
raw_sockets.append(socket)
evt = 0
if mask & _zmq.POLLIN:
evt |= self._READ
if mask & _zmq.POLLOUT:
evt |= self._WRITE
self._watch_raw_socket(loop, socket, evt, wake_raw)
def on_poll_ready(f):
if future.done():
return
if watcher.cancelled():
try:
future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
return
if watcher.exception():
future.set_exception(watcher.exception())
else:
try:
result = super(_AsyncPoller, self).poll(0)
except Exception as e:
future.set_exception(e)
else:
future.set_result(result)
watcher.add_done_callback(on_poll_ready)
if timeout is not None and timeout > 0:
# schedule cancel to fire on poll timeout, if any
def trigger_timeout():
if not watcher.done():
watcher.set_result(None)
timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout)
def cancel_timeout(f):
if hasattr(timeout_handle, 'cancel'):
timeout_handle.cancel()
else:
loop.remove_timeout(timeout_handle)
future.add_done_callback(cancel_timeout)
def cancel_watcher(f):
if not watcher.done():
watcher.cancel()
future.add_done_callback(cancel_watcher)
return future
class _NoTimer:
@staticmethod
def cancel():
pass
T = TypeVar("T", bound="_AsyncSocket")
class _AsyncSocket(_Async, _zmq.Socket):
# Warning : these class variables are only here to allow to call super().__setattr__.
# They be overridden at instance initialization and not shared in the whole class
_recv_futures = None
_send_futures = None
_state = 0
_shadow_sock: "_zmq.Socket"
_poller_class = _AsyncPoller
_fd = None
def __init__(
self,
context=None,
socket_type=-1,
io_loop=None,
_from_socket: Optional["_zmq.Socket"] = None,
**kwargs,
) -> None:
if isinstance(context, _zmq.Socket):
context, _from_socket = (None, context)
if _from_socket is not None:
super().__init__(shadow=_from_socket.underlying)
self._shadow_sock = _from_socket
else:
super().__init__(context, socket_type, **kwargs)
self._shadow_sock = _zmq.Socket.shadow(self.underlying)
if io_loop is not None:
warnings.warn(
f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
" The currently active loop will always be used.",
DeprecationWarning,
stacklevel=3,
)
self._recv_futures = deque()
self._send_futures = deque()
self._state = 0
self._fd = self._shadow_sock.FD
@classmethod
def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T:
"""Create an async socket from an existing Socket"""
return cls(_from_socket=socket, io_loop=io_loop)
def close(self, linger: Optional[int] = None) -> None:
if not self.closed and self._fd is not None:
event_list: List[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
)
for event in event_list:
if not event.future.done():
try:
event.future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
self._clear_io_state()
super().close(linger=linger)
close.__doc__ = _zmq.Socket.close.__doc__
def get(self, key):
result = super().get(key)
if key == EVENTS:
self._schedule_remaining_events(result)
return result
get.__doc__ = _zmq.Socket.get.__doc__
@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[List[bytes]]:
...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[List[bytes]]:
...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[List[_zmq.Frame]]: # type: ignore
...
@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
...
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
"""Receive a complete multipart zmq message.
Returns a Future whose result will be a multipart message.
"""
return self._add_recv_event(
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)
def recv( # type: ignore
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[Union[bytes, _zmq.Frame]]:
"""Receive a single zmq frame.
Returns a Future, whose result will be the received frame.
Recommend using recv_multipart instead.
"""
return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track))
def send_multipart( # type: ignore
self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
) -> Awaitable[Optional[_zmq.MessageTracker]]:
"""Send a complete multipart zmq message.
Returns a Future that resolves when sending is complete.
"""
kwargs['flags'] = flags
kwargs['copy'] = copy
kwargs['track'] = track
return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
def send( # type: ignore
self,
data: Any,
flags: int = 0,
copy: bool = True,
track: bool = False,
**kwargs: Any,
) -> Awaitable[Optional[_zmq.MessageTracker]]:
"""Send a single zmq frame.
Returns a Future that resolves when sending is complete.
Recommend using send_multipart instead.
"""
kwargs['flags'] = flags
kwargs['copy'] = copy
kwargs['track'] = track
kwargs.update(dict(flags=flags, copy=copy, track=track))
return self._add_send_event('send', msg=data, kwargs=kwargs)
def _deserialize(self, recvd, load):
"""Deserialize with Futures"""
f = self._Future()
def _chain(_):
"""Chain result through serialization to recvd"""
if f.done():
return
if recvd.exception():
f.set_exception(recvd.exception())
else:
buf = recvd.result()
try:
loaded = load(buf)
except Exception as e:
f.set_exception(e)
else:
f.set_result(loaded)
recvd.add_done_callback(_chain)
def _chain_cancel(_):
"""Chain cancellation from f to recvd"""
if recvd.done():
return
if f.cancelled():
recvd.cancel()
f.add_done_callback(_chain_cancel)
return f
def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
"""poll the socket for events
returns a Future for the poll results.
"""
if self.closed:
raise _zmq.ZMQError(_zmq.ENOTSUP)
p = self._poller_class()
p.register(self, flags)
f = cast(Future, p.poll(timeout))
future = self._Future()
def unwrap_result(f):
if future.done():
return
if f.cancelled():
try:
future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
return
if f.exception():
future.set_exception(f.exception())
else:
evts = dict(f.result())
future.set_result(evts.get(self, 0))
if f.done():
# hook up result if
unwrap_result(f)
else:
f.add_done_callback(unwrap_result)
return future
# overrides only necessary for updated types
def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
return super().recv_string(*args, **kwargs) # type: ignore
def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore
return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
def _add_timeout(self, future, timeout):
"""Add a timeout for a send or recv Future"""
def future_timeout():
if future.done():
# future already resolved, do nothing
return
# raise EAGAIN
future.set_exception(_zmq.Again())
return self._call_later(timeout, future_timeout)
def _call_later(self, delay, callback):
"""Schedule a function to be called later
Override for different IOLoop implementations
Tornado and asyncio happen to both have ioloop.call_later
with the same signature.
"""
return self._get_loop().call_later(delay, callback)
@staticmethod
def _remove_finished_future(future, event_list):
"""Make sure that futures are removed from the event list when they resolve
Avoids delaying cleanup until the next send/recv event,
which may never come.
"""
for f_idx, event in enumerate(event_list):
if event.future is future:
break
else:
return
# "future" instance is shared between sockets, but each socket has its own event list.
event_list.remove(event_list[f_idx])
def _add_recv_event(self, kind, kwargs=None, future=None):
"""Add a recv event, returning the corresponding Future"""
f = future or self._Future()
if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
# short-circuit non-blocking calls
recv = getattr(self._shadow_sock, kind)
try:
r = recv(**kwargs)
except Exception as e:
f.set_exception(e)
else:
f.set_result(r)
return f
timer = _NoTimer
if hasattr(_zmq, 'RCVTIMEO'):
timeout_ms = self._shadow_sock.rcvtimeo
if timeout_ms >= 0:
timer = self._add_timeout(f, timeout_ms * 1e-3)
# we add it to the list of futures before we add the timeout as the
# timeout will remove the future from recv_futures to avoid leaks
self._recv_futures.append(_FutureEvent(f, kind, kwargs, msg=None, timer=timer))
# Don't let the Future sit in _recv_events after it's done
f.add_done_callback(
lambda f: self._remove_finished_future(f, self._recv_futures)
)
if self._shadow_sock.get(EVENTS) & POLLIN:
# recv immediately, if we can
self._handle_recv()
if self._recv_futures:
self._add_io_state(POLLIN)
return f
def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
"""Add a send event, returning the corresponding Future"""
f = future or self._Future()
# attempt send with DONTWAIT if no futures are waiting
# short-circuit for sends that will resolve immediately
# only call if no send Futures are waiting
if kind in ('send', 'send_multipart') and not self._send_futures:
flags = kwargs.get('flags', 0)
nowait_kwargs = kwargs.copy()
nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
# short-circuit non-blocking calls
send = getattr(self._shadow_sock, kind)
# track if the send resolved or not
# (EAGAIN if DONTWAIT is not set should proceed with)
finish_early = True
try:
r = send(msg, **nowait_kwargs)
except _zmq.Again as e:
if flags & _zmq.DONTWAIT:
f.set_exception(e)
else:
# EAGAIN raised and DONTWAIT not requested,
# proceed with async send
finish_early = False
except Exception as e:
f.set_exception(e)
else:
f.set_result(r)
if finish_early:
# short-circuit resolved, return finished Future
# schedule wake for recv if there are any receivers waiting
if self._recv_futures:
self._schedule_remaining_events()
return f
timer = _NoTimer
if hasattr(_zmq, 'SNDTIMEO'):
timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
if timeout_ms >= 0:
timer = self._add_timeout(f, timeout_ms * 1e-3)
# we add it to the list of futures before we add the timeout as the
# timeout will remove the future from recv_futures to avoid leaks
self._send_futures.append(
_FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer)
)
# Don't let the Future sit in _send_futures after it's done
f.add_done_callback(
lambda f: self._remove_finished_future(f, self._send_futures)
)
self._add_io_state(POLLOUT)
return f
def _handle_recv(self):
"""Handle recv events"""
if not self._shadow_sock.get(EVENTS) & POLLIN:
# event triggered, but state may have been changed between trigger and callback
return
f = None
while self._recv_futures:
f, kind, kwargs, _, timer = self._recv_futures.popleft()
# skip any cancelled futures
if f.done():
f = None
else:
break
if not self._recv_futures:
self._drop_io_state(POLLIN)
if f is None:
return
timer.cancel()
if kind == 'poll':
# on poll event, just signal ready, nothing else.
f.set_result(None)
return
elif kind == 'recv_multipart':
recv = self._shadow_sock.recv_multipart
elif kind == 'recv':
recv = self._shadow_sock.recv
else:
raise ValueError("Unhandled recv event type: %r" % kind)
kwargs['flags'] |= _zmq.DONTWAIT
try:
result = recv(**kwargs)
except Exception as e:
f.set_exception(e)
else:
f.set_result(result)
def _handle_send(self):
if not self._shadow_sock.get(EVENTS) & POLLOUT:
# event triggered, but state may have been changed between trigger and callback
return
f = None
while self._send_futures:
f, kind, kwargs, msg, timer = self._send_futures.popleft()
# skip any cancelled futures
if f.done():
f = None
else:
break
if not self._send_futures:
self._drop_io_state(POLLOUT)
if f is None:
return
timer.cancel()
if kind == 'poll':
# on poll event, just signal ready, nothing else.
f.set_result(None)
return
elif kind == 'send_multipart':
send = self._shadow_sock.send_multipart
elif kind == 'send':
send = self._shadow_sock.send
else:
raise ValueError("Unhandled send event type: %r" % kind)
kwargs['flags'] |= _zmq.DONTWAIT
try:
result = send(msg, **kwargs)
except Exception as e:
f.set_exception(e)
else:
f.set_result(result)
# event masking from ZMQStream
def _handle_events(self, fd=0, events=0):
"""Dispatch IO events to _handle_recv, etc."""
zmq_events = self._shadow_sock.get(EVENTS)
if zmq_events & _zmq.POLLIN:
self._handle_recv()
if zmq_events & _zmq.POLLOUT:
self._handle_send()
self._schedule_remaining_events()
def _schedule_remaining_events(self, events=None):
"""Schedule a call to handle_events next loop iteration
If there are still events to handle.
"""
# edge-triggered handling
# allow passing events in, in case this is triggered by retrieving events,
# so we don't have to retrieve it twice.
if self._state == 0:
# not watching for anything, nothing to schedule
return
if events is None:
events = self._shadow_sock.get(EVENTS)
if events & self._state:
self._call_later(0, self._handle_events)
def _add_io_state(self, state):
"""Add io_state to poller."""
if self._state != state:
state = self._state = self._state | state
self._update_handler(self._state)
def _drop_io_state(self, state):
"""Stop poller from watching an io_state."""
if self._state & state:
self._state = self._state & (~state)
self._update_handler(self._state)
def _update_handler(self, state):
"""Update IOLoop handler with state.
zmq FD is always read-only.
"""
# ensure loop is registered and init_io has been called
# if there are any events to watch for
if state:
self._get_loop()
self._schedule_remaining_events()
def _init_io_state(self, loop=None):
"""initialize the ioloop event handler"""
if loop is None:
loop = self._get_loop()
loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
self._call_later(0, self._handle_events)
def _clear_io_state(self):
"""unregister the ioloop event handler
called once during close
"""
fd = self._shadow_sock
if self._shadow_sock.closed:
fd = self._fd
if self._current_loop is not None:
self._current_loop.remove_handler(fd)

View File

@@ -0,0 +1,19 @@
import sys
from typing import Any, Dict
if sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
# avoid runtime dependency on typing_extensions on py37
try:
from typing_extensions import Literal, TypedDict # type: ignore
except ImportError:
class _Literal:
def __getitem__(self, key):
return Any
Literal = _Literal() # type: ignore
class TypedDict(Dict): # type: ignore
pass

View File

@@ -0,0 +1,213 @@
"""AsyncIO support for zmq
Requires asyncio and Python 3.
"""
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import asyncio
import selectors
import sys
import warnings
from asyncio import Future, SelectorEventLoop
from weakref import WeakKeyDictionary
import zmq as _zmq
from zmq import _future
# registry of asyncio loop : selector thread
_selectors: WeakKeyDictionary = WeakKeyDictionary()
class ProactorSelectorThreadWarning(RuntimeWarning):
"""Warning class for notifying about the extra thread spawned by tornado
We automatically support proactor via tornado's AddThreadSelectorEventLoop"""
def _get_selector_windows(
asyncio_loop,
) -> asyncio.AbstractEventLoop:
"""Get selector-compatible loop
Returns an object with ``add_reader`` family of methods,
either the loop itself or a SelectorThread instance.
Workaround Windows proactor removal of
*reader methods, which we need for zmq sockets.
"""
if asyncio_loop in _selectors:
return _selectors[asyncio_loop]
# detect add_reader instead of checking for proactor?
if hasattr(asyncio, "ProactorEventLoop") and isinstance(
asyncio_loop, asyncio.ProactorEventLoop # type: ignore
):
try:
from tornado.platform.asyncio import AddThreadSelectorEventLoop
except ImportError:
raise RuntimeError(
"Proactor event loop does not implement add_reader family of methods required for zmq."
" zmq will work with proactor if tornado >= 6.1 can be found."
" Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
" or install 'tornado>=6.1' to avoid this error."
)
warnings.warn(
"Proactor event loop does not implement add_reader family of methods required for zmq."
" Registering an additional selector thread for add_reader support via tornado."
" Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
" to avoid this warning.",
RuntimeWarning,
# stacklevel 5 matches most likely zmq.asyncio.Context().socket()
stacklevel=5,
)
selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop(asyncio_loop) # type: ignore
# patch loop.close to also close the selector thread
loop_close = asyncio_loop.close
def _close_selector_and_loop():
# restore original before calling selector.close,
# which in turn calls eventloop.close!
asyncio_loop.close = loop_close
_selectors.pop(asyncio_loop, None)
selector_loop.close()
asyncio_loop.close = _close_selector_and_loop
return selector_loop
else:
return asyncio_loop
def _get_selector_noop(loop) -> asyncio.AbstractEventLoop:
"""no-op on non-Windows"""
return loop
if sys.platform == "win32":
_get_selector = _get_selector_windows
else:
_get_selector = _get_selector_noop
class _AsyncIO:
_Future = Future
_WRITE = selectors.EVENT_WRITE
_READ = selectors.EVENT_READ
def _default_loop(self):
if sys.version_info >= (3, 7):
try:
return asyncio.get_running_loop()
except RuntimeError:
warnings.warn(
"No running event loop. zmq.asyncio should be used from within an asyncio loop.",
RuntimeWarning,
stacklevel=4,
)
# get_event_loop deprecated in 3.10:
return asyncio.get_event_loop()
class Poller(_AsyncIO, _future._AsyncPoller):
"""Poller returning asyncio.Future for poll results."""
def _watch_raw_socket(self, loop, socket, evt, f):
"""Schedule callback for a raw socket"""
selector = _get_selector(loop)
if evt & self._READ:
selector.add_reader(socket, lambda *args: f())
if evt & self._WRITE:
selector.add_writer(socket, lambda *args: f())
def _unwatch_raw_sockets(self, loop, *sockets):
"""Unschedule callback for a raw socket"""
selector = _get_selector(loop)
for socket in sockets:
selector.remove_reader(socket)
selector.remove_writer(socket)
class Socket(_AsyncIO, _future._AsyncSocket):
"""Socket returning asyncio Futures for send/recv/poll methods."""
_poller_class = Poller
def _get_selector(self, io_loop=None):
if io_loop is None:
io_loop = self._get_loop()
return _get_selector(io_loop)
def _init_io_state(self, io_loop=None):
"""initialize the ioloop event handler"""
self._get_selector(io_loop).add_reader(
self._fd, lambda: self._handle_events(0, 0)
)
def _clear_io_state(self):
"""clear any ioloop event handler
called once at close
"""
loop = self._current_loop
if loop and not loop.is_closed() and self._fd != -1:
self._get_selector(loop).remove_reader(self._fd)
Poller._socket_class = Socket
class Context(_zmq.Context[Socket]):
"""Context for creating asyncio-compatible Sockets"""
_socket_class = Socket
# avoid sharing instance with base Context class
_instance = None
class ZMQEventLoop(SelectorEventLoop):
"""DEPRECATED: AsyncIO eventloop using zmq_poll.
pyzmq sockets should work with any asyncio event loop as of pyzmq 17.
"""
def __init__(self, selector=None):
_deprecated()
return super().__init__(selector)
_loop = None
def _deprecated():
if _deprecated.called: # type: ignore
return
_deprecated.called = True # type: ignore
warnings.warn(
"ZMQEventLoop and zmq.asyncio.install are deprecated in pyzmq 17. Special eventloop integration is no longer needed.",
DeprecationWarning,
stacklevel=3,
)
_deprecated.called = False # type: ignore
def install():
"""DEPRECATED: No longer needed in pyzmq 17"""
_deprecated()
__all__ = [
"Context",
"Socket",
"Poller",
"ZMQEventLoop",
"install",
]

View File

@@ -0,0 +1,14 @@
"""Utilities for ZAP authentication.
To run authentication in a background thread, see :mod:`zmq.auth.thread`.
For integration with the tornado eventloop, see :mod:`zmq.auth.ioloop`.
For integration with the asyncio event loop, see :mod:`zmq.auth.asyncio`.
Authentication examples are provided in the pyzmq codebase, under
`/examples/security/`.
.. versionadded:: 14.1
"""
from .base import *
from .certs import *

View File

@@ -0,0 +1,59 @@
"""ZAP Authenticator integrated with the asyncio IO loop.
.. versionadded:: 15.2
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import warnings
from typing import Any, Optional
import zmq
from zmq.asyncio import Poller
from .base import Authenticator
class AsyncioAuthenticator(Authenticator):
"""ZAP authentication for use in the asyncio IO loop"""
__poller: Optional[Poller]
__task: Any
zap_socket: "zmq.asyncio.Socket"
def __init__(self, context: Optional["zmq.Context"] = None, loop: Any = None):
super().__init__(context)
if loop is not None:
warnings.warn(f"{self.__class__.__name__}(loop) is deprecated and ignored")
self.__poller = None
self.__task = None
async def __handle_zap(self) -> None:
while True:
if self.__poller is None:
break
events = await self.__poller.poll()
if self.zap_socket in dict(events):
msg = await self.zap_socket.recv_multipart()
self.handle_zap_message(msg)
def start(self) -> None:
"""Start ZAP authentication"""
super().start()
self.__poller = Poller()
self.__poller.register(self.zap_socket, zmq.POLLIN)
self.__task = asyncio.ensure_future(self.__handle_zap())
def stop(self) -> None:
"""Stop ZAP authentication"""
if self.__task:
self.__task.cancel()
if self.__poller:
self.__poller.unregister(self.zap_socket)
self.__poller = None
super().stop()
__all__ = ["AsyncioAuthenticator"]

View File

@@ -0,0 +1,443 @@
"""Base implementation of 0MQ authentication."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import os
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import zmq
from zmq.error import _check_version
from zmq.utils import z85
from .certs import load_certificates
CURVE_ALLOW_ANY = '*'
VERSION = b'1.0'
class Authenticator:
"""Implementation of ZAP authentication for zmq connections.
This authenticator class does not register with an event loop. As a result,
you will need to manually call `handle_zap_message`::
auth = zmq.Authenticator()
auth.allow("127.0.0.1")
auth.start()
while True:
auth.handle_zap_msg(auth.zap_socket.recv_multipart()
Alternatively, you can register `auth.zap_socket` with a poller.
Since many users will want to run ZAP in a way that does not block the
main thread, other authentication classes (such as :mod:`zmq.auth.thread`)
are provided.
Note:
- libzmq provides four levels of security: default NULL (which the Authenticator does
not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
- until you add policies, all incoming NULL connections are allowed.
(classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
- GSSAPI requires no configuration.
"""
context: "zmq.Context"
encoding: str
allow_any: bool
credentials_providers: Dict[str, Any]
zap_socket: "zmq.Socket"
whitelist: Set[str]
blacklist: Set[str]
passwords: Dict[str, Dict[str, str]]
certs: Dict[str, Dict[bytes, Any]]
log: Any
def __init__(
self,
context: Optional["zmq.Context"] = None,
encoding: str = 'utf-8',
log: Any = None,
):
_check_version((4, 0), "security")
self.context = context or zmq.Context.instance()
self.encoding = encoding
self.allow_any = False
self.credentials_providers = {}
self.zap_socket = None # type: ignore
self.whitelist = set()
self.blacklist = set()
# passwords is a dict keyed by domain and contains values
# of dicts with username:password pairs.
self.passwords = {}
# certs is dict keyed by domain and contains values
# of dicts keyed by the public keys from the specified location.
self.certs = {}
self.log = log or logging.getLogger('zmq.auth')
def start(self) -> None:
"""Create and bind the ZAP socket"""
self.zap_socket = self.context.socket(zmq.REP)
self.zap_socket.linger = 1
self.zap_socket.bind("inproc://zeromq.zap.01")
self.log.debug("Starting")
def stop(self) -> None:
"""Close the ZAP socket"""
if self.zap_socket:
self.zap_socket.close()
self.zap_socket = None # type: ignore
def allow(self, *addresses: str) -> None:
"""Allow (whitelist) IP address(es).
Connections from addresses not in the whitelist will be rejected.
- For NULL, all clients from this address will be accepted.
- For real auth setups, they will be allowed to continue with authentication.
whitelist is mutually exclusive with blacklist.
"""
if self.blacklist:
raise ValueError("Only use a whitelist or a blacklist, not both")
self.log.debug("Allowing %s", ','.join(addresses))
self.whitelist.update(addresses)
def deny(self, *addresses: str) -> None:
"""Deny (blacklist) IP address(es).
Addresses not in the blacklist will be allowed to continue with authentication.
Blacklist is mutually exclusive with whitelist.
"""
if self.whitelist:
raise ValueError("Only use a whitelist or a blacklist, not both")
self.log.debug("Denying %s", ','.join(addresses))
self.blacklist.update(addresses)
def configure_plain(
self, domain: str = '*', passwords: Dict[str, str] = None
) -> None:
"""Configure PLAIN authentication for a given domain.
PLAIN authentication uses a plain-text password file.
To cover all domains, use "*".
You can modify the password file at any time; it is reloaded automatically.
"""
if passwords:
self.passwords[domain] = passwords
self.log.debug("Configure plain: %s", domain)
def configure_curve(
self, domain: str = '*', location: Union[str, os.PathLike] = "."
) -> None:
"""Configure CURVE authentication for a given domain.
CURVE authentication uses a directory that holds all public client certificates,
i.e. their public keys.
To cover all domains, use "*".
You can add and remove certificates in that directory at any time. configure_curve must be called
every time certificates are added or removed, in order to update the Authenticator's state
To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
"""
# If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
# treat location as a directory that holds the certificates.
self.log.debug("Configure curve: %s[%s]", domain, location)
if location == CURVE_ALLOW_ANY:
self.allow_any = True
else:
self.allow_any = False
try:
self.certs[domain] = load_certificates(location)
except Exception as e:
self.log.error("Failed to load CURVE certs from %s: %s", location, e)
def configure_curve_callback(
self, domain: str = '*', credentials_provider: Any = None
) -> None:
"""Configure CURVE authentication for a given domain.
CURVE authentication using a callback function validating
the client public key according to a custom mechanism, e.g. checking the
key against records in a db. credentials_provider is an object of a class which
implements a callback method accepting two parameters (domain and key), e.g.::
class CredentialsProvider(object):
def __init__(self):
...e.g. db connection
def callback(self, domain, key):
valid = ...lookup key and/or domain in db
if valid:
logging.info('Authorizing: {0}, {1}'.format(domain, key))
return True
else:
logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
return False
To cover all domains, use "*".
To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
"""
self.allow_any = False
if credentials_provider is not None:
self.credentials_providers[domain] = credentials_provider
else:
self.log.error("None credentials_provider provided for domain:%s", domain)
def curve_user_id(self, client_public_key: bytes) -> str:
"""Return the User-Id corresponding to a CURVE client's public key
Default implementation uses the z85-encoding of the public key.
Override to define a custom mapping of public key : user-id
This is only called on successful authentication.
Parameters
----------
client_public_key: bytes
The client public key used for the given message
Returns
-------
user_id: unicode
The user ID as text
"""
return z85.encode(client_public_key).decode('ascii')
def configure_gssapi(
self, domain: str = '*', location: Optional[str] = None
) -> None:
"""Configure GSSAPI authentication
Currently this is a no-op because there is nothing to configure with GSSAPI.
"""
def handle_zap_message(self, msg: List[bytes]):
"""Perform ZAP authentication"""
if len(msg) < 6:
self.log.error("Invalid ZAP message, not enough frames: %r", msg)
if len(msg) < 2:
self.log.error("Not enough information to reply")
else:
self._send_zap_reply(msg[1], b"400", b"Not enough frames")
return
version, request_id, domain, address, identity, mechanism = msg[:6]
credentials = msg[6:]
domain = domain.decode(self.encoding, 'replace')
address = address.decode(self.encoding, 'replace')
if version != VERSION:
self.log.error("Invalid ZAP version: %r", msg)
self._send_zap_reply(request_id, b"400", b"Invalid version")
return
self.log.debug(
"version: %r, request_id: %r, domain: %r,"
" address: %r, identity: %r, mechanism: %r",
version,
request_id,
domain,
address,
identity,
mechanism,
)
# Is address is explicitly whitelisted or blacklisted?
allowed = False
denied = False
reason = b"NO ACCESS"
if self.whitelist:
if address in self.whitelist:
allowed = True
self.log.debug("PASSED (whitelist) address=%s", address)
else:
denied = True
reason = b"Address not in whitelist"
self.log.debug("DENIED (not in whitelist) address=%s", address)
elif self.blacklist:
if address in self.blacklist:
denied = True
reason = b"Address is blacklisted"
self.log.debug("DENIED (blacklist) address=%s", address)
else:
allowed = True
self.log.debug("PASSED (not in blacklist) address=%s", address)
# Perform authentication mechanism-specific checks if necessary
username = "anonymous"
if not denied:
if mechanism == b'NULL' and not allowed:
# For NULL, we allow if the address wasn't blacklisted
self.log.debug("ALLOWED (NULL)")
allowed = True
elif mechanism == b'PLAIN':
# For PLAIN, even a whitelisted address must authenticate
if len(credentials) != 2:
self.log.error("Invalid PLAIN credentials: %r", credentials)
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
return
username, password = (
c.decode(self.encoding, 'replace') for c in credentials
)
allowed, reason = self._authenticate_plain(domain, username, password)
elif mechanism == b'CURVE':
# For CURVE, even a whitelisted address must authenticate
if len(credentials) != 1:
self.log.error("Invalid CURVE credentials: %r", credentials)
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
return
key = credentials[0]
allowed, reason = self._authenticate_curve(domain, key)
if allowed:
username = self.curve_user_id(key)
elif mechanism == b'GSSAPI':
if len(credentials) != 1:
self.log.error("Invalid GSSAPI credentials: %r", credentials)
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
return
# use principal as user-id for now
principal = credentials[0]
username = principal.decode("utf8")
allowed, reason = self._authenticate_gssapi(domain, principal)
if allowed:
self._send_zap_reply(request_id, b"200", b"OK", username)
else:
self._send_zap_reply(request_id, b"400", reason)
def _authenticate_plain(
self, domain: str, username: str, password: str
) -> Tuple[bool, bytes]:
"""PLAIN ZAP authentication"""
allowed = False
reason = b""
if self.passwords:
# If no domain is not specified then use the default domain
if not domain:
domain = '*'
if domain in self.passwords:
if username in self.passwords[domain]:
if password == self.passwords[domain][username]:
allowed = True
else:
reason = b"Invalid password"
else:
reason = b"Invalid username"
else:
reason = b"Invalid domain"
if allowed:
self.log.debug(
"ALLOWED (PLAIN) domain=%s username=%s password=%s",
domain,
username,
password,
)
else:
self.log.debug("DENIED %s", reason)
else:
reason = b"No passwords defined"
self.log.debug("DENIED (PLAIN) %s", reason)
return allowed, reason
def _authenticate_curve(self, domain: str, client_key: bytes) -> Tuple[bool, bytes]:
"""CURVE ZAP authentication"""
allowed = False
reason = b""
if self.allow_any:
allowed = True
reason = b"OK"
self.log.debug("ALLOWED (CURVE allow any client)")
elif self.credentials_providers != {}:
# If no explicit domain is specified then use the default domain
if not domain:
domain = '*'
if domain in self.credentials_providers:
z85_client_key = z85.encode(client_key)
# Callback to check if key is Allowed
if self.credentials_providers[domain].callback(domain, z85_client_key):
allowed = True
reason = b"OK"
else:
reason = b"Unknown key"
status = "ALLOWED" if allowed else "DENIED"
self.log.debug(
"%s (CURVE auth_callback) domain=%s client_key=%s",
status,
domain,
z85_client_key,
)
else:
reason = b"Unknown domain"
else:
# If no explicit domain is specified then use the default domain
if not domain:
domain = '*'
if domain in self.certs:
# The certs dict stores keys in z85 format, convert binary key to z85 bytes
z85_client_key = z85.encode(client_key)
if self.certs[domain].get(z85_client_key):
allowed = True
reason = b"OK"
else:
reason = b"Unknown key"
status = "ALLOWED" if allowed else "DENIED"
self.log.debug(
"%s (CURVE) domain=%s client_key=%s",
status,
domain,
z85_client_key,
)
else:
reason = b"Unknown domain"
return allowed, reason
def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]:
"""Nothing to do for GSSAPI, which has already been handled by an external service."""
self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
return True, b'OK'
def _send_zap_reply(
self,
request_id: bytes,
status_code: bytes,
status_text: bytes,
user_id: str = 'anonymous',
) -> None:
"""Send a ZAP reply to finish the authentication."""
user_id = user_id if status_code == b'200' else b''
if isinstance(user_id, str):
user_id = user_id.encode(self.encoding, 'replace')
metadata = b'' # not currently used
self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
self.zap_socket.send_multipart(reply)
__all__ = ['Authenticator', 'CURVE_ALLOW_ANY']

View File

@@ -0,0 +1,141 @@
"""0MQ authentication related functions and classes."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import datetime
import glob
import os
from typing import Dict, Optional, Tuple, Union
import zmq
_cert_secret_banner = """# **** Generated on {0} by pyzmq ****
# ZeroMQ CURVE **Secret** Certificate
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
"""
_cert_public_banner = """# **** Generated on {0} by pyzmq ****
# ZeroMQ CURVE Public Certificate
# Exchange securely, or use a secure mechanism to verify the contents
# of this file after exchange. Store public certificates in your home
# directory, in the .curve subdirectory.
"""
def _write_key_file(
key_filename: Union[str, os.PathLike],
banner: str,
public_key: Union[str, bytes],
secret_key: Optional[Union[str, bytes]] = None,
metadata: Optional[Dict[str, str]] = None,
encoding: str = 'utf-8',
) -> None:
"""Create a certificate file"""
if isinstance(public_key, bytes):
public_key = public_key.decode(encoding)
if isinstance(secret_key, bytes):
secret_key = secret_key.decode(encoding)
with open(key_filename, 'w', encoding='utf8') as f:
f.write(banner.format(datetime.datetime.now()))
f.write('metadata\n')
if metadata:
for k, v in metadata.items():
if isinstance(k, bytes):
k = k.decode(encoding)
if isinstance(v, bytes):
v = v.decode(encoding)
f.write(f" {k} = {v}\n")
f.write('curve\n')
f.write(f" public-key = \"{public_key}\"\n")
if secret_key:
f.write(f" secret-key = \"{secret_key}\"\n")
def create_certificates(
key_dir: Union[str, os.PathLike],
name: str,
metadata: Optional[Dict[str, str]] = None,
) -> Tuple[str, str]:
"""Create zmq certificates.
Returns the file paths to the public and secret certificate files.
"""
public_key, secret_key = zmq.curve_keypair()
base_filename = os.path.join(key_dir, name)
secret_key_file = f"{base_filename}.key_secret"
public_key_file = f"{base_filename}.key"
now = datetime.datetime.now()
_write_key_file(public_key_file, _cert_public_banner.format(now), public_key)
_write_key_file(
secret_key_file,
_cert_secret_banner.format(now),
public_key,
secret_key=secret_key,
metadata=metadata,
)
return public_key_file, secret_key_file
def load_certificate(
filename: Union[str, os.PathLike]
) -> Tuple[bytes, Optional[bytes]]:
"""Load public and secret key from a zmq certificate.
Returns (public_key, secret_key)
If the certificate file only contains the public key,
secret_key will be None.
If there is no public key found in the file, ValueError will be raised.
"""
public_key = None
secret_key = None
if not os.path.exists(filename):
raise OSError(f"Invalid certificate file: {filename}")
with open(filename, 'rb') as f:
for line in f:
line = line.strip()
if line.startswith(b'#'):
continue
if line.startswith(b'public-key'):
public_key = line.split(b"=", 1)[1].strip(b' \t\'"')
if line.startswith(b'secret-key'):
secret_key = line.split(b"=", 1)[1].strip(b' \t\'"')
if public_key and secret_key:
break
if public_key is None:
raise ValueError("No public key found in %s" % filename)
return public_key, secret_key
def load_certificates(directory: Union[str, os.PathLike] = '.') -> Dict[bytes, bool]:
"""Load public keys from all certificates in a directory"""
certs = {}
if not os.path.isdir(directory):
raise OSError(f"Invalid certificate directory: {directory}")
# Follow czmq pattern of public keys stored in *.key files.
glob_string = os.path.join(directory, "*.key")
cert_files = glob.glob(glob_string)
for cert_file in cert_files:
public_key, _ = load_certificate(cert_file)
if public_key:
certs[public_key] = True
return certs
__all__ = ['create_certificates', 'load_certificate', 'load_certificates']

View File

@@ -0,0 +1,49 @@
"""ZAP Authenticator integrated with the tornado IOLoop.
.. versionadded:: 14.1
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from typing import Any, Optional
from tornado import ioloop
import zmq
from zmq.eventloop import zmqstream
from .base import Authenticator
class IOLoopAuthenticator(Authenticator):
"""ZAP authentication for use in the tornado IOLoop"""
zap_stream: zmqstream.ZMQStream
io_loop: ioloop.IOLoop
def __init__(
self,
context: Optional["zmq.Context"] = None,
encoding: str = 'utf-8',
log: Any = None,
io_loop: Optional[ioloop.IOLoop] = None,
):
super().__init__(context, encoding, log)
self.zap_stream = None # type: ignore
self.io_loop = io_loop or ioloop.IOLoop.current()
def start(self) -> None:
"""Start ZAP authentication"""
super().start()
self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop)
self.zap_stream.on_recv(self.handle_zap_message)
def stop(self):
"""Stop ZAP authentication"""
if self.zap_stream:
self.zap_stream.close()
self.zap_stream = None
super().stop()
__all__ = ['IOLoopAuthenticator']

View File

@@ -0,0 +1,259 @@
"""ZAP Authenticator in a Python Thread.
.. versionadded:: 14.1
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
from itertools import chain
from threading import Event, Thread
from typing import Any, Dict, List, Optional, TypeVar, cast
import zmq
from zmq.utils import jsonapi
from .base import Authenticator
class AuthenticationThread(Thread):
"""A Thread for running a zmq Authenticator
This is run in the background by ThreadedAuthenticator
"""
def __init__(
self,
context: "zmq.Context",
endpoint: str,
encoding: str = 'utf-8',
log: Any = None,
authenticator: Optional[Authenticator] = None,
) -> None:
super().__init__()
self.context = context or zmq.Context.instance()
self.encoding = encoding
self.log = log = log or logging.getLogger('zmq.auth')
self.started = Event()
self.authenticator: Authenticator = authenticator or Authenticator(
context, encoding=encoding, log=log
)
# create a socket to communicate back to main thread.
self.pipe = context.socket(zmq.PAIR)
self.pipe.linger = 1
self.pipe.connect(endpoint)
def run(self) -> None:
"""Start the Authentication Agent thread task"""
self.authenticator.start()
self.started.set()
zap = self.authenticator.zap_socket
poller = zmq.Poller()
poller.register(self.pipe, zmq.POLLIN)
poller.register(zap, zmq.POLLIN)
while True:
try:
socks = dict(poller.poll())
except zmq.ZMQError:
break # interrupted
if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
# Make sure all API requests are processed before
# looking at the ZAP socket.
while True:
try:
msg = self.pipe.recv_multipart(flags=zmq.NOBLOCK)
except zmq.Again:
break
terminate = self._handle_pipe(msg)
if terminate:
break
if terminate:
break
if zap in socks and socks[zap] == zmq.POLLIN:
self._handle_zap()
self.pipe.close()
self.authenticator.stop()
def _handle_zap(self) -> None:
"""
Handle a message from the ZAP socket.
"""
if self.authenticator.zap_socket is None:
raise RuntimeError("ZAP socket closed")
msg = self.authenticator.zap_socket.recv_multipart()
if not msg:
return
self.authenticator.handle_zap_message(msg)
def _handle_pipe(self, msg: List[bytes]) -> bool:
"""
Handle a message from front-end API.
"""
terminate = False
if msg is None:
terminate = True
return terminate
command = msg[0]
self.log.debug("auth received API command %r", command)
if command == b'ALLOW':
addresses = [m.decode(self.encoding) for m in msg[1:]]
try:
self.authenticator.allow(*addresses)
except Exception:
self.log.exception("Failed to allow %s", addresses)
elif command == b'DENY':
addresses = [m.decode(self.encoding) for m in msg[1:]]
try:
self.authenticator.deny(*addresses)
except Exception:
self.log.exception("Failed to deny %s", addresses)
elif command == b'PLAIN':
domain = msg[1].decode(self.encoding)
json_passwords = msg[2]
passwords: Dict[str, str] = cast(
Dict[str, str], jsonapi.loads(json_passwords)
)
self.authenticator.configure_plain(domain, passwords)
elif command == b'CURVE':
# For now we don't do anything with domains
domain = msg[1].decode(self.encoding)
# If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
# treat location as a directory that holds the certificates.
location = msg[2].decode(self.encoding)
self.authenticator.configure_curve(domain, location)
elif command == b'TERMINATE':
terminate = True
else:
self.log.error("Invalid auth command from API: %r", command)
return terminate
T = TypeVar("T", bound=type)
def _inherit_docstrings(cls: T) -> T:
"""inherit docstrings from Authenticator, so we don't duplicate them"""
for name, method in cls.__dict__.items():
if name.startswith('_') or not callable(method):
continue
upstream_method = getattr(Authenticator, name, None)
if not method.__doc__:
method.__doc__ = upstream_method.__doc__
return cls
@_inherit_docstrings
class ThreadAuthenticator:
"""Run ZAP authentication in a background thread"""
context: "zmq.Context"
log: Any
encoding: str
pipe: "zmq.Socket"
pipe_endpoint: str = ''
thread: AuthenticationThread
def __init__(
self,
context: Optional["zmq.Context"] = None,
encoding: str = 'utf-8',
log: Any = None,
):
self.log = log
self.encoding = encoding
self.pipe = None # type: ignore
self.pipe_endpoint = f"inproc://{id(self)}.inproc"
self.thread = None # type: ignore
self.context = context or zmq.Context.instance()
# proxy base Authenticator attributes
def __setattr__(self, key: str, value: Any):
for obj in chain([self], self.__class__.mro()):
if key in obj.__dict__ or (key in getattr(obj, "__annotations__", {})):
object.__setattr__(self, key, value)
return
setattr(self.thread.authenticator, key, value)
def __getattr__(self, key: str):
return getattr(self.thread.authenticator, key)
def allow(self, *addresses: str):
self.pipe.send_multipart(
[b'ALLOW'] + [a.encode(self.encoding) for a in addresses]
)
def deny(self, *addresses: str):
self.pipe.send_multipart(
[b'DENY'] + [a.encode(self.encoding) for a in addresses]
)
def configure_plain(
self, domain: str = '*', passwords: Optional[Dict[str, str]] = None
):
self.pipe.send_multipart(
[b'PLAIN', domain.encode(self.encoding), jsonapi.dumps(passwords or {})]
)
def configure_curve(self, domain: str = '*', location: str = ''):
domain = domain.encode(self.encoding)
location = location.encode(self.encoding)
self.pipe.send_multipart([b'CURVE', domain, location])
def configure_curve_callback(
self, domain: str = '*', credentials_provider: Any = None
):
self.thread.authenticator.configure_curve_callback(
domain, credentials_provider=credentials_provider
)
def start(self) -> None:
"""Start the authentication thread"""
# create a socket to communicate with auth thread.
self.pipe = self.context.socket(zmq.PAIR)
self.pipe.linger = 1
self.pipe.bind(self.pipe_endpoint)
self.thread = AuthenticationThread(
self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log
)
self.thread.start()
if not self.thread.started.wait(timeout=10):
raise RuntimeError("Authenticator thread failed to start")
def stop(self) -> None:
"""Stop the authentication thread"""
if self.pipe:
self.pipe.send(b'TERMINATE')
if self.is_alive():
self.thread.join()
self.thread = None # type: ignore
self.pipe.close()
self.pipe = None # type: ignore
def is_alive(self) -> bool:
"""Is the ZAP thread currently running?"""
if self.thread and self.thread.is_alive():
return True
return False
def __del__(self) -> None:
self.stop()
__all__ = ['ThreadAuthenticator']

View File

@@ -0,0 +1,35 @@
"""Import basic exposure of libzmq C API as a backend"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import platform
from .select import public_api, select_backend
if 'PYZMQ_BACKEND' in os.environ:
backend = os.environ['PYZMQ_BACKEND']
if backend in ('cython', 'cffi'):
backend = 'zmq.backend.%s' % backend
_ns = select_backend(backend)
else:
# default to cython, fallback to cffi
# (reverse on PyPy)
if platform.python_implementation() == 'PyPy':
first, second = ('zmq.backend.cffi', 'zmq.backend.cython')
else:
first, second = ('zmq.backend.cython', 'zmq.backend.cffi')
try:
_ns = select_backend(first)
except Exception as original_error:
try:
_ns = select_backend(second)
except ImportError:
raise original_error from None
globals().update(_ns)
__all__ = public_api

View File

@@ -0,0 +1,114 @@
from typing import Any, List, Optional, Set, Tuple, TypeVar, Union, overload
from typing_extensions import Literal
import zmq
from .select import select_backend
# avoid collision in Frame.bytes
_bytestr = bytes
T = TypeVar("T")
class Frame:
buffer: Any
bytes: bytes
more: bool
tracker: Any
def __init__(
self,
data: Any = None,
track: bool = False,
copy: Optional[bool] = None,
copy_threshold: Optional[int] = None,
): ...
def copy_fast(self: T) -> T: ...
def get(self, option: int) -> Union[int, _bytestr, str]: ...
def set(self, option: int, value: Union[int, _bytestr, str]) -> None: ...
class Socket:
underlying: int
context: "zmq.Context"
copy_threshold: int
# specific option types
FD: int
def close(self, linger: Optional[int] = ...) -> None: ...
def get(self, option: int) -> Union[int, bytes, str]: ...
def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
def connect(self, url: str): ...
def disconnect(self, url: str) -> None: ...
def bind(self, url: str): ...
def unbind(self, url: str) -> None: ...
def send(
self,
data: Any,
flags: int = ...,
copy: bool = ...,
track: bool = ...,
) -> Optional["zmq.MessageTracker"]: ...
@overload
def recv(
self,
flags: int = ...,
*,
copy: Literal[False],
track: bool = ...,
) -> "zmq.Frame": ...
@overload
def recv(
self,
flags: int = ...,
*,
copy: Literal[True],
track: bool = ...,
) -> bytes: ...
@overload
def recv(
self,
flags: int = ...,
track: bool = False,
) -> bytes: ...
@overload
def recv(
self,
flags: Optional[int] = ...,
copy: bool = ...,
track: Optional[bool] = False,
) -> Union["zmq.Frame", bytes]: ...
def monitor(self, addr: Optional[str], events: int) -> None: ...
# draft methods
def join(self, group: str) -> None: ...
def leave(self, group: str) -> None: ...
class Context:
underlying: int
def __init__(self, io_threads: int = 1, shadow: Any = None): ...
def get(self, option: int) -> Union[int, bytes, str]: ...
def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
def socket(self, socket_type: int) -> Socket: ...
def term(self) -> None: ...
IPC_PATH_MAX_LEN: int
def has(capability: str) -> bool: ...
def curve_keypair() -> Tuple[bytes, bytes]: ...
def curve_public(secret_key: bytes) -> bytes: ...
def strerror(errno: Optional[int] = ...) -> str: ...
def zmq_errno() -> int: ...
def zmq_version() -> str: ...
def zmq_version_info() -> Tuple[int, int, int]: ...
def zmq_poll(
sockets: List[Any], timeout: Optional[int] = ...
) -> List[Tuple[Socket, int]]: ...
def device(
device_type: int, frontend: Socket, backend: Optional[Socket] = ...
) -> int: ...
def proxy(frontend: Socket, backend: Socket) -> int: ...
def proxy_steerable(
frontend: Socket,
backend: Socket,
capture: Optional[Socket] = ...,
control: Optional[Socket] = ...,
) -> int: ...

View File

@@ -0,0 +1,33 @@
"""CFFI backend (for PyPy)"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq.backend.cffi import _poll, context, devices, error, message, socket, utils
from ._cffi import ffi
from ._cffi import lib as C
def zmq_version_info():
"""Get libzmq version as tuple of ints"""
major = ffi.new('int*')
minor = ffi.new('int*')
patch = ffi.new('int*')
C.zmq_version(major, minor, patch)
return (int(major[0]), int(minor[0]), int(patch[0]))
__all__ = ["zmq_version_info"]
for submod in (error, message, context, socket, _poll, devices, utils):
__all__.extend(submod.__all__)
from ._poll import *
from .context import *
from .devices import *
from .error import *
from .message import *
from .socket import *
from .utils import *

View File

@@ -0,0 +1,90 @@
void zmq_version(int *major, int *minor, int *patch);
void* zmq_socket(void *context, int type);
int zmq_close(void *socket);
int zmq_bind(void *socket, const char *endpoint);
int zmq_connect(void *socket, const char *endpoint);
int zmq_errno(void);
const char * zmq_strerror(int errnum);
int zmq_device(int device, void *frontend, void *backend);
int zmq_unbind(void *socket, const char *endpoint);
int zmq_disconnect(void *socket, const char *endpoint);
void* zmq_ctx_new();
int zmq_ctx_destroy(void *context);
int zmq_ctx_get(void *context, int opt);
int zmq_ctx_set(void *context, int opt, int optval);
int zmq_proxy(void *frontend, void *backend, void *capture);
int zmq_proxy_steerable(void *frontend,
void *backend,
void *capture,
void *control);
int zmq_socket_monitor(void *socket, const char *addr, int events);
int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key);
int zmq_curve_public (char *z85_public_key, char *z85_secret_key);
int zmq_has (const char *capability);
typedef struct { ...; } zmq_msg_t;
typedef ... zmq_free_fn;
int zmq_msg_init(zmq_msg_t *msg);
int zmq_msg_init_size(zmq_msg_t *msg, size_t size);
int zmq_msg_init_data(zmq_msg_t *msg,
void *data,
size_t size,
zmq_free_fn *ffn,
void *hint);
size_t zmq_msg_size(zmq_msg_t *msg);
void *zmq_msg_data(zmq_msg_t *msg);
int zmq_msg_close(zmq_msg_t *msg);
int zmq_msg_copy(zmq_msg_t *dst, zmq_msg_t *src);
int zmq_msg_send(zmq_msg_t *msg, void *socket, int flags);
int zmq_msg_recv(zmq_msg_t *msg, void *socket, int flags);
int zmq_getsockopt(void *socket,
int option_name,
void *option_value,
size_t *option_len);
int zmq_setsockopt(void *socket,
int option_name,
const void *option_value,
size_t option_len);
typedef int... ZMQ_FD_T;
typedef struct
{
void *socket;
ZMQ_FD_T fd;
short events;
short revents;
} zmq_pollitem_t;
int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout);
// miscellany
void * memcpy(void *restrict s1, const void *restrict s2, size_t n);
void * malloc(size_t sz);
void free(void *p);
int get_ipc_path_max_len(void);
typedef struct _zhint {
void *sock;
void *mutex;
size_t id;
} zhint;
typedef ... mutex_t;
mutex_t* mutex_allocate();
int zmq_wrap_msg_init_data(zmq_msg_t *msg,
void *data,
size_t size,
void *hint);

View File

@@ -0,0 +1,92 @@
"""zmq poll function"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
try:
from time import monotonic
except ImportError:
from time import clock as monotonic
import warnings
from zmq.error import InterruptedSystemCall, _check_rc
from ._cffi import ffi
from ._cffi import lib as C
def _make_zmq_pollitem(socket, flags):
zmq_socket = socket._zmq_socket
zmq_pollitem = ffi.new('zmq_pollitem_t*')
zmq_pollitem.socket = zmq_socket
zmq_pollitem.fd = 0
zmq_pollitem.events = flags
zmq_pollitem.revents = 0
return zmq_pollitem[0]
def _make_zmq_pollitem_fromfd(socket_fd, flags):
zmq_pollitem = ffi.new('zmq_pollitem_t*')
zmq_pollitem.socket = ffi.NULL
zmq_pollitem.fd = socket_fd
zmq_pollitem.events = flags
zmq_pollitem.revents = 0
return zmq_pollitem[0]
def zmq_poll(sockets, timeout):
cffi_pollitem_list = []
low_level_to_socket_obj = {}
from zmq import Socket
for item in sockets:
if isinstance(item[0], Socket):
low_level_to_socket_obj[item[0]._zmq_socket] = item
cffi_pollitem_list.append(_make_zmq_pollitem(item[0], item[1]))
else:
if not isinstance(item[0], int):
# not an FD, get it from fileno()
item = (item[0].fileno(), item[1])
low_level_to_socket_obj[item[0]] = item
cffi_pollitem_list.append(_make_zmq_pollitem_fromfd(item[0], item[1]))
items = ffi.new('zmq_pollitem_t[]', cffi_pollitem_list)
list_length = ffi.cast('int', len(cffi_pollitem_list))
while True:
c_timeout = ffi.cast('long', timeout)
start = monotonic()
rc = C.zmq_poll(items, list_length, c_timeout)
try:
_check_rc(rc)
except InterruptedSystemCall:
if timeout > 0:
ms_passed = int(1000 * (monotonic() - start))
if ms_passed < 0:
# don't allow negative ms_passed,
# which can happen on old Python versions without time.monotonic.
warnings.warn(
"Negative elapsed time for interrupted poll: %s."
" Did the clock change?" % ms_passed,
RuntimeWarning,
)
ms_passed = 0
timeout = max(0, timeout - ms_passed)
continue
else:
break
result = []
for index in range(len(items)):
if items[index].revents > 0:
if not items[index].socket == ffi.NULL:
result.append(
(
low_level_to_socket_obj[items[index].socket][0],
items[index].revents,
)
)
else:
result.append((items[index].fd, items[index].revents))
return result
__all__ = ['zmq_poll']

View File

@@ -0,0 +1,78 @@
"""zmq Context class"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq.constants import EINVAL, IO_THREADS
from zmq.error import InterruptedSystemCall, ZMQError, _check_rc
from ._cffi import ffi
from ._cffi import lib as C
class Context:
_zmq_ctx = None
_iothreads = None
_closed = None
_shadow = False
def __init__(self, io_threads=1, shadow=None):
if shadow:
self._zmq_ctx = ffi.cast("void *", shadow)
self._shadow = True
else:
self._shadow = False
if not io_threads >= 0:
raise ZMQError(EINVAL)
self._zmq_ctx = C.zmq_ctx_new()
if self._zmq_ctx == ffi.NULL:
raise ZMQError(C.zmq_errno())
if not shadow:
C.zmq_ctx_set(self._zmq_ctx, IO_THREADS, io_threads)
self._closed = False
@property
def underlying(self):
"""The address of the underlying libzmq context"""
return int(ffi.cast('size_t', self._zmq_ctx))
@property
def closed(self):
return self._closed
def set(self, option, value):
"""set a context option
see zmq_ctx_set
"""
rc = C.zmq_ctx_set(self._zmq_ctx, option, value)
_check_rc(rc)
def get(self, option):
"""get context option
see zmq_ctx_get
"""
rc = C.zmq_ctx_get(self._zmq_ctx, option)
_check_rc(rc, error_without_errno=False)
return rc
def term(self):
if self.closed:
return
rc = C.zmq_ctx_destroy(self._zmq_ctx)
try:
_check_rc(rc)
except InterruptedSystemCall:
# ignore interrupted term
# see PEP 475 notes about close & EINTR for why
pass
self._zmq_ctx = None
self._closed = True
__all__ = ['Context']

View File

@@ -0,0 +1,63 @@
"""zmq device functions"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from ._cffi import ffi
from ._cffi import lib as C
from .socket import Socket
from .utils import _retry_sys_call
def device(device_type, frontend, backend):
return proxy(frontend, backend)
def proxy(frontend, backend, capture=None):
if isinstance(capture, Socket):
capture = capture._zmq_socket
else:
capture = ffi.NULL
_retry_sys_call(C.zmq_proxy, frontend._zmq_socket, backend._zmq_socket, capture)
def proxy_steerable(frontend, backend, capture=None, control=None):
"""proxy_steerable(frontend, backend, capture, control)
Start a zeromq proxy with control flow.
.. versionadded:: libzmq-4.1
.. versionadded:: 18.0
Parameters
----------
frontend : Socket
The Socket instance for the incoming traffic.
backend : Socket
The Socket instance for the outbound traffic.
capture : Socket (optional)
The Socket instance for capturing traffic.
control : Socket (optional)
The Socket instance for control flow.
"""
if isinstance(capture, Socket):
capture = capture._zmq_socket
else:
capture = ffi.NULL
if isinstance(control, Socket):
control = control._zmq_socket
else:
control = ffi.NULL
_retry_sys_call(
C.zmq_proxy_steerable,
frontend._zmq_socket,
backend._zmq_socket,
capture,
control,
)
__all__ = ['device', 'proxy', 'proxy_steerable']

View File

@@ -0,0 +1,20 @@
"""zmq error functions"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from ._cffi import ffi
from ._cffi import lib as C
def strerror(errno):
s = ffi.string(C.zmq_strerror(errno))
if not isinstance(s, str):
# py3
s = s.decode()
return s
zmq_errno = C.zmq_errno
__all__ = ['strerror', 'zmq_errno']

View File

@@ -0,0 +1,225 @@
"""Dummy Frame object"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import errno
from threading import Event
import zmq
import zmq.error
from zmq.constants import ETERM
from ._cffi import ffi
from ._cffi import lib as C
zmq_gc = None
try:
from __pypy__.bufferable import bufferable as maybe_bufferable
except ImportError:
maybe_bufferable = object
def _content(obj):
"""Return content of obj as bytes"""
if type(obj) is bytes:
return obj
if not isinstance(obj, memoryview):
obj = memoryview(obj)
return obj.tobytes()
def _check_rc(rc):
err = C.zmq_errno()
if rc == -1:
if err == errno.EINTR:
raise zmq.error.InterrruptedSystemCall(err)
elif err == errno.EAGAIN:
raise zmq.error.Again(errno)
elif err == ETERM:
raise zmq.error.ContextTerminated(err)
else:
raise zmq.error.ZMQError(err)
return 0
class Frame(maybe_bufferable):
_data = None
tracker = None
closed = False
more = False
_buffer = None
_bytes = None
_failed_init = False
tracker_event = None
zmq_msg = None
def __init__(self, data=None, track=False, copy=None, copy_threshold=None):
self._failed_init = True
self.zmq_msg = ffi.cast('zmq_msg_t[1]', C.malloc(ffi.sizeof("zmq_msg_t")))
# self.tracker should start finished
# except in the case where we are sharing memory with libzmq
if track:
self.tracker = zmq._FINISHED_TRACKER
if isinstance(data, str):
raise TypeError(
"Unicode strings are not allowed. Only: bytes, buffer interfaces."
)
if data is None:
rc = C.zmq_msg_init(self.zmq_msg)
_check_rc(rc)
self._failed_init = False
return
self._data = data
if type(data) is bytes:
# avoid unnecessary copy on .bytes access
self._bytes = data
self._buffer = memoryview(data)
c_data = ffi.from_buffer(self._buffer)
data_len_c = self._buffer.nbytes
if copy is None:
if copy_threshold and data_len_c < copy_threshold:
copy = True
else:
copy = False
if copy:
# copy message data instead of sharing memory
rc = C.zmq_msg_init_size(self.zmq_msg, data_len_c)
_check_rc(rc)
ffi.buffer(C.zmq_msg_data(self.zmq_msg), data_len_c)[:] = self._buffer
self._failed_init = False
return
# Getting here means that we are doing a true zero-copy Frame,
# where libzmq and Python are sharing memory.
# Hook up garbage collection with MessageTracker and zmq_free_fn
# Event and MessageTracker for monitoring when zmq is done with data:
if track:
evt = Event()
self.tracker_event = evt
self.tracker = zmq.MessageTracker(evt)
# create the hint for zmq_free_fn
# two pointers: the zmq_gc context and a message to be sent to the zmq_gc PULL socket
# allows libzmq to signal to Python when it is done with Python-owned memory.
global zmq_gc
if zmq_gc is None:
from zmq.utils.garbage import gc as zmq_gc
# can't use ffi.new because it will be freed at the wrong time!
hint = ffi.cast("zhint[1]", C.malloc(ffi.sizeof("zhint")))
hint[0].id = zmq_gc.store(data, self.tracker_event)
if not zmq_gc._push_mutex:
zmq_gc._push_mutex = C.mutex_allocate()
hint[0].mutex = ffi.cast("mutex_t*", zmq_gc._push_mutex)
hint[0].sock = ffi.cast("void*", zmq_gc._push_socket.underlying)
# calls zmq_wrap_msg_init_data with the C.free_python_msg callback
rc = C.zmq_wrap_msg_init_data(
self.zmq_msg,
c_data,
data_len_c,
hint,
)
if rc != 0:
C.free(hint)
C.free(self.zmq_msg)
_check_rc(rc)
self._failed_init = False
def __del__(self):
if not self.closed and not self._failed_init:
self.close()
def close(self):
if self.closed or self._failed_init or self.zmq_msg is None:
return
self.closed = True
rc = C.zmq_msg_close(self.zmq_msg)
C.free(self.zmq_msg)
self.zmq_msg = None
if rc != 0:
_check_rc(rc)
def _buffer_from_zmq_msg(self):
"""one-time extract buffer from zmq_msg
for Frames created by recv
"""
if self._data is None:
self._data = ffi.buffer(
C.zmq_msg_data(self.zmq_msg), C.zmq_msg_size(self.zmq_msg)
)
if self._buffer is None:
self._buffer = memoryview(self._data)
@property
def buffer(self):
if self._buffer is None:
self._buffer_from_zmq_msg()
return self._buffer
@property
def bytes(self):
if self._bytes is None:
self._bytes = self.buffer.tobytes()
return self._bytes
def __len__(self):
return self.buffer.nbytes
def __eq__(self, other):
return self.bytes == _content(other)
def __str__(self):
return self.bytes.decode()
@property
def done(self):
return self.tracker.done()
def __buffer__(self, flags):
return self.buffer
def __copy__(self):
"""Create a shallow copy of the message.
This does not copy the contents of the Frame, just the pointer.
This will increment the 0MQ ref count of the message, but not
the ref count of the Python object. That is only done once when
the Python is first turned into a 0MQ message.
"""
return self.fast_copy()
def fast_copy(self):
"""Fast shallow copy of the Frame.
Does not copy underlying data.
"""
new_msg = Frame()
# This does not copy the contents, but just increases the ref-count
# of the zmq_msg by one.
C.zmq_msg_copy(new_msg.zmq_msg, self.zmq_msg)
# Copy the ref to underlying data
new_msg._data = self._data
new_msg._buffer = self._buffer
# Frame copies share the tracker and tracker_event
new_msg.tracker_event = self.tracker_event
new_msg.tracker = self.tracker
return new_msg
Message = Frame
__all__ = ['Frame', 'Message']

View File

@@ -0,0 +1,351 @@
"""zmq Socket class"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import errno as errno_mod
from ._cffi import ffi
from ._cffi import lib as C
nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length)
new_uint64_pointer = lambda: (ffi.new('uint64_t*'), nsp(ffi.sizeof('uint64_t')))
new_int64_pointer = lambda: (ffi.new('int64_t*'), nsp(ffi.sizeof('int64_t')))
new_int_pointer = lambda: (ffi.new('int*'), nsp(ffi.sizeof('int')))
new_binary_data = lambda length: (
ffi.new('char[%d]' % (length)),
nsp(ffi.sizeof('char') * length),
)
value_uint64_pointer = lambda val: (ffi.new('uint64_t*', val), ffi.sizeof('uint64_t'))
value_int64_pointer = lambda val: (ffi.new('int64_t*', val), ffi.sizeof('int64_t'))
value_int_pointer = lambda val: (ffi.new('int*', val), ffi.sizeof('int'))
value_binary_data = lambda val, length: (
ffi.new('char[%d]' % (length + 1), val),
ffi.sizeof('char') * length,
)
ZMQ_FD_64BIT = ffi.sizeof('ZMQ_FD_T') == 8
IPC_PATH_MAX_LEN = C.get_ipc_path_max_len()
import zmq
from zmq.constants import SocketOption, _OptType
from zmq.error import ZMQError, _check_rc, _check_version
from .message import Frame
from .utils import _retry_sys_call
def new_pointer_from_opt(option, length=0):
opt_type = getattr(option, "_opt_type", _OptType.int)
if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
return new_int64_pointer()
elif opt_type == _OptType.bytes:
return new_binary_data(length)
else:
# default
return new_int_pointer()
def value_from_opt_pointer(option, opt_pointer, length=0):
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
if opt_type == _OptType.bytes:
return ffi.buffer(opt_pointer, length)[:]
else:
return int(opt_pointer[0])
def initialize_opt_pointer(option, value, length=0):
opt_type = getattr(option, "_opt_type", _OptType.int)
if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
return value_int64_pointer(value)
elif opt_type == _OptType.bytes:
return value_binary_data(value, length)
else:
return value_int_pointer(value)
class Socket:
context = None
socket_type = None
_zmq_socket = None
_closed = None
_ref = None
_shadow = False
copy_threshold = 0
def __init__(self, context=None, socket_type=None, shadow=None):
self.context = context
if shadow is not None:
if isinstance(shadow, Socket):
shadow = shadow.underlying
self._zmq_socket = ffi.cast("void *", shadow)
self._shadow = True
else:
self._shadow = False
self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type)
if self._zmq_socket == ffi.NULL:
raise ZMQError()
self._closed = False
@property
def underlying(self):
"""The address of the underlying libzmq socket"""
return int(ffi.cast('size_t', self._zmq_socket))
def _check_closed_deep(self):
"""thorough check of whether the socket has been closed,
even if by another entity (e.g. ctx.destroy).
Only used by the `closed` property.
returns True if closed, False otherwise
"""
if self._closed:
return True
try:
self.get(zmq.TYPE)
except ZMQError as e:
if e.errno == zmq.ENOTSOCK:
self._closed = True
return True
else:
raise
return False
@property
def closed(self):
return self._check_closed_deep()
def close(self, linger=None):
rc = 0
if not self._closed and hasattr(self, '_zmq_socket'):
if self._zmq_socket is not None:
if linger is not None:
self.set(zmq.LINGER, linger)
rc = C.zmq_close(self._zmq_socket)
self._closed = True
if rc < 0:
_check_rc(rc)
def bind(self, address):
if isinstance(address, str):
address_b = address.encode('utf8')
else:
address_b = address
if isinstance(address, bytes):
address = address_b.decode('utf8')
rc = C.zmq_bind(self._zmq_socket, address_b)
if rc < 0:
if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG:
path = address.split('://', 1)[-1]
msg = (
'ipc path "{}" is longer than {} '
'characters (sizeof(sockaddr_un.sun_path)).'.format(
path, IPC_PATH_MAX_LEN
)
)
raise ZMQError(C.zmq_errno(), msg=msg)
elif C.zmq_errno() == errno_mod.ENOENT:
path = address.split('://', 1)[-1]
msg = f'No such file or directory for ipc path "{path}".'
raise ZMQError(C.zmq_errno(), msg=msg)
else:
_check_rc(rc)
def unbind(self, address):
_check_version((3, 2), "unbind")
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_unbind(self._zmq_socket, address)
_check_rc(rc)
def connect(self, address):
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_connect(self._zmq_socket, address)
_check_rc(rc)
def disconnect(self, address):
_check_version((3, 2), "disconnect")
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_disconnect(self._zmq_socket, address)
_check_rc(rc)
def set(self, option, value):
length = None
if isinstance(value, str):
raise TypeError("unicode not allowed, use bytes")
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
if isinstance(value, bytes):
if opt_type != _OptType.bytes:
raise TypeError("not a bytes sockopt: %s" % option)
length = len(value)
c_value_pointer, c_sizet = initialize_opt_pointer(option, value, length)
_retry_sys_call(
C.zmq_setsockopt,
self._zmq_socket,
option,
ffi.cast('void*', c_value_pointer),
c_sizet,
)
def get(self, option):
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255)
_retry_sys_call(
C.zmq_getsockopt, self._zmq_socket, option, c_value_pointer, c_sizet_pointer
)
sz = c_sizet_pointer[0]
v = value_from_opt_pointer(option, c_value_pointer, sz)
if (
option != zmq.SocketOption.ROUTING_ID
and opt_type == _OptType.bytes
and v.endswith(b'\0')
):
v = v[:-1]
return v
def _send_copy(self, buf, flags):
"""Send a copy of a bufferable"""
zmq_msg = ffi.new('zmq_msg_t*')
if not isinstance(buf, bytes):
# cast any bufferable data to bytes via memoryview
buf = memoryview(buf).tobytes()
c_message = ffi.new('char[]', buf)
rc = C.zmq_msg_init_size(zmq_msg, len(buf))
_check_rc(rc)
C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(buf))
_retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
rc2 = C.zmq_msg_close(zmq_msg)
_check_rc(rc2)
def _send_frame(self, frame, flags):
"""Send a Frame on this socket in a non-copy manner."""
# Always copy the Frame so the original message isn't garbage collected.
# This doesn't do a real copy, just a reference.
frame_copy = frame.fast_copy()
zmq_msg = frame_copy.zmq_msg
_retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
tracker = frame_copy.tracker
frame_copy.close()
return tracker
def send(self, data, flags=0, copy=False, track=False):
if isinstance(data, str):
raise TypeError("Message must be in bytes, not a unicode object")
if copy and not isinstance(data, Frame):
return self._send_copy(data, flags)
else:
close_frame = False
if isinstance(data, Frame):
if track and not data.tracker:
raise ValueError('Not a tracked message')
frame = data
else:
if self.copy_threshold:
buf = memoryview(data)
# always copy messages smaller than copy_threshold
if buf.nbytes < self.copy_threshold:
self._send_copy(buf, flags)
return zmq._FINISHED_TRACKER
frame = Frame(data, track=track, copy_threshold=self.copy_threshold)
close_frame = True
tracker = self._send_frame(frame, flags)
if close_frame:
frame.close()
return tracker
def recv(self, flags=0, copy=True, track=False):
if copy:
zmq_msg = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg)
else:
frame = zmq.Frame(track=track)
zmq_msg = frame.zmq_msg
try:
_retry_sys_call(C.zmq_msg_recv, zmq_msg, self._zmq_socket, flags)
except Exception:
if copy:
C.zmq_msg_close(zmq_msg)
raise
if not copy:
return frame
_buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg))
_bytes = _buffer[:]
rc = C.zmq_msg_close(zmq_msg)
_check_rc(rc)
return _bytes
def monitor(self, addr, events=-1):
"""s.monitor(addr, flags)
Start publishing socket events on inproc.
See libzmq docs for zmq_monitor for details.
Note: requires libzmq >= 3.2
Parameters
----------
addr : str
The inproc url used for monitoring. Passing None as
the addr will cause an existing socket monitor to be
deregistered.
events : int [default: zmq.EVENT_ALL]
The zmq event bitmask for which events will be sent to the monitor.
"""
_check_version((3, 2), "monitor")
if events < 0:
events = zmq.EVENT_ALL
if addr is None:
addr = ffi.NULL
if isinstance(addr, str):
addr = addr.encode('utf8')
C.zmq_socket_monitor(self._zmq_socket, addr, events)
__all__ = ['Socket', 'IPC_PATH_MAX_LEN']

View File

@@ -0,0 +1,78 @@
"""miscellaneous zmq_utils wrapping"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq.error import InterruptedSystemCall, _check_rc, _check_version
from ._cffi import ffi
from ._cffi import lib as C
def has(capability):
"""Check for zmq capability by name (e.g. 'ipc', 'curve')
.. versionadded:: libzmq-4.1
.. versionadded:: 14.1
"""
_check_version((4, 1), 'zmq.has')
if isinstance(capability, str):
capability = capability.encode('utf8')
return bool(C.zmq_has(capability))
def curve_keypair():
"""generate a Z85 keypair for use with zmq.CURVE security
Requires libzmq (≥ 4.0) to have been built with CURVE support.
Returns
-------
(public, secret) : two bytestrings
The public and private keypair as 40 byte z85-encoded bytestrings.
"""
_check_version((3, 2), "curve_keypair")
public = ffi.new('char[64]')
private = ffi.new('char[64]')
rc = C.zmq_curve_keypair(public, private)
_check_rc(rc)
return ffi.buffer(public)[:40], ffi.buffer(private)[:40]
def curve_public(private):
"""Compute the public key corresponding to a private key for use
with zmq.CURVE security
Requires libzmq (≥ 4.2) to have been built with CURVE support.
Parameters
----------
private
The private key as a 40 byte z85-encoded bytestring
Returns
-------
bytestring
The public key as a 40 byte z85-encoded bytestring.
"""
if isinstance(private, str):
private = private.encode('utf8')
_check_version((4, 2), "curve_public")
public = ffi.new('char[64]')
rc = C.zmq_curve_public(public, private)
_check_rc(rc)
return ffi.buffer(public)[:40]
def _retry_sys_call(f, *args, **kwargs):
"""make a call, retrying if interrupted with EINTR"""
while True:
rc = f(*args)
try:
_check_rc(rc)
except InterruptedSystemCall:
continue
else:
break
__all__ = ['has', 'curve_keypair', 'curve_public']

View File

@@ -0,0 +1,3 @@
from zmq.backend.cython.context cimport Context
from zmq.backend.cython.message cimport Frame
from zmq.backend.cython.socket cimport Socket

View File

@@ -0,0 +1,40 @@
"""Python bindings for core 0MQ objects."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Lesser GNU Public License (LGPL).
from . import (
_device,
_poll,
_proxy_steerable,
_version,
context,
error,
message,
socket,
utils,
)
__all__ = []
for submod in (
error,
message,
context,
socket,
utils,
_poll,
_version,
_device,
_proxy_steerable,
):
__all__.extend(submod.__all__)
from ._device import * # noqa
from ._poll import * # noqa
from ._proxy_steerable import * # noqa
from ._version import * # noqa
from .context import * # noqa
from .error import * # noqa
from .message import * # noqa
from .socket import * # noqa
from .utils import * # noqa

View File

@@ -0,0 +1,29 @@
from cpython cimport PyErr_CheckSignals
from libc.errno cimport EAGAIN, EINTR
from .libzmq cimport ZMQ_ETERM, zmq_errno
cdef inline int _check_rc(int rc, bint error_without_errno=True) except -1:
"""internal utility for checking zmq return condition
and raising the appropriate Exception class
"""
cdef int errno = zmq_errno()
PyErr_CheckSignals()
if errno == 0 and not error_without_errno:
return 0
if rc == -1: # if rc < -1, it's a bug in libzmq. Should we warn?
if errno == EINTR:
from zmq.error import InterruptedSystemCall
raise InterruptedSystemCall(errno)
elif errno == EAGAIN:
from zmq.error import Again
raise Again(errno)
elif errno == ZMQ_ETERM:
from zmq.error import ContextTerminated
raise ContextTerminated(errno)
else:
from zmq.error import ZMQError
raise ZMQError(errno)
return 0

View File

@@ -0,0 +1,228 @@
cdef extern from "zmq.h" nogil:
enum: PYZMQ_DRAFT_API
enum: ZMQ_VERSION
enum: ZMQ_VERSION_MAJOR
enum: ZMQ_VERSION_MINOR
enum: ZMQ_VERSION_PATCH
enum: ZMQ_IO_THREADS
enum: ZMQ_MAX_SOCKETS
enum: ZMQ_SOCKET_LIMIT
enum: ZMQ_THREAD_PRIORITY
enum: ZMQ_THREAD_SCHED_POLICY
enum: ZMQ_MAX_MSGSZ
enum: ZMQ_MSG_T_SIZE
enum: ZMQ_THREAD_AFFINITY_CPU_ADD
enum: ZMQ_THREAD_AFFINITY_CPU_REMOVE
enum: ZMQ_THREAD_NAME_PREFIX
enum: ZMQ_STREAMER
enum: ZMQ_FORWARDER
enum: ZMQ_QUEUE
enum: ZMQ_EAGAIN "EAGAIN"
enum: ZMQ_EFAULT "EFAULT"
enum: ZMQ_EINVAL "EINVAL"
enum: ZMQ_ENOTSUP "ENOTSUP"
enum: ZMQ_EPROTONOSUPPORT "EPROTONOSUPPORT"
enum: ZMQ_ENOBUFS "ENOBUFS"
enum: ZMQ_ENETDOWN "ENETDOWN"
enum: ZMQ_EADDRINUSE "EADDRINUSE"
enum: ZMQ_EADDRNOTAVAIL "EADDRNOTAVAIL"
enum: ZMQ_ECONNREFUSED "ECONNREFUSED"
enum: ZMQ_EINPROGRESS "EINPROGRESS"
enum: ZMQ_ENOTSOCK "ENOTSOCK"
enum: ZMQ_EMSGSIZE "EMSGSIZE"
enum: ZMQ_EAFNOSUPPORT "EAFNOSUPPORT"
enum: ZMQ_ENETUNREACH "ENETUNREACH"
enum: ZMQ_ECONNABORTED "ECONNABORTED"
enum: ZMQ_ECONNRESET "ECONNRESET"
enum: ZMQ_ENOTCONN "ENOTCONN"
enum: ZMQ_ETIMEDOUT "ETIMEDOUT"
enum: ZMQ_EHOSTUNREACH "EHOSTUNREACH"
enum: ZMQ_ENETRESET "ENETRESET"
enum: ZMQ_EFSM "EFSM"
enum: ZMQ_ENOCOMPATPROTO "ENOCOMPATPROTO"
enum: ZMQ_ETERM "ETERM"
enum: ZMQ_EMTHREAD "EMTHREAD"
enum: ZMQ_EVENT_CONNECTED
enum: ZMQ_EVENT_CONNECT_DELAYED
enum: ZMQ_EVENT_CONNECT_RETRIED
enum: ZMQ_EVENT_LISTENING
enum: ZMQ_EVENT_BIND_FAILED
enum: ZMQ_EVENT_ACCEPTED
enum: ZMQ_EVENT_ACCEPT_FAILED
enum: ZMQ_EVENT_CLOSED
enum: ZMQ_EVENT_CLOSE_FAILED
enum: ZMQ_EVENT_DISCONNECTED
enum: ZMQ_EVENT_MONITOR_STOPPED
enum: ZMQ_EVENT_ALL
enum: ZMQ_HANDSHAKE_FAILED_NO_DETAIL
enum: ZMQ_HANDSHAKE_SUCCEEDED
enum: ZMQ_HANDSHAKE_FAILED_PROTOCOL
enum: ZMQ_HANDSHAKE_FAILED_AUTH
enum: ZMQ_PROTOCOL_ERROR_ZMTP_UNSPECIFIED
enum: ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND
enum: ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE
enum: ZMQ_PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME
enum: ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA
enum: ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC
enum: ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH
enum: ZMQ_PROTOCOL_ERROR_ZAP_UNSPECIFIED
enum: ZMQ_PROTOCOL_ERROR_ZAP_MALFORMED_REPLY
enum: ZMQ_PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID
enum: ZMQ_PROTOCOL_ERROR_ZAP_BAD_VERSION
enum: ZMQ_PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE
enum: ZMQ_PROTOCOL_ERROR_ZAP_INVALID_METADATA
enum: ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED
enum: ZMQ_EVENT_PIPES_STATS
enum: ZMQ_EVENT_ALL_V1
enum: ZMQ_EVENT_ALL_V2
enum: ZMQ_DONTWAIT
enum: ZMQ_SNDMORE
enum: ZMQ_NOBLOCK
enum: ZMQ_MORE
enum: ZMQ_SHARED
enum: ZMQ_SRCFD
enum: ZMQ_POLLIN
enum: ZMQ_POLLOUT
enum: ZMQ_POLLERR
enum: ZMQ_POLLPRI
enum: ZMQ_NULL
enum: ZMQ_PLAIN
enum: ZMQ_CURVE
enum: ZMQ_GSSAPI
enum: ZMQ_HWM
enum: ZMQ_AFFINITY
enum: ZMQ_ROUTING_ID
enum: ZMQ_SUBSCRIBE
enum: ZMQ_UNSUBSCRIBE
enum: ZMQ_RATE
enum: ZMQ_RECOVERY_IVL
enum: ZMQ_SNDBUF
enum: ZMQ_RCVBUF
enum: ZMQ_RCVMORE
enum: ZMQ_FD
enum: ZMQ_EVENTS
enum: ZMQ_TYPE
enum: ZMQ_LINGER
enum: ZMQ_RECONNECT_IVL
enum: ZMQ_BACKLOG
enum: ZMQ_RECONNECT_IVL_MAX
enum: ZMQ_MAXMSGSIZE
enum: ZMQ_SNDHWM
enum: ZMQ_RCVHWM
enum: ZMQ_MULTICAST_HOPS
enum: ZMQ_RCVTIMEO
enum: ZMQ_SNDTIMEO
enum: ZMQ_LAST_ENDPOINT
enum: ZMQ_ROUTER_MANDATORY
enum: ZMQ_TCP_KEEPALIVE
enum: ZMQ_TCP_KEEPALIVE_CNT
enum: ZMQ_TCP_KEEPALIVE_IDLE
enum: ZMQ_TCP_KEEPALIVE_INTVL
enum: ZMQ_IMMEDIATE
enum: ZMQ_XPUB_VERBOSE
enum: ZMQ_ROUTER_RAW
enum: ZMQ_IPV6
enum: ZMQ_MECHANISM
enum: ZMQ_PLAIN_SERVER
enum: ZMQ_PLAIN_USERNAME
enum: ZMQ_PLAIN_PASSWORD
enum: ZMQ_CURVE_SERVER
enum: ZMQ_CURVE_PUBLICKEY
enum: ZMQ_CURVE_SECRETKEY
enum: ZMQ_CURVE_SERVERKEY
enum: ZMQ_PROBE_ROUTER
enum: ZMQ_REQ_CORRELATE
enum: ZMQ_REQ_RELAXED
enum: ZMQ_CONFLATE
enum: ZMQ_ZAP_DOMAIN
enum: ZMQ_ROUTER_HANDOVER
enum: ZMQ_TOS
enum: ZMQ_CONNECT_ROUTING_ID
enum: ZMQ_GSSAPI_SERVER
enum: ZMQ_GSSAPI_PRINCIPAL
enum: ZMQ_GSSAPI_SERVICE_PRINCIPAL
enum: ZMQ_GSSAPI_PLAINTEXT
enum: ZMQ_HANDSHAKE_IVL
enum: ZMQ_SOCKS_PROXY
enum: ZMQ_XPUB_NODROP
enum: ZMQ_BLOCKY
enum: ZMQ_XPUB_MANUAL
enum: ZMQ_XPUB_WELCOME_MSG
enum: ZMQ_STREAM_NOTIFY
enum: ZMQ_INVERT_MATCHING
enum: ZMQ_HEARTBEAT_IVL
enum: ZMQ_HEARTBEAT_TTL
enum: ZMQ_HEARTBEAT_TIMEOUT
enum: ZMQ_XPUB_VERBOSER
enum: ZMQ_CONNECT_TIMEOUT
enum: ZMQ_TCP_MAXRT
enum: ZMQ_THREAD_SAFE
enum: ZMQ_MULTICAST_MAXTPDU
enum: ZMQ_VMCI_BUFFER_SIZE
enum: ZMQ_VMCI_BUFFER_MIN_SIZE
enum: ZMQ_VMCI_BUFFER_MAX_SIZE
enum: ZMQ_VMCI_CONNECT_TIMEOUT
enum: ZMQ_USE_FD
enum: ZMQ_GSSAPI_PRINCIPAL_NAMETYPE
enum: ZMQ_GSSAPI_SERVICE_PRINCIPAL_NAMETYPE
enum: ZMQ_BINDTODEVICE
enum: ZMQ_IDENTITY
enum: ZMQ_CONNECT_RID
enum: ZMQ_TCP_ACCEPT_FILTER
enum: ZMQ_IPC_FILTER_PID
enum: ZMQ_IPC_FILTER_UID
enum: ZMQ_IPC_FILTER_GID
enum: ZMQ_IPV4ONLY
enum: ZMQ_DELAY_ATTACH_ON_CONNECT
enum: ZMQ_FAIL_UNROUTABLE
enum: ZMQ_ROUTER_BEHAVIOR
enum: ZMQ_ZAP_ENFORCE_DOMAIN
enum: ZMQ_LOOPBACK_FASTPATH
enum: ZMQ_METADATA
enum: ZMQ_MULTICAST_LOOP
enum: ZMQ_ROUTER_NOTIFY
enum: ZMQ_XPUB_MANUAL_LAST_VALUE
enum: ZMQ_SOCKS_USERNAME
enum: ZMQ_SOCKS_PASSWORD
enum: ZMQ_IN_BATCH_SIZE
enum: ZMQ_OUT_BATCH_SIZE
enum: ZMQ_WSS_KEY_PEM
enum: ZMQ_WSS_CERT_PEM
enum: ZMQ_WSS_TRUST_PEM
enum: ZMQ_WSS_HOSTNAME
enum: ZMQ_WSS_TRUST_SYSTEM
enum: ZMQ_ONLY_FIRST_SUBSCRIBE
enum: ZMQ_RECONNECT_STOP
enum: ZMQ_HELLO_MSG
enum: ZMQ_DISCONNECT_MSG
enum: ZMQ_PRIORITY
enum: ZMQ_PAIR
enum: ZMQ_PUB
enum: ZMQ_SUB
enum: ZMQ_REQ
enum: ZMQ_REP
enum: ZMQ_DEALER
enum: ZMQ_ROUTER
enum: ZMQ_PULL
enum: ZMQ_PUSH
enum: ZMQ_XPUB
enum: ZMQ_XSUB
enum: ZMQ_STREAM
enum: ZMQ_XREQ
enum: ZMQ_XREP
enum: ZMQ_SERVER
enum: ZMQ_CLIENT
enum: ZMQ_RADIO
enum: ZMQ_DISH
enum: ZMQ_GATHER
enum: ZMQ_SCATTER
enum: ZMQ_DGRAM
enum: ZMQ_PEER
enum: ZMQ_CHANNEL

View File

@@ -0,0 +1,34 @@
"""0MQ Context class declaration."""
#
# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley
#
# This file is part of pyzmq.
#
# pyzmq is free software; you can redistribute it and/or modify it under
# the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# pyzmq is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------------
# Code
#-----------------------------------------------------------------------------
cdef class Context:
cdef object __weakref__ # enable weakref
cdef void *handle # The C handle for the underlying zmq object.
cdef bint _shadow # whether the Context is a shadow wrapper of another
cdef int _pid # the pid of the process which created me (for fork safety)
cdef public bint closed # bool property for a closed context.
cdef inline int _term(self)

View File

@@ -0,0 +1,122 @@
"""All the C imports for 0MQ"""
#
# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley
#
# This file is part of pyzmq.
#
# pyzmq is free software; you can redistribute it and/or modify it under
# the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# pyzmq is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
# Import the C header files
#-----------------------------------------------------------------------------
# common includes, such as zmq compat, pyversion_compat
# make sure we load pyversion compat in every Cython module
cdef extern from "pyversion_compat.h":
pass
# were it not for Windows,
# we could cimport these from libc.stdint
cdef extern from "zmq_compat.h":
ctypedef signed long long int64_t "pyzmq_int64_t"
ctypedef unsigned int uint32_t "pyzmq_uint32_t"
include "constant_enums.pxi"
cdef extern from "zmq.h" nogil:
void _zmq_version "zmq_version"(int *major, int *minor, int *patch)
ctypedef int fd_t "ZMQ_FD_T"
enum: errno
const char *zmq_strerror (int errnum)
int zmq_errno()
void *zmq_ctx_new ()
int zmq_ctx_destroy (void *context)
int zmq_ctx_set (void *context, int option, int optval)
int zmq_ctx_get (void *context, int option)
void *zmq_init (int io_threads)
int zmq_term (void *context)
# blackbox def for zmq_msg_t
ctypedef void * zmq_msg_t "zmq_msg_t"
ctypedef void zmq_free_fn(void *data, void *hint)
int zmq_msg_init (zmq_msg_t *msg)
int zmq_msg_init_size (zmq_msg_t *msg, size_t size)
int zmq_msg_init_data (zmq_msg_t *msg, void *data,
size_t size, zmq_free_fn *ffn, void *hint)
int zmq_msg_send (zmq_msg_t *msg, void *s, int flags)
int zmq_msg_recv (zmq_msg_t *msg, void *s, int flags)
int zmq_msg_close (zmq_msg_t *msg)
int zmq_msg_move (zmq_msg_t *dest, zmq_msg_t *src)
int zmq_msg_copy (zmq_msg_t *dest, zmq_msg_t *src)
void *zmq_msg_data (zmq_msg_t *msg)
size_t zmq_msg_size (zmq_msg_t *msg)
int zmq_msg_more (zmq_msg_t *msg)
int zmq_msg_get (zmq_msg_t *msg, int option)
int zmq_msg_set (zmq_msg_t *msg, int option, int optval)
const char *zmq_msg_gets (zmq_msg_t *msg, const char *property)
int zmq_has (const char *capability)
void *zmq_socket (void *context, int type)
int zmq_close (void *s)
int zmq_setsockopt (void *s, int option, void *optval, size_t optvallen)
int zmq_getsockopt (void *s, int option, void *optval, size_t *optvallen)
int zmq_bind (void *s, char *addr)
int zmq_connect (void *s, char *addr)
int zmq_unbind (void *s, char *addr)
int zmq_disconnect (void *s, char *addr)
int zmq_socket_monitor (void *s, char *addr, int flags)
# send/recv
int zmq_sendbuf (void *s, const void *buf, size_t n, int flags)
int zmq_recvbuf (void *s, void *buf, size_t n, int flags)
ctypedef struct zmq_pollitem_t:
void *socket
fd_t fd
short events
short revents
int zmq_poll (zmq_pollitem_t *items, int nitems, long timeout)
int zmq_device (int device_, void *insocket_, void *outsocket_)
int zmq_proxy (void *frontend, void *backend, void *capture)
int zmq_proxy_steerable (void *frontend,
void *backend,
void *capture,
void *control)
int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key)
int zmq_curve_public (char *z85_public_key, char *z85_secret_key)
# 4.2 draft
int zmq_join (void *s, const char *group)
int zmq_leave (void *s, const char *group)
int zmq_msg_set_routing_id(zmq_msg_t *msg, uint32_t routing_id)
uint32_t zmq_msg_routing_id(zmq_msg_t *msg)
int zmq_msg_set_group(zmq_msg_t *msg, const char *group)
const char *zmq_msg_group(zmq_msg_t *msg)

View File

@@ -0,0 +1,61 @@
"""0MQ Message related class declarations."""
#
# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley
#
# This file is part of pyzmq.
#
# pyzmq is free software; you can redistribute it and/or modify it under
# the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# pyzmq is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
from cpython cimport PyBytes_FromStringAndSize
from zmq.backend.cython.libzmq cimport zmq_msg_data, zmq_msg_size, zmq_msg_t
#-----------------------------------------------------------------------------
# Code
#-----------------------------------------------------------------------------
cdef class MessageTracker(object):
cdef set events # Message Event objects to track.
cdef set peers # Other Message or MessageTracker objects.
cdef class Frame:
cdef zmq_msg_t zmq_msg
cdef object _data # The actual message data as a Python object.
cdef object _buffer # A Python Buffer/View of the message contents
cdef object _bytes # A bytes/str copy of the message.
cdef bint _failed_init # Flag to handle failed zmq_msg_init
cdef public object tracker_event # Event for use with zmq_free_fn.
cdef public object tracker # MessageTracker object.
cdef public bint more # whether RCVMORE was set
cdef Frame fast_copy(self) # Create shallow copy of Message object.
cdef object _getbuffer(self) # Construct self._buffer.
cdef inline object copy_zmq_msg_bytes(zmq_msg_t *zmq_msg):
""" Copy the data from a zmq_msg_t """
cdef char *data_c = NULL
cdef Py_ssize_t data_len_c
data_c = <char *>zmq_msg_data(zmq_msg)
data_len_c = zmq_msg_size(zmq_msg)
return PyBytes_FromStringAndSize(data_c, data_len_c)

View File

@@ -0,0 +1,48 @@
"""0MQ Socket class declaration."""
#
# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley
#
# This file is part of pyzmq.
#
# pyzmq is free software; you can redistribute it and/or modify it under
# the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# pyzmq is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
from .context cimport Context
#-----------------------------------------------------------------------------
# Code
#-----------------------------------------------------------------------------
cdef class Socket:
cdef object __weakref__ # enable weakref
cdef void *handle # The C handle for the underlying zmq object.
cdef bint _shadow # whether the Socket is a shadow wrapper of another
# Hold on to a reference to the context to make sure it is not garbage
# collected until the socket it done with it.
cdef public Context context # The zmq Context object that owns this.
cdef public bint _closed # bool property for a closed socket.
cdef public int copy_threshold # threshold below which pyzmq will always copy messages
cdef int _pid # the pid of the process which created me (for fork safety)
cdef void _c_close(self) # underlying close of zmq socket
# cpdef methods for direct-cython access:
cpdef object send(self, object data, int flags=*, copy=*, track=*)
cpdef object recv(self, int flags=*, copy=*, track=*)

View File

@@ -0,0 +1,40 @@
"""Import basic exposure of libzmq C API as a backend"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from importlib import import_module
from typing import Dict
public_api = [
'Context',
'Socket',
'Frame',
'Message',
'device',
'proxy',
'proxy_steerable',
'zmq_poll',
'strerror',
'zmq_errno',
'has',
'curve_keypair',
'curve_public',
'zmq_version_info',
'IPC_PATH_MAX_LEN',
]
def select_backend(name: str) -> Dict:
"""Select the pyzmq backend"""
try:
mod = import_module(name)
except ImportError:
raise
except Exception as e:
raise ImportError(f"Importing {name} failed with {e}") from e
ns = {}
for key in public_api:
ns[key] = getattr(mod, key)
return ns

View File

@@ -0,0 +1,862 @@
"""zmq constants as enums"""
import errno
import sys
from enum import Enum, IntEnum, IntFlag
from typing import List
_HAUSNUMERO = 156384712
class Errno(IntEnum):
"""libzmq error codes
.. versionadded:: 23
"""
EAGAIN = errno.EAGAIN
EFAULT = errno.EFAULT
EINVAL = errno.EINVAL
if sys.platform.startswith("win"):
# Windows: libzmq uses errno.h
# while Python errno prefers WSA* variants
# many of these were introduced to errno.h in vs2010
# ref: https://github.com/python/cpython/blob/3.9/Modules/errnomodule.c#L10-L37
# source: https://docs.microsoft.com/en-us/cpp/c-runtime-library/errno-constants
ENOTSUP = 129
EPROTONOSUPPORT = 135
ENOBUFS = 119
ENETDOWN = 116
EADDRINUSE = 100
EADDRNOTAVAIL = 101
ECONNREFUSED = 107
EINPROGRESS = 112
ENOTSOCK = 128
EMSGSIZE = 115
EAFNOSUPPORT = 102
ENETUNREACH = 118
ECONNABORTED = 106
ECONNRESET = 108
ENOTCONN = 126
ETIMEDOUT = 138
EHOSTUNREACH = 110
ENETRESET = 117
else:
ENOTSUP = getattr(errno, "ENOTSUP", _HAUSNUMERO + 1)
EPROTONOSUPPORT = getattr(errno, "EPROTONOSUPPORT", _HAUSNUMERO + 2)
ENOBUFS = getattr(errno, "ENOBUFS", _HAUSNUMERO + 3)
ENETDOWN = getattr(errno, "ENETDOWN", _HAUSNUMERO + 4)
EADDRINUSE = getattr(errno, "EADDRINUSE", _HAUSNUMERO + 5)
EADDRNOTAVAIL = getattr(errno, "EADDRNOTAVAIL", _HAUSNUMERO + 6)
ECONNREFUSED = getattr(errno, "ECONNREFUSED", _HAUSNUMERO + 7)
EINPROGRESS = getattr(errno, "EINPROGRESS", _HAUSNUMERO + 8)
ENOTSOCK = getattr(errno, "ENOTSOCK", _HAUSNUMERO + 9)
EMSGSIZE = getattr(errno, "EMSGSIZE", _HAUSNUMERO + 10)
EAFNOSUPPORT = getattr(errno, "EAFNOSUPPORT", _HAUSNUMERO + 11)
ENETUNREACH = getattr(errno, "ENETUNREACH", _HAUSNUMERO + 12)
ECONNABORTED = getattr(errno, "ECONNABORTED", _HAUSNUMERO + 13)
ECONNRESET = getattr(errno, "ECONNRESET", _HAUSNUMERO + 14)
ENOTCONN = getattr(errno, "ENOTCONN", _HAUSNUMERO + 15)
ETIMEDOUT = getattr(errno, "ETIMEDOUT", _HAUSNUMERO + 16)
EHOSTUNREACH = getattr(errno, "EHOSTUNREACH", _HAUSNUMERO + 17)
ENETRESET = getattr(errno, "ENETRESET", _HAUSNUMERO + 18)
# Native 0MQ error codes
EFSM = _HAUSNUMERO + 51
ENOCOMPATPROTO = _HAUSNUMERO + 52
ETERM = _HAUSNUMERO + 53
EMTHREAD = _HAUSNUMERO + 54
class ContextOption(IntEnum):
"""Options for Context.get/set
.. versionadded:: 23
"""
IO_THREADS = 1
MAX_SOCKETS = 2
SOCKET_LIMIT = 3
THREAD_PRIORITY = 3
THREAD_SCHED_POLICY = 4
MAX_MSGSZ = 5
MSG_T_SIZE = 6
THREAD_AFFINITY_CPU_ADD = 7
THREAD_AFFINITY_CPU_REMOVE = 8
THREAD_NAME_PREFIX = 9
class SocketType(IntEnum):
"""zmq socket types
.. versionadded:: 23
"""
PAIR = 0
PUB = 1
SUB = 2
REQ = 3
REP = 4
DEALER = 5
ROUTER = 6
PULL = 7
PUSH = 8
XPUB = 9
XSUB = 10
STREAM = 11
# deprecated aliases
XREQ = DEALER
XREP = ROUTER
# DRAFT socket types
SERVER = 12
CLIENT = 13
RADIO = 14
DISH = 15
GATHER = 16
SCATTER = 17
DGRAM = 18
PEER = 19
CHANNEL = 20
class _OptType(Enum):
int = 'int'
int64 = 'int64'
bytes = 'bytes'
fd = 'fd'
class SocketOption(IntEnum):
"""Options for Socket.get/set
.. versionadded:: 23
"""
_opt_type: str
def __new__(cls, value, opt_type=_OptType.int):
"""Attach option type as `._opt_type`"""
obj = int.__new__(cls, value)
obj._value_ = value
obj._opt_type = opt_type
return obj
HWM = 1
AFFINITY = 4, _OptType.int64
ROUTING_ID = 5, _OptType.bytes
SUBSCRIBE = 6, _OptType.bytes
UNSUBSCRIBE = 7, _OptType.bytes
RATE = 8
RECOVERY_IVL = 9
SNDBUF = 11
RCVBUF = 12
RCVMORE = 13
FD = 14, _OptType.fd
EVENTS = 15
TYPE = 16
LINGER = 17
RECONNECT_IVL = 18
BACKLOG = 19
RECONNECT_IVL_MAX = 21
MAXMSGSIZE = 22, _OptType.int64
SNDHWM = 23
RCVHWM = 24
MULTICAST_HOPS = 25
RCVTIMEO = 27
SNDTIMEO = 28
LAST_ENDPOINT = 32, _OptType.bytes
ROUTER_MANDATORY = 33
TCP_KEEPALIVE = 34
TCP_KEEPALIVE_CNT = 35
TCP_KEEPALIVE_IDLE = 36
TCP_KEEPALIVE_INTVL = 37
IMMEDIATE = 39
XPUB_VERBOSE = 40
ROUTER_RAW = 41
IPV6 = 42
MECHANISM = 43
PLAIN_SERVER = 44
PLAIN_USERNAME = 45, _OptType.bytes
PLAIN_PASSWORD = 46, _OptType.bytes
CURVE_SERVER = 47
CURVE_PUBLICKEY = 48, _OptType.bytes
CURVE_SECRETKEY = 49, _OptType.bytes
CURVE_SERVERKEY = 50, _OptType.bytes
PROBE_ROUTER = 51
REQ_CORRELATE = 52
REQ_RELAXED = 53
CONFLATE = 54
ZAP_DOMAIN = 55, _OptType.bytes
ROUTER_HANDOVER = 56
TOS = 57
CONNECT_ROUTING_ID = 61, _OptType.bytes
GSSAPI_SERVER = 62
GSSAPI_PRINCIPAL = 63, _OptType.bytes
GSSAPI_SERVICE_PRINCIPAL = 64, _OptType.bytes
GSSAPI_PLAINTEXT = 65
HANDSHAKE_IVL = 66
SOCKS_PROXY = 68, _OptType.bytes
XPUB_NODROP = 69
BLOCKY = 70
XPUB_MANUAL = 71
XPUB_WELCOME_MSG = 72, _OptType.bytes
STREAM_NOTIFY = 73
INVERT_MATCHING = 74
HEARTBEAT_IVL = 75
HEARTBEAT_TTL = 76
HEARTBEAT_TIMEOUT = 77
XPUB_VERBOSER = 78
CONNECT_TIMEOUT = 79
TCP_MAXRT = 80
THREAD_SAFE = 81
MULTICAST_MAXTPDU = 84
VMCI_BUFFER_SIZE = 85, _OptType.int64
VMCI_BUFFER_MIN_SIZE = 86, _OptType.int64
VMCI_BUFFER_MAX_SIZE = 87, _OptType.int64
VMCI_CONNECT_TIMEOUT = 88
USE_FD = 89
GSSAPI_PRINCIPAL_NAMETYPE = 90
GSSAPI_SERVICE_PRINCIPAL_NAMETYPE = 91
BINDTODEVICE = 92, _OptType.bytes
# Deprecated options and aliases
# must not use name-assignment, must have the same value
IDENTITY = ROUTING_ID
CONNECT_RID = CONNECT_ROUTING_ID
TCP_ACCEPT_FILTER = 38, _OptType.bytes
IPC_FILTER_PID = 58
IPC_FILTER_UID = 59
IPC_FILTER_GID = 60
IPV4ONLY = 31
DELAY_ATTACH_ON_CONNECT = IMMEDIATE
FAIL_UNROUTABLE = ROUTER_MANDATORY
ROUTER_BEHAVIOR = ROUTER_MANDATORY
# Draft socket options
ZAP_ENFORCE_DOMAIN = 93
LOOPBACK_FASTPATH = 94
METADATA = 95, _OptType.bytes
MULTICAST_LOOP = 96
ROUTER_NOTIFY = 97
XPUB_MANUAL_LAST_VALUE = 98
SOCKS_USERNAME = 99, _OptType.bytes
SOCKS_PASSWORD = 100, _OptType.bytes
IN_BATCH_SIZE = 101
OUT_BATCH_SIZE = 102
WSS_KEY_PEM = 103, _OptType.bytes
WSS_CERT_PEM = 104, _OptType.bytes
WSS_TRUST_PEM = 105, _OptType.bytes
WSS_HOSTNAME = 106, _OptType.bytes
WSS_TRUST_SYSTEM = 107
ONLY_FIRST_SUBSCRIBE = 108
RECONNECT_STOP = 109
HELLO_MSG = 110, _OptType.bytes
DISCONNECT_MSG = 111, _OptType.bytes
PRIORITY = 112
class MessageOption(IntEnum):
"""Options on zmq.Frame objects
.. versionadded:: 23
"""
MORE = 1
SHARED = 3
# Deprecated message options
SRCFD = 2
class Flag(IntFlag):
"""Send/recv flags
.. versionadded:: 23
"""
DONTWAIT = 1
SNDMORE = 2
NOBLOCK = DONTWAIT
class SecurityMechanism(IntEnum):
"""Security mechanisms (as returned by ``socket.get(zmq.MECHANISM)``)
.. versionadded:: 23
"""
NULL = 0
PLAIN = 1
CURVE = 2
GSSAPI = 3
class Event(IntFlag):
"""Socket monitoring events
.. versionadded:: 23
"""
@staticmethod
def _global_name(name):
if name.startswith(("PROTOCOL_ERROR_", "HANDSHAKE_")):
return name
else:
# add EVENT_ prefix
return "EVENT_" + name
CONNECTED = 0x0001
CONNECT_DELAYED = 0x0002
CONNECT_RETRIED = 0x0004
LISTENING = 0x0008
BIND_FAILED = 0x0010
ACCEPTED = 0x0020
ACCEPT_FAILED = 0x0040
CLOSED = 0x0080
CLOSE_FAILED = 0x0100
DISCONNECTED = 0x0200
MONITOR_STOPPED = 0x0400
ALL = 0xFFFF
HANDSHAKE_FAILED_NO_DETAIL = 0x0800
HANDSHAKE_SUCCEEDED = 0x1000
HANDSHAKE_FAILED_PROTOCOL = 0x2000
HANDSHAKE_FAILED_AUTH = 0x4000
PROTOCOL_ERROR_ZMTP_UNSPECIFIED = 0x10000000
PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND = 0x10000001
PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE = 0x10000002
PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE = 0x10000003
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED = 0x10000011
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE = 0x10000012
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO = 0x10000013
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE = 0x10000014
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR = 0x10000015
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY = 0x10000016
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME = 0x10000017
PROTOCOL_ERROR_ZMTP_INVALID_METADATA = 0x10000018
PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC = 0x11000001
PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH = 0x11000002
PROTOCOL_ERROR_ZAP_UNSPECIFIED = 0x20000000
PROTOCOL_ERROR_ZAP_MALFORMED_REPLY = 0x20000001
PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID = 0x20000002
PROTOCOL_ERROR_ZAP_BAD_VERSION = 0x20000003
PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE = 0x20000004
PROTOCOL_ERROR_ZAP_INVALID_METADATA = 0x20000005
PROTOCOL_ERROR_WS_UNSPECIFIED = 0x30000000
# DRAFT Socket monitoring events
PIPES_STATS = 0x10000
ALL_V1 = ALL
ALL_V2 = ALL_V1 | PIPES_STATS
class PollEvent(IntFlag):
"""Which events to poll for in poll methods
.. versionadded: 23
"""
POLLIN = 1
POLLOUT = 2
POLLERR = 4
POLLPRI = 8
class DeviceType(IntEnum):
"""Device type constants for zmq.device
.. versionadded: 23
"""
STREAMER = 1
FORWARDER = 2
QUEUE = 3
# AUTOGENERATED_BELOW_HERE
IO_THREADS: int = ContextOption.IO_THREADS
MAX_SOCKETS: int = ContextOption.MAX_SOCKETS
SOCKET_LIMIT: int = ContextOption.SOCKET_LIMIT
THREAD_PRIORITY: int = ContextOption.THREAD_PRIORITY
THREAD_SCHED_POLICY: int = ContextOption.THREAD_SCHED_POLICY
MAX_MSGSZ: int = ContextOption.MAX_MSGSZ
MSG_T_SIZE: int = ContextOption.MSG_T_SIZE
THREAD_AFFINITY_CPU_ADD: int = ContextOption.THREAD_AFFINITY_CPU_ADD
THREAD_AFFINITY_CPU_REMOVE: int = ContextOption.THREAD_AFFINITY_CPU_REMOVE
THREAD_NAME_PREFIX: int = ContextOption.THREAD_NAME_PREFIX
STREAMER: int = DeviceType.STREAMER
FORWARDER: int = DeviceType.FORWARDER
QUEUE: int = DeviceType.QUEUE
EAGAIN: int = Errno.EAGAIN
EFAULT: int = Errno.EFAULT
EINVAL: int = Errno.EINVAL
ENOTSUP: int = Errno.ENOTSUP
EPROTONOSUPPORT: int = Errno.EPROTONOSUPPORT
ENOBUFS: int = Errno.ENOBUFS
ENETDOWN: int = Errno.ENETDOWN
EADDRINUSE: int = Errno.EADDRINUSE
EADDRNOTAVAIL: int = Errno.EADDRNOTAVAIL
ECONNREFUSED: int = Errno.ECONNREFUSED
EINPROGRESS: int = Errno.EINPROGRESS
ENOTSOCK: int = Errno.ENOTSOCK
EMSGSIZE: int = Errno.EMSGSIZE
EAFNOSUPPORT: int = Errno.EAFNOSUPPORT
ENETUNREACH: int = Errno.ENETUNREACH
ECONNABORTED: int = Errno.ECONNABORTED
ECONNRESET: int = Errno.ECONNRESET
ENOTCONN: int = Errno.ENOTCONN
ETIMEDOUT: int = Errno.ETIMEDOUT
EHOSTUNREACH: int = Errno.EHOSTUNREACH
ENETRESET: int = Errno.ENETRESET
EFSM: int = Errno.EFSM
ENOCOMPATPROTO: int = Errno.ENOCOMPATPROTO
ETERM: int = Errno.ETERM
EMTHREAD: int = Errno.EMTHREAD
EVENT_CONNECTED: int = Event.CONNECTED
EVENT_CONNECT_DELAYED: int = Event.CONNECT_DELAYED
EVENT_CONNECT_RETRIED: int = Event.CONNECT_RETRIED
EVENT_LISTENING: int = Event.LISTENING
EVENT_BIND_FAILED: int = Event.BIND_FAILED
EVENT_ACCEPTED: int = Event.ACCEPTED
EVENT_ACCEPT_FAILED: int = Event.ACCEPT_FAILED
EVENT_CLOSED: int = Event.CLOSED
EVENT_CLOSE_FAILED: int = Event.CLOSE_FAILED
EVENT_DISCONNECTED: int = Event.DISCONNECTED
EVENT_MONITOR_STOPPED: int = Event.MONITOR_STOPPED
EVENT_ALL: int = Event.ALL
HANDSHAKE_FAILED_NO_DETAIL: int = Event.HANDSHAKE_FAILED_NO_DETAIL
HANDSHAKE_SUCCEEDED: int = Event.HANDSHAKE_SUCCEEDED
HANDSHAKE_FAILED_PROTOCOL: int = Event.HANDSHAKE_FAILED_PROTOCOL
HANDSHAKE_FAILED_AUTH: int = Event.HANDSHAKE_FAILED_AUTH
PROTOCOL_ERROR_ZMTP_UNSPECIFIED: int = Event.PROTOCOL_ERROR_ZMTP_UNSPECIFIED
PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND: int = (
Event.PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND
)
PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE: int = Event.PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE
PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE: int = Event.PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY
)
PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME: int = (
Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME
)
PROTOCOL_ERROR_ZMTP_INVALID_METADATA: int = Event.PROTOCOL_ERROR_ZMTP_INVALID_METADATA
PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC: int = Event.PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC
PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH: int = (
Event.PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH
)
PROTOCOL_ERROR_ZAP_UNSPECIFIED: int = Event.PROTOCOL_ERROR_ZAP_UNSPECIFIED
PROTOCOL_ERROR_ZAP_MALFORMED_REPLY: int = Event.PROTOCOL_ERROR_ZAP_MALFORMED_REPLY
PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID: int = Event.PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID
PROTOCOL_ERROR_ZAP_BAD_VERSION: int = Event.PROTOCOL_ERROR_ZAP_BAD_VERSION
PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE: int = (
Event.PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE
)
PROTOCOL_ERROR_ZAP_INVALID_METADATA: int = Event.PROTOCOL_ERROR_ZAP_INVALID_METADATA
PROTOCOL_ERROR_WS_UNSPECIFIED: int = Event.PROTOCOL_ERROR_WS_UNSPECIFIED
EVENT_PIPES_STATS: int = Event.PIPES_STATS
EVENT_ALL_V1: int = Event.ALL_V1
EVENT_ALL_V2: int = Event.ALL_V2
DONTWAIT: int = Flag.DONTWAIT
SNDMORE: int = Flag.SNDMORE
NOBLOCK: int = Flag.NOBLOCK
MORE: int = MessageOption.MORE
SHARED: int = MessageOption.SHARED
SRCFD: int = MessageOption.SRCFD
POLLIN: int = PollEvent.POLLIN
POLLOUT: int = PollEvent.POLLOUT
POLLERR: int = PollEvent.POLLERR
POLLPRI: int = PollEvent.POLLPRI
NULL: int = SecurityMechanism.NULL
PLAIN: int = SecurityMechanism.PLAIN
CURVE: int = SecurityMechanism.CURVE
GSSAPI: int = SecurityMechanism.GSSAPI
HWM: int = SocketOption.HWM
AFFINITY: int = SocketOption.AFFINITY
ROUTING_ID: int = SocketOption.ROUTING_ID
SUBSCRIBE: int = SocketOption.SUBSCRIBE
UNSUBSCRIBE: int = SocketOption.UNSUBSCRIBE
RATE: int = SocketOption.RATE
RECOVERY_IVL: int = SocketOption.RECOVERY_IVL
SNDBUF: int = SocketOption.SNDBUF
RCVBUF: int = SocketOption.RCVBUF
RCVMORE: int = SocketOption.RCVMORE
FD: int = SocketOption.FD
EVENTS: int = SocketOption.EVENTS
TYPE: int = SocketOption.TYPE
LINGER: int = SocketOption.LINGER
RECONNECT_IVL: int = SocketOption.RECONNECT_IVL
BACKLOG: int = SocketOption.BACKLOG
RECONNECT_IVL_MAX: int = SocketOption.RECONNECT_IVL_MAX
MAXMSGSIZE: int = SocketOption.MAXMSGSIZE
SNDHWM: int = SocketOption.SNDHWM
RCVHWM: int = SocketOption.RCVHWM
MULTICAST_HOPS: int = SocketOption.MULTICAST_HOPS
RCVTIMEO: int = SocketOption.RCVTIMEO
SNDTIMEO: int = SocketOption.SNDTIMEO
LAST_ENDPOINT: int = SocketOption.LAST_ENDPOINT
ROUTER_MANDATORY: int = SocketOption.ROUTER_MANDATORY
TCP_KEEPALIVE: int = SocketOption.TCP_KEEPALIVE
TCP_KEEPALIVE_CNT: int = SocketOption.TCP_KEEPALIVE_CNT
TCP_KEEPALIVE_IDLE: int = SocketOption.TCP_KEEPALIVE_IDLE
TCP_KEEPALIVE_INTVL: int = SocketOption.TCP_KEEPALIVE_INTVL
IMMEDIATE: int = SocketOption.IMMEDIATE
XPUB_VERBOSE: int = SocketOption.XPUB_VERBOSE
ROUTER_RAW: int = SocketOption.ROUTER_RAW
IPV6: int = SocketOption.IPV6
MECHANISM: int = SocketOption.MECHANISM
PLAIN_SERVER: int = SocketOption.PLAIN_SERVER
PLAIN_USERNAME: int = SocketOption.PLAIN_USERNAME
PLAIN_PASSWORD: int = SocketOption.PLAIN_PASSWORD
CURVE_SERVER: int = SocketOption.CURVE_SERVER
CURVE_PUBLICKEY: int = SocketOption.CURVE_PUBLICKEY
CURVE_SECRETKEY: int = SocketOption.CURVE_SECRETKEY
CURVE_SERVERKEY: int = SocketOption.CURVE_SERVERKEY
PROBE_ROUTER: int = SocketOption.PROBE_ROUTER
REQ_CORRELATE: int = SocketOption.REQ_CORRELATE
REQ_RELAXED: int = SocketOption.REQ_RELAXED
CONFLATE: int = SocketOption.CONFLATE
ZAP_DOMAIN: int = SocketOption.ZAP_DOMAIN
ROUTER_HANDOVER: int = SocketOption.ROUTER_HANDOVER
TOS: int = SocketOption.TOS
CONNECT_ROUTING_ID: int = SocketOption.CONNECT_ROUTING_ID
GSSAPI_SERVER: int = SocketOption.GSSAPI_SERVER
GSSAPI_PRINCIPAL: int = SocketOption.GSSAPI_PRINCIPAL
GSSAPI_SERVICE_PRINCIPAL: int = SocketOption.GSSAPI_SERVICE_PRINCIPAL
GSSAPI_PLAINTEXT: int = SocketOption.GSSAPI_PLAINTEXT
HANDSHAKE_IVL: int = SocketOption.HANDSHAKE_IVL
SOCKS_PROXY: int = SocketOption.SOCKS_PROXY
XPUB_NODROP: int = SocketOption.XPUB_NODROP
BLOCKY: int = SocketOption.BLOCKY
XPUB_MANUAL: int = SocketOption.XPUB_MANUAL
XPUB_WELCOME_MSG: int = SocketOption.XPUB_WELCOME_MSG
STREAM_NOTIFY: int = SocketOption.STREAM_NOTIFY
INVERT_MATCHING: int = SocketOption.INVERT_MATCHING
HEARTBEAT_IVL: int = SocketOption.HEARTBEAT_IVL
HEARTBEAT_TTL: int = SocketOption.HEARTBEAT_TTL
HEARTBEAT_TIMEOUT: int = SocketOption.HEARTBEAT_TIMEOUT
XPUB_VERBOSER: int = SocketOption.XPUB_VERBOSER
CONNECT_TIMEOUT: int = SocketOption.CONNECT_TIMEOUT
TCP_MAXRT: int = SocketOption.TCP_MAXRT
THREAD_SAFE: int = SocketOption.THREAD_SAFE
MULTICAST_MAXTPDU: int = SocketOption.MULTICAST_MAXTPDU
VMCI_BUFFER_SIZE: int = SocketOption.VMCI_BUFFER_SIZE
VMCI_BUFFER_MIN_SIZE: int = SocketOption.VMCI_BUFFER_MIN_SIZE
VMCI_BUFFER_MAX_SIZE: int = SocketOption.VMCI_BUFFER_MAX_SIZE
VMCI_CONNECT_TIMEOUT: int = SocketOption.VMCI_CONNECT_TIMEOUT
USE_FD: int = SocketOption.USE_FD
GSSAPI_PRINCIPAL_NAMETYPE: int = SocketOption.GSSAPI_PRINCIPAL_NAMETYPE
GSSAPI_SERVICE_PRINCIPAL_NAMETYPE: int = SocketOption.GSSAPI_SERVICE_PRINCIPAL_NAMETYPE
BINDTODEVICE: int = SocketOption.BINDTODEVICE
IDENTITY: int = SocketOption.IDENTITY
CONNECT_RID: int = SocketOption.CONNECT_RID
TCP_ACCEPT_FILTER: int = SocketOption.TCP_ACCEPT_FILTER
IPC_FILTER_PID: int = SocketOption.IPC_FILTER_PID
IPC_FILTER_UID: int = SocketOption.IPC_FILTER_UID
IPC_FILTER_GID: int = SocketOption.IPC_FILTER_GID
IPV4ONLY: int = SocketOption.IPV4ONLY
DELAY_ATTACH_ON_CONNECT: int = SocketOption.DELAY_ATTACH_ON_CONNECT
FAIL_UNROUTABLE: int = SocketOption.FAIL_UNROUTABLE
ROUTER_BEHAVIOR: int = SocketOption.ROUTER_BEHAVIOR
ZAP_ENFORCE_DOMAIN: int = SocketOption.ZAP_ENFORCE_DOMAIN
LOOPBACK_FASTPATH: int = SocketOption.LOOPBACK_FASTPATH
METADATA: int = SocketOption.METADATA
MULTICAST_LOOP: int = SocketOption.MULTICAST_LOOP
ROUTER_NOTIFY: int = SocketOption.ROUTER_NOTIFY
XPUB_MANUAL_LAST_VALUE: int = SocketOption.XPUB_MANUAL_LAST_VALUE
SOCKS_USERNAME: int = SocketOption.SOCKS_USERNAME
SOCKS_PASSWORD: int = SocketOption.SOCKS_PASSWORD
IN_BATCH_SIZE: int = SocketOption.IN_BATCH_SIZE
OUT_BATCH_SIZE: int = SocketOption.OUT_BATCH_SIZE
WSS_KEY_PEM: int = SocketOption.WSS_KEY_PEM
WSS_CERT_PEM: int = SocketOption.WSS_CERT_PEM
WSS_TRUST_PEM: int = SocketOption.WSS_TRUST_PEM
WSS_HOSTNAME: int = SocketOption.WSS_HOSTNAME
WSS_TRUST_SYSTEM: int = SocketOption.WSS_TRUST_SYSTEM
ONLY_FIRST_SUBSCRIBE: int = SocketOption.ONLY_FIRST_SUBSCRIBE
RECONNECT_STOP: int = SocketOption.RECONNECT_STOP
HELLO_MSG: int = SocketOption.HELLO_MSG
DISCONNECT_MSG: int = SocketOption.DISCONNECT_MSG
PRIORITY: int = SocketOption.PRIORITY
PAIR: int = SocketType.PAIR
PUB: int = SocketType.PUB
SUB: int = SocketType.SUB
REQ: int = SocketType.REQ
REP: int = SocketType.REP
DEALER: int = SocketType.DEALER
ROUTER: int = SocketType.ROUTER
PULL: int = SocketType.PULL
PUSH: int = SocketType.PUSH
XPUB: int = SocketType.XPUB
XSUB: int = SocketType.XSUB
STREAM: int = SocketType.STREAM
XREQ: int = SocketType.XREQ
XREP: int = SocketType.XREP
SERVER: int = SocketType.SERVER
CLIENT: int = SocketType.CLIENT
RADIO: int = SocketType.RADIO
DISH: int = SocketType.DISH
GATHER: int = SocketType.GATHER
SCATTER: int = SocketType.SCATTER
DGRAM: int = SocketType.DGRAM
PEER: int = SocketType.PEER
CHANNEL: int = SocketType.CHANNEL
__all__: List[str] = [
"ContextOption",
"IO_THREADS",
"MAX_SOCKETS",
"SOCKET_LIMIT",
"THREAD_PRIORITY",
"THREAD_SCHED_POLICY",
"MAX_MSGSZ",
"MSG_T_SIZE",
"THREAD_AFFINITY_CPU_ADD",
"THREAD_AFFINITY_CPU_REMOVE",
"THREAD_NAME_PREFIX",
"DeviceType",
"STREAMER",
"FORWARDER",
"QUEUE",
"Enum",
"Errno",
"EAGAIN",
"EFAULT",
"EINVAL",
"ENOTSUP",
"EPROTONOSUPPORT",
"ENOBUFS",
"ENETDOWN",
"EADDRINUSE",
"EADDRNOTAVAIL",
"ECONNREFUSED",
"EINPROGRESS",
"ENOTSOCK",
"EMSGSIZE",
"EAFNOSUPPORT",
"ENETUNREACH",
"ECONNABORTED",
"ECONNRESET",
"ENOTCONN",
"ETIMEDOUT",
"EHOSTUNREACH",
"ENETRESET",
"EFSM",
"ENOCOMPATPROTO",
"ETERM",
"EMTHREAD",
"Event",
"EVENT_CONNECTED",
"EVENT_CONNECT_DELAYED",
"EVENT_CONNECT_RETRIED",
"EVENT_LISTENING",
"EVENT_BIND_FAILED",
"EVENT_ACCEPTED",
"EVENT_ACCEPT_FAILED",
"EVENT_CLOSED",
"EVENT_CLOSE_FAILED",
"EVENT_DISCONNECTED",
"EVENT_MONITOR_STOPPED",
"EVENT_ALL",
"HANDSHAKE_FAILED_NO_DETAIL",
"HANDSHAKE_SUCCEEDED",
"HANDSHAKE_FAILED_PROTOCOL",
"HANDSHAKE_FAILED_AUTH",
"PROTOCOL_ERROR_ZMTP_UNSPECIFIED",
"PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND",
"PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE",
"PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY",
"PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME",
"PROTOCOL_ERROR_ZMTP_INVALID_METADATA",
"PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC",
"PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH",
"PROTOCOL_ERROR_ZAP_UNSPECIFIED",
"PROTOCOL_ERROR_ZAP_MALFORMED_REPLY",
"PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID",
"PROTOCOL_ERROR_ZAP_BAD_VERSION",
"PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE",
"PROTOCOL_ERROR_ZAP_INVALID_METADATA",
"PROTOCOL_ERROR_WS_UNSPECIFIED",
"EVENT_PIPES_STATS",
"EVENT_ALL_V1",
"EVENT_ALL_V2",
"Flag",
"DONTWAIT",
"SNDMORE",
"NOBLOCK",
"IntEnum",
"IntFlag",
"MessageOption",
"MORE",
"SHARED",
"SRCFD",
"PollEvent",
"POLLIN",
"POLLOUT",
"POLLERR",
"POLLPRI",
"SecurityMechanism",
"NULL",
"PLAIN",
"CURVE",
"GSSAPI",
"SocketOption",
"HWM",
"AFFINITY",
"ROUTING_ID",
"SUBSCRIBE",
"UNSUBSCRIBE",
"RATE",
"RECOVERY_IVL",
"SNDBUF",
"RCVBUF",
"RCVMORE",
"FD",
"EVENTS",
"TYPE",
"LINGER",
"RECONNECT_IVL",
"BACKLOG",
"RECONNECT_IVL_MAX",
"MAXMSGSIZE",
"SNDHWM",
"RCVHWM",
"MULTICAST_HOPS",
"RCVTIMEO",
"SNDTIMEO",
"LAST_ENDPOINT",
"ROUTER_MANDATORY",
"TCP_KEEPALIVE",
"TCP_KEEPALIVE_CNT",
"TCP_KEEPALIVE_IDLE",
"TCP_KEEPALIVE_INTVL",
"IMMEDIATE",
"XPUB_VERBOSE",
"ROUTER_RAW",
"IPV6",
"MECHANISM",
"PLAIN_SERVER",
"PLAIN_USERNAME",
"PLAIN_PASSWORD",
"CURVE_SERVER",
"CURVE_PUBLICKEY",
"CURVE_SECRETKEY",
"CURVE_SERVERKEY",
"PROBE_ROUTER",
"REQ_CORRELATE",
"REQ_RELAXED",
"CONFLATE",
"ZAP_DOMAIN",
"ROUTER_HANDOVER",
"TOS",
"CONNECT_ROUTING_ID",
"GSSAPI_SERVER",
"GSSAPI_PRINCIPAL",
"GSSAPI_SERVICE_PRINCIPAL",
"GSSAPI_PLAINTEXT",
"HANDSHAKE_IVL",
"SOCKS_PROXY",
"XPUB_NODROP",
"BLOCKY",
"XPUB_MANUAL",
"XPUB_WELCOME_MSG",
"STREAM_NOTIFY",
"INVERT_MATCHING",
"HEARTBEAT_IVL",
"HEARTBEAT_TTL",
"HEARTBEAT_TIMEOUT",
"XPUB_VERBOSER",
"CONNECT_TIMEOUT",
"TCP_MAXRT",
"THREAD_SAFE",
"MULTICAST_MAXTPDU",
"VMCI_BUFFER_SIZE",
"VMCI_BUFFER_MIN_SIZE",
"VMCI_BUFFER_MAX_SIZE",
"VMCI_CONNECT_TIMEOUT",
"USE_FD",
"GSSAPI_PRINCIPAL_NAMETYPE",
"GSSAPI_SERVICE_PRINCIPAL_NAMETYPE",
"BINDTODEVICE",
"IDENTITY",
"CONNECT_RID",
"TCP_ACCEPT_FILTER",
"IPC_FILTER_PID",
"IPC_FILTER_UID",
"IPC_FILTER_GID",
"IPV4ONLY",
"DELAY_ATTACH_ON_CONNECT",
"FAIL_UNROUTABLE",
"ROUTER_BEHAVIOR",
"ZAP_ENFORCE_DOMAIN",
"LOOPBACK_FASTPATH",
"METADATA",
"MULTICAST_LOOP",
"ROUTER_NOTIFY",
"XPUB_MANUAL_LAST_VALUE",
"SOCKS_USERNAME",
"SOCKS_PASSWORD",
"IN_BATCH_SIZE",
"OUT_BATCH_SIZE",
"WSS_KEY_PEM",
"WSS_CERT_PEM",
"WSS_TRUST_PEM",
"WSS_HOSTNAME",
"WSS_TRUST_SYSTEM",
"ONLY_FIRST_SUBSCRIBE",
"RECONNECT_STOP",
"HELLO_MSG",
"DISCONNECT_MSG",
"PRIORITY",
"SocketType",
"PAIR",
"PUB",
"SUB",
"REQ",
"REP",
"DEALER",
"ROUTER",
"PULL",
"PUSH",
"XPUB",
"XSUB",
"STREAM",
"XREQ",
"XREP",
"SERVER",
"CLIENT",
"RADIO",
"DISH",
"GATHER",
"SCATTER",
"DGRAM",
"PEER",
"CHANNEL",
]

View File

@@ -0,0 +1,188 @@
"""Decorators for running functions with context/sockets.
.. versionadded:: 15.3
Like using Contexts and Sockets as context managers, but with decorator syntax.
Context and sockets are closed at the end of the function.
For example::
from zmq.decorators import context, socket
@context()
@socket(zmq.PUSH)
def work(ctx, push):
...
"""
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
__all__ = (
'context',
'socket',
)
from functools import wraps
import zmq
class _Decorator:
'''The mini decorator factory'''
def __init__(self, target=None):
self._target = target
def __call__(self, *dec_args, **dec_kwargs):
"""
The main logic of decorator
Here is how those arguments works::
@out_decorator(*dec_args, *dec_kwargs)
def func(*wrap_args, **wrap_kwargs):
...
And in the ``wrapper``, we simply create ``self.target`` instance via
``with``::
target = self.get_target(*args, **kwargs)
with target(*dec_args, **dec_kwargs) as obj:
...
"""
kw_name, dec_args, dec_kwargs = self.process_decorator_args(
*dec_args, **dec_kwargs
)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
target = self.get_target(*args, **kwargs)
with target(*dec_args, **dec_kwargs) as obj:
# insert our object into args
if kw_name and kw_name not in kwargs:
kwargs[kw_name] = obj
elif kw_name and kw_name in kwargs:
raise TypeError(
"{}() got multiple values for"
" argument '{}'".format(func.__name__, kw_name)
)
else:
args = args + (obj,)
return func(*args, **kwargs)
return wrapper
return decorator
def get_target(self, *args, **kwargs):
"""Return the target function
Allows modifying args/kwargs to be passed.
"""
return self._target
def process_decorator_args(self, *args, **kwargs):
"""Process args passed to the decorator.
args not consumed by the decorator will be passed to the target factory
(Context/Socket constructor).
"""
kw_name = None
if isinstance(kwargs.get('name'), str):
kw_name = kwargs.pop('name')
elif len(args) >= 1 and isinstance(args[0], str):
kw_name = args[0]
args = args[1:]
return kw_name, args, kwargs
class _ContextDecorator(_Decorator):
"""Decorator subclass for Contexts"""
def __init__(self):
super().__init__(zmq.Context)
class _SocketDecorator(_Decorator):
"""Decorator subclass for sockets
Gets the context from other args.
"""
def process_decorator_args(self, *args, **kwargs):
"""Also grab context_name out of kwargs"""
kw_name, args, kwargs = super().process_decorator_args(*args, **kwargs)
self.context_name = kwargs.pop('context_name', 'context')
return kw_name, args, kwargs
def get_target(self, *args, **kwargs):
"""Get context, based on call-time args"""
context = self._get_context(*args, **kwargs)
return context.socket
def _get_context(self, *args, **kwargs):
"""
Find the ``zmq.Context`` from ``args`` and ``kwargs`` at call time.
First, if there is an keyword argument named ``context`` and it is a
``zmq.Context`` instance , we will take it.
Second, we check all the ``args``, take the first ``zmq.Context``
instance.
Finally, we will provide default Context -- ``zmq.Context.instance``
:return: a ``zmq.Context`` instance
"""
if self.context_name in kwargs:
ctx = kwargs[self.context_name]
if isinstance(ctx, zmq.Context):
return ctx
for arg in args:
if isinstance(arg, zmq.Context):
return arg
# not specified by any decorator
return zmq.Context.instance()
def context(*args, **kwargs):
"""Decorator for adding a Context to a function.
Usage::
@context()
def foo(ctx):
...
.. versionadded:: 15.3
:param str name: the keyword argument passed to decorated function
"""
return _ContextDecorator()(*args, **kwargs)
def socket(*args, **kwargs):
"""Decorator for adding a socket to a function.
Usage::
@socket(zmq.PUSH)
def foo(push):
...
.. versionadded:: 15.3
:param str name: the keyword argument passed to decorated function
:param str context_name: the keyword only argument to identify context
object
"""
return _SocketDecorator()(*args, **kwargs)

View File

@@ -0,0 +1,28 @@
"""0MQ Device classes for running in background threads or processes."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq import device
from zmq.devices import (
basedevice,
monitoredqueue,
monitoredqueuedevice,
proxydevice,
proxysteerabledevice,
)
from zmq.devices.basedevice import *
from zmq.devices.monitoredqueue import *
from zmq.devices.monitoredqueuedevice import *
from zmq.devices.proxydevice import *
from zmq.devices.proxysteerabledevice import *
__all__ = ['device']
for submod in (
basedevice,
proxydevice,
proxysteerabledevice,
monitoredqueue,
monitoredqueuedevice,
):
__all__.extend(submod.__all__) # type: ignore

View File

@@ -0,0 +1,302 @@
"""Classes for running 0MQ Devices in the background."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
from multiprocessing import Process
from threading import Thread
from typing import Any, List, Optional, Tuple
import zmq
from zmq import ETERM, QUEUE, REQ, Context, ZMQBindError, ZMQError, device
class Device:
"""A 0MQ Device to be run in the background.
You do not pass Socket instances to this, but rather Socket types::
Device(device_type, in_socket_type, out_socket_type)
For instance::
dev = Device(zmq.QUEUE, zmq.DEALER, zmq.ROUTER)
Similar to zmq.device, but socket types instead of sockets themselves are
passed, and the sockets are created in the work thread, to avoid issues
with thread safety. As a result, additional bind_{in|out} and
connect_{in|out} methods and setsockopt_{in|out} allow users to specify
connections for the sockets.
Parameters
----------
device_type : int
The 0MQ Device type
{in|out}_type : int
zmq socket types, to be passed later to context.socket(). e.g.
zmq.PUB, zmq.SUB, zmq.REQ. If out_type is < 0, then in_socket is used
for both in_socket and out_socket.
Methods
-------
bind_{in_out}(iface)
passthrough for ``{in|out}_socket.bind(iface)``, to be called in the thread
connect_{in_out}(iface)
passthrough for ``{in|out}_socket.connect(iface)``, to be called in the
thread
setsockopt_{in_out}(opt,value)
passthrough for ``{in|out}_socket.setsockopt(opt, value)``, to be called in
the thread
Attributes
----------
daemon : int
sets whether the thread should be run as a daemon
Default is true, because if it is false, the thread will not
exit unless it is killed
context_factory : callable (class attribute)
Function for creating the Context. This will be Context.instance
in ThreadDevices, and Context in ProcessDevices. The only reason
it is not instance() in ProcessDevices is that there may be a stale
Context instance already initialized, and the forked environment
should *never* try to use it.
"""
context_factory = Context.instance
"""Callable that returns a context. Typically either Context.instance or Context,
depending on whether the device should share the global instance or not.
"""
device_type: int
in_type: int
out_type: int
_in_binds: List[str]
_in_connects: List[str]
_in_sockopts: List[Tuple[int, Any]]
_out_binds: List[str]
_out_connects: List[str]
_out_sockopts: List[Tuple[int, Any]]
_random_addrs: List[str]
def __init__(
self,
device_type: int = QUEUE,
in_type: Optional[int] = None,
out_type: Optional[int] = None,
) -> None:
self.device_type = device_type
if in_type is None:
raise TypeError("in_type must be specified")
if out_type is None:
raise TypeError("out_type must be specified")
self.in_type = in_type
self.out_type = out_type
self._in_binds = []
self._in_connects = []
self._in_sockopts = []
self._out_binds = []
self._out_connects = []
self._out_sockopts = []
self._random_addrs = []
self.daemon = True
self.done = False
def bind_in(self, addr: str) -> None:
"""Enqueue ZMQ address for binding on in_socket.
See zmq.Socket.bind for details.
"""
self._in_binds.append(addr)
def bind_in_to_random_port(self, addr: str, *args, **kwargs) -> int:
"""Enqueue a random port on the given interface for binding on
in_socket.
See zmq.Socket.bind_to_random_port for details.
.. versionadded:: 18.0
"""
port = self._reserve_random_port(addr, *args, **kwargs)
self.bind_in('%s:%i' % (addr, port))
return port
def connect_in(self, addr: str) -> None:
"""Enqueue ZMQ address for connecting on in_socket.
See zmq.Socket.connect for details.
"""
self._in_connects.append(addr)
def setsockopt_in(self, opt: int, value: Any) -> None:
"""Enqueue setsockopt(opt, value) for in_socket
See zmq.Socket.setsockopt for details.
"""
self._in_sockopts.append((opt, value))
def bind_out(self, addr: str) -> None:
"""Enqueue ZMQ address for binding on out_socket.
See zmq.Socket.bind for details.
"""
self._out_binds.append(addr)
def bind_out_to_random_port(self, addr: str, *args, **kwargs) -> int:
"""Enqueue a random port on the given interface for binding on
out_socket.
See zmq.Socket.bind_to_random_port for details.
.. versionadded:: 18.0
"""
port = self._reserve_random_port(addr, *args, **kwargs)
self.bind_out('%s:%i' % (addr, port))
return port
def connect_out(self, addr: str):
"""Enqueue ZMQ address for connecting on out_socket.
See zmq.Socket.connect for details.
"""
self._out_connects.append(addr)
def setsockopt_out(self, opt: int, value: Any):
"""Enqueue setsockopt(opt, value) for out_socket
See zmq.Socket.setsockopt for details.
"""
self._out_sockopts.append((opt, value))
def _reserve_random_port(self, addr: str, *args, **kwargs) -> int:
ctx = Context()
binder = ctx.socket(REQ)
for i in range(5):
port = binder.bind_to_random_port(addr, *args, **kwargs)
new_addr = '%s:%i' % (addr, port)
if new_addr in self._random_addrs:
continue
else:
break
else:
raise ZMQBindError("Could not reserve random port.")
self._random_addrs.append(new_addr)
binder.close()
return port
def _setup_sockets(self) -> Tuple[zmq.Socket, zmq.Socket]:
ctx: zmq.Context[zmq.Socket] = self.context_factory() # type: ignore
self._context = ctx
# create the sockets
ins = ctx.socket(self.in_type)
if self.out_type < 0:
outs = ins
else:
outs = ctx.socket(self.out_type)
# set sockopts (must be done first, in case of zmq.IDENTITY)
for opt, value in self._in_sockopts:
ins.setsockopt(opt, value)
for opt, value in self._out_sockopts:
outs.setsockopt(opt, value)
for iface in self._in_binds:
ins.bind(iface)
for iface in self._out_binds:
outs.bind(iface)
for iface in self._in_connects:
ins.connect(iface)
for iface in self._out_connects:
outs.connect(iface)
return ins, outs
def run_device(self) -> None:
"""The runner method.
Do not call me directly, instead call ``self.start()``, just like a Thread.
"""
ins, outs = self._setup_sockets()
device(self.device_type, ins, outs)
def run(self) -> None:
"""wrap run_device in try/catch ETERM"""
try:
self.run_device()
except ZMQError as e:
if e.errno == ETERM:
# silence TERM errors, because this should be a clean shutdown
pass
else:
raise
finally:
self.done = True
def start(self) -> None:
"""Start the device. Override me in subclass for other launchers."""
return self.run()
def join(self, timeout: Optional[float] = None) -> None:
"""wait for me to finish, like Thread.join.
Reimplemented appropriately by subclasses."""
tic = time.time()
toc = tic
while not self.done and not (timeout is not None and toc - tic > timeout):
time.sleep(0.001)
toc = time.time()
class BackgroundDevice(Device):
"""Base class for launching Devices in background processes and threads."""
launcher: Any = None
_launch_class: Any = None
def start(self) -> None:
self.launcher = self._launch_class(target=self.run)
self.launcher.daemon = self.daemon
return self.launcher.start()
def join(self, timeout: Optional[float] = None) -> None:
return self.launcher.join(timeout=timeout)
class ThreadDevice(BackgroundDevice):
"""A Device that will be run in a background Thread.
See Device for details.
"""
_launch_class = Thread
class ProcessDevice(BackgroundDevice):
"""A Device that will be run in a background Process.
See Device for details.
"""
_launch_class = Process
context_factory = Context
"""Callable that returns a context. Typically either Context.instance or Context,
depending on whether the device should share the global instance or not.
"""
__all__ = ['Device', 'ThreadDevice', 'ProcessDevice']

View File

@@ -0,0 +1,177 @@
"""MonitoredQueue class declarations.
Authors
-------
* MinRK
* Brian Granger
"""
#
# Copyright (c) 2010 Min Ragan-Kelley, Brian Granger
#
# This file is part of pyzmq, but is derived and adapted from zmq_queue.cpp
# originally from libzmq-2.1.6, used under LGPLv3
#
# pyzmq is free software; you can redistribute it and/or modify it under
# the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# pyzmq is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
from zmq.backend.cython.libzmq cimport *
#-----------------------------------------------------------------------------
# MonitoredQueue C functions
#-----------------------------------------------------------------------------
cdef inline int _relay(void *insocket_, void *outsocket_, void *sidesocket_,
zmq_msg_t msg, zmq_msg_t side_msg, zmq_msg_t id_msg,
bint swap_ids) nogil:
cdef int rc
cdef int64_t flag_2
cdef int flag_3
cdef int flags
cdef bint more
cdef size_t flagsz
cdef void * flag_ptr
if ZMQ_VERSION_MAJOR < 3:
flagsz = sizeof (int64_t)
flag_ptr = &flag_2
else:
flagsz = sizeof (int)
flag_ptr = &flag_3
if swap_ids:# both router, must send second identity first
# recv two ids into msg, id_msg
rc = zmq_msg_recv(&msg, insocket_, 0)
if rc < 0: return rc
rc = zmq_msg_recv(&id_msg, insocket_, 0)
if rc < 0: return rc
# send second id (id_msg) first
#!!!! always send a copy before the original !!!!
rc = zmq_msg_copy(&side_msg, &id_msg)
if rc < 0: return rc
rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE)
if rc < 0: return rc
rc = zmq_msg_send(&id_msg, sidesocket_, ZMQ_SNDMORE)
if rc < 0: return rc
# send first id (msg) second
rc = zmq_msg_copy(&side_msg, &msg)
if rc < 0: return rc
rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE)
if rc < 0: return rc
rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE)
if rc < 0: return rc
while (True):
rc = zmq_msg_recv(&msg, insocket_, 0)
if rc < 0: return rc
# assert (rc == 0)
rc = zmq_getsockopt (insocket_, ZMQ_RCVMORE, flag_ptr, &flagsz)
if rc < 0: return rc
flags = 0
if ZMQ_VERSION_MAJOR < 3:
if flag_2:
flags |= ZMQ_SNDMORE
else:
if flag_3:
flags |= ZMQ_SNDMORE
# LABEL has been removed:
# rc = zmq_getsockopt (insocket_, ZMQ_RCVLABEL, flag_ptr, &flagsz)
# if flag_3:
# flags |= ZMQ_SNDLABEL
# assert (rc == 0)
rc = zmq_msg_copy(&side_msg, &msg)
if rc < 0: return rc
if flags:
rc = zmq_msg_send(&side_msg, outsocket_, flags)
if rc < 0: return rc
# only SNDMORE for side-socket
rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE)
if rc < 0: return rc
else:
rc = zmq_msg_send(&side_msg, outsocket_, 0)
if rc < 0: return rc
rc = zmq_msg_send(&msg, sidesocket_, 0)
if rc < 0: return rc
break
return rc
# the MonitoredQueue C function, adapted from zmq::queue.cpp :
cdef inline int c_monitored_queue (void *insocket_, void *outsocket_,
void *sidesocket_, zmq_msg_t *in_msg_ptr,
zmq_msg_t *out_msg_ptr, int swap_ids) nogil:
"""The actual C function for a monitored queue device.
See ``monitored_queue()`` for details.
"""
cdef zmq_msg_t msg
cdef int rc = zmq_msg_init (&msg)
cdef zmq_msg_t id_msg
rc = zmq_msg_init (&id_msg)
if rc < 0: return rc
cdef zmq_msg_t side_msg
rc = zmq_msg_init (&side_msg)
if rc < 0: return rc
cdef zmq_pollitem_t items [2]
items [0].socket = insocket_
items [0].fd = 0
items [0].events = ZMQ_POLLIN
items [0].revents = 0
items [1].socket = outsocket_
items [1].fd = 0
items [1].events = ZMQ_POLLIN
items [1].revents = 0
# I don't think sidesocket should be polled?
# items [2].socket = sidesocket_
# items [2].fd = 0
# items [2].events = ZMQ_POLLIN
# items [2].revents = 0
while (True):
# // Wait while there are either requests or replies to process.
rc = zmq_poll (&items [0], 2, -1)
if rc < 0: return rc
# // The algorithm below asumes ratio of request and replies processed
# // under full load to be 1:1. Although processing requests replies
# // first is tempting it is suspectible to DoS attacks (overloading
# // the system with unsolicited replies).
#
# // Process a request.
if (items [0].revents & ZMQ_POLLIN):
# send in_prefix to side socket
rc = zmq_msg_copy(&side_msg, in_msg_ptr)
if rc < 0: return rc
rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE)
if rc < 0: return rc
# relay the rest of the message
rc = _relay(insocket_, outsocket_, sidesocket_, msg, side_msg, id_msg, swap_ids)
if rc < 0: return rc
if (items [1].revents & ZMQ_POLLIN):
# send out_prefix to side socket
rc = zmq_msg_copy(&side_msg, out_msg_ptr)
if rc < 0: return rc
rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE)
if rc < 0: return rc
# relay the rest of the message
rc = _relay(outsocket_, insocket_, sidesocket_, msg, side_msg, id_msg, swap_ids)
if rc < 0: return rc
return rc

View File

@@ -0,0 +1,41 @@
"""pure Python monitored_queue function
For use when Cython extension is unavailable (PyPy).
Authors
-------
* MinRK
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
def _relay(ins, outs, sides, prefix, swap_ids):
msg = ins.recv_multipart()
if swap_ids:
msg[:2] = msg[:2][::-1]
outs.send_multipart(msg)
sides.send_multipart([prefix] + msg)
def monitored_queue(
in_socket, out_socket, mon_socket, in_prefix=b'in', out_prefix=b'out'
):
swap_ids = in_socket.type == zmq.ROUTER and out_socket.type == zmq.ROUTER
poller = zmq.Poller()
poller.register(in_socket, zmq.POLLIN)
poller.register(out_socket, zmq.POLLIN)
while True:
events = dict(poller.poll())
if in_socket in events:
_relay(in_socket, out_socket, mon_socket, in_prefix, swap_ids)
if out_socket in events:
_relay(out_socket, in_socket, mon_socket, out_prefix, swap_ids)
__all__ = ['monitored_queue']

View File

@@ -0,0 +1,62 @@
"""MonitoredQueue classes and functions."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq import PUB
from zmq.devices.monitoredqueue import monitored_queue
from zmq.devices.proxydevice import ProcessProxy, Proxy, ProxyBase, ThreadProxy
class MonitoredQueueBase(ProxyBase):
"""Base class for overriding methods."""
_in_prefix = b''
_out_prefix = b''
def __init__(
self, in_type, out_type, mon_type=PUB, in_prefix=b'in', out_prefix=b'out'
):
ProxyBase.__init__(self, in_type=in_type, out_type=out_type, mon_type=mon_type)
self._in_prefix = in_prefix
self._out_prefix = out_prefix
def run_device(self):
ins, outs, mons = self._setup_sockets()
monitored_queue(ins, outs, mons, self._in_prefix, self._out_prefix)
class MonitoredQueue(MonitoredQueueBase, Proxy):
"""Class for running monitored_queue in the background.
See zmq.devices.Device for most of the spec. MonitoredQueue differs from Proxy,
only in that it adds a ``prefix`` to messages sent on the monitor socket,
with a different prefix for each direction.
MQ also supports ROUTER on both sides, which zmq.proxy does not.
If a message arrives on `in_sock`, it will be prefixed with `in_prefix` on the monitor socket.
If it arrives on out_sock, it will be prefixed with `out_prefix`.
A PUB socket is the most logical choice for the mon_socket, but it is not required.
"""
class ThreadMonitoredQueue(MonitoredQueueBase, ThreadProxy):
"""Run zmq.monitored_queue in a background thread.
See MonitoredQueue and Proxy for details.
"""
class ProcessMonitoredQueue(MonitoredQueueBase, ProcessProxy):
"""Run zmq.monitored_queue in a separate process.
See MonitoredQueue and Proxy for details.
"""
__all__ = ['MonitoredQueue', 'ThreadMonitoredQueue', 'ProcessMonitoredQueue']

View File

@@ -0,0 +1,104 @@
"""Proxy classes and functions."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.devices.basedevice import Device, ProcessDevice, ThreadDevice
class ProxyBase:
"""Base class for overriding methods."""
def __init__(self, in_type, out_type, mon_type=zmq.PUB):
Device.__init__(self, in_type=in_type, out_type=out_type)
self.mon_type = mon_type
self._mon_binds = []
self._mon_connects = []
self._mon_sockopts = []
def bind_mon(self, addr):
"""Enqueue ZMQ address for binding on mon_socket.
See zmq.Socket.bind for details.
"""
self._mon_binds.append(addr)
def bind_mon_to_random_port(self, addr, *args, **kwargs):
"""Enqueue a random port on the given interface for binding on
mon_socket.
See zmq.Socket.bind_to_random_port for details.
.. versionadded:: 18.0
"""
port = self._reserve_random_port(addr, *args, **kwargs)
self.bind_mon('%s:%i' % (addr, port))
return port
def connect_mon(self, addr):
"""Enqueue ZMQ address for connecting on mon_socket.
See zmq.Socket.connect for details.
"""
self._mon_connects.append(addr)
def setsockopt_mon(self, opt, value):
"""Enqueue setsockopt(opt, value) for mon_socket
See zmq.Socket.setsockopt for details.
"""
self._mon_sockopts.append((opt, value))
def _setup_sockets(self):
ins, outs = Device._setup_sockets(self)
ctx = self._context
mons = ctx.socket(self.mon_type)
# set sockopts (must be done first, in case of zmq.IDENTITY)
for opt, value in self._mon_sockopts:
mons.setsockopt(opt, value)
for iface in self._mon_binds:
mons.bind(iface)
for iface in self._mon_connects:
mons.connect(iface)
return ins, outs, mons
def run_device(self):
ins, outs, mons = self._setup_sockets()
zmq.proxy(ins, outs, mons)
class Proxy(ProxyBase, Device):
"""Threadsafe Proxy object.
See zmq.devices.Device for most of the spec. This subclass adds a
<method>_mon version of each <method>_{in|out} method, for configuring the
monitor socket.
A Proxy is a 3-socket ZMQ Device that functions just like a
QUEUE, except each message is also sent out on the monitor socket.
A PUB socket is the most logical choice for the mon_socket, but it is not required.
"""
class ThreadProxy(ProxyBase, ThreadDevice):
"""Proxy in a Thread. See Proxy for more."""
class ProcessProxy(ProxyBase, ProcessDevice):
"""Proxy in a Process. See Proxy for more."""
__all__ = [
'Proxy',
'ThreadProxy',
'ProcessProxy',
]

View File

@@ -0,0 +1,105 @@
"""Classes for running a steerable ZMQ proxy"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.devices.proxydevice import ProcessProxy, Proxy, ThreadProxy
class ProxySteerableBase:
"""Base class for overriding methods."""
def __init__(self, in_type, out_type, mon_type=zmq.PUB, ctrl_type=None):
super().__init__(in_type=in_type, out_type=out_type, mon_type=mon_type)
self.ctrl_type = ctrl_type
self._ctrl_binds = []
self._ctrl_connects = []
self._ctrl_sockopts = []
def bind_ctrl(self, addr):
"""Enqueue ZMQ address for binding on ctrl_socket.
See zmq.Socket.bind for details.
"""
self._ctrl_binds.append(addr)
def bind_ctrl_to_random_port(self, addr, *args, **kwargs):
"""Enqueue a random port on the given interface for binding on
ctrl_socket.
See zmq.Socket.bind_to_random_port for details.
"""
port = self._reserve_random_port(addr, *args, **kwargs)
self.bind_ctrl('%s:%i' % (addr, port))
return port
def connect_ctrl(self, addr):
"""Enqueue ZMQ address for connecting on ctrl_socket.
See zmq.Socket.connect for details.
"""
self._ctrl_connects.append(addr)
def setsockopt_ctrl(self, opt, value):
"""Enqueue setsockopt(opt, value) for ctrl_socket
See zmq.Socket.setsockopt for details.
"""
self._ctrl_sockopts.append((opt, value))
def _setup_sockets(self):
ins, outs, mons = super()._setup_sockets()
ctx = self._context
ctrls = ctx.socket(self.ctrl_type)
for opt, value in self._ctrl_sockopts:
ctrls.setsockopt(opt, value)
for iface in self._ctrl_binds:
ctrls.bind(iface)
for iface in self._ctrl_connects:
ctrls.connect(iface)
return ins, outs, mons, ctrls
def run_device(self):
ins, outs, mons, ctrls = self._setup_sockets()
zmq.proxy_steerable(ins, outs, mons, ctrls)
class ProxySteerable(ProxySteerableBase, Proxy):
"""Class for running a steerable proxy in the background.
See zmq.devices.Proxy for most of the spec. If the control socket is not
NULL, the proxy supports control flow, provided by the socket.
If PAUSE is received on this socket, the proxy suspends its activities. If
RESUME is received, it goes on. If TERMINATE is received, it terminates
smoothly. If the control socket is NULL, the proxy behave exactly as if
zmq.devices.Proxy had been used.
This subclass adds a <method>_ctrl version of each <method>_{in|out}
method, for configuring the control socket.
.. versionadded:: libzmq-4.1
.. versionadded:: 18.0
"""
class ThreadProxySteerable(ProxySteerableBase, ThreadProxy):
"""ProxySteerable in a Thread. See ProxySteerable for details."""
class ProcessProxySteerable(ProxySteerableBase, ProcessProxy):
"""ProxySteerable in a Process. See ProxySteerable for details."""
__all__ = [
'ProxySteerable',
'ThreadProxySteerable',
'ProcessProxySteerable',
]

View File

@@ -0,0 +1,214 @@
"""0MQ Error classes and functions."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from errno import EINTR
from typing import Optional, Tuple, Union
class ZMQBaseError(Exception):
"""Base exception class for 0MQ errors in Python."""
class ZMQError(ZMQBaseError):
"""Wrap an errno style error.
Parameters
----------
errno : int
The ZMQ errno or None. If None, then ``zmq_errno()`` is called and
used.
msg : string
Description of the error or None.
"""
errno: Optional[int] = None
def __init__(self, errno: Optional[int] = None, msg: Optional[str] = None):
"""Wrap an errno style error.
Parameters
----------
errno : int
The ZMQ errno or None. If None, then ``zmq_errno()`` is called and
used.
msg : string
Description of the error or None.
"""
from zmq.backend import strerror, zmq_errno
if errno is None:
errno = zmq_errno()
if isinstance(errno, int):
self.errno = errno
if msg is None:
self.strerror = strerror(errno)
else:
self.strerror = msg
else:
if msg is None:
self.strerror = str(errno)
else:
self.strerror = msg
# flush signals, because there could be a SIGINT
# waiting to pounce, resulting in uncaught exceptions.
# Doing this here means getting SIGINT during a blocking
# libzmq call will raise a *catchable* KeyboardInterrupt
# PyErr_CheckSignals()
def __str__(self) -> str:
return self.strerror
def __repr__(self) -> str:
return f"{self.__class__.__name__}('{str(self)}')"
class ZMQBindError(ZMQBaseError):
"""An error for ``Socket.bind_to_random_port()``.
See Also
--------
.Socket.bind_to_random_port
"""
class NotDone(ZMQBaseError):
"""Raised when timeout is reached while waiting for 0MQ to finish with a Message
See Also
--------
.MessageTracker.wait : object for tracking when ZeroMQ is done
"""
class ContextTerminated(ZMQError):
"""Wrapper for zmq.ETERM
.. versionadded:: 13.0
"""
def __init__(self, errno="ignored", msg="ignored"):
from zmq import ETERM
super().__init__(ETERM)
class Again(ZMQError):
"""Wrapper for zmq.EAGAIN
.. versionadded:: 13.0
"""
def __init__(self, errno="ignored", msg="ignored"):
from zmq import EAGAIN
super().__init__(EAGAIN)
class InterruptedSystemCall(ZMQError, InterruptedError):
"""Wrapper for EINTR
This exception should be caught internally in pyzmq
to retry system calls, and not propagate to the user.
.. versionadded:: 14.7
"""
errno = EINTR
def __init__(self, errno="ignored", msg="ignored"):
super().__init__(EINTR)
def __str__(self):
s = super().__str__()
return s + ": This call should have been retried. Please report this to pyzmq."
def _check_rc(rc, errno=None, error_without_errno=True):
"""internal utility for checking zmq return condition
and raising the appropriate Exception class
"""
if rc == -1:
if errno is None:
from zmq.backend import zmq_errno
errno = zmq_errno()
if errno == 0 and not error_without_errno:
return
from zmq import EAGAIN, ETERM
if errno == EINTR:
raise InterruptedSystemCall(errno)
elif errno == EAGAIN:
raise Again(errno)
elif errno == ETERM:
raise ContextTerminated(errno)
else:
raise ZMQError(errno)
_zmq_version_info = None
_zmq_version = None
class ZMQVersionError(NotImplementedError):
"""Raised when a feature is not provided by the linked version of libzmq.
.. versionadded:: 14.2
"""
min_version = None
def __init__(self, min_version: str, msg: str = "Feature"):
global _zmq_version
if _zmq_version is None:
from zmq import zmq_version
_zmq_version = zmq_version()
self.msg = msg
self.min_version = min_version
self.version = _zmq_version
def __repr__(self):
return "ZMQVersionError('%s')" % str(self)
def __str__(self):
return "{} requires libzmq >= {}, have {}".format(
self.msg,
self.min_version,
self.version,
)
def _check_version(
min_version_info: Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]],
msg: str = "Feature",
):
"""Check for libzmq
raises ZMQVersionError if current zmq version is not at least min_version
min_version_info is a tuple of integers, and will be compared against zmq.zmq_version_info().
"""
global _zmq_version_info
if _zmq_version_info is None:
from zmq import zmq_version_info
_zmq_version_info = zmq_version_info()
if _zmq_version_info < min_version_info:
min_version = ".".join(str(v) for v in min_version_info)
raise ZMQVersionError(min_version, msg)
__all__ = [
"ZMQBaseError",
"ZMQBindError",
"ZMQError",
"NotDone",
"ContextTerminated",
"InterruptedSystemCall",
"Again",
"ZMQVersionError",
]

View File

@@ -0,0 +1,5 @@
"""Tornado eventloop integration for pyzmq"""
from zmq.eventloop.ioloop import IOLoop
__all__ = ['IOLoop']

View File

@@ -0,0 +1,213 @@
"""tornado IOLoop API with zmq compatibility
If you have tornado ≥ 3.0, this is a subclass of tornado's IOLoop,
otherwise we ship a minimal subset of tornado in zmq.eventloop.minitornado.
The minimal shipped version of tornado's IOLoop does not include
support for concurrent futures - this will only be available if you
have tornado ≥ 3.0.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import warnings
from typing import Tuple
from zmq import ETERM, POLLERR, POLLIN, POLLOUT, Poller, ZMQError
tornado_version: Tuple = ()
try:
import tornado
tornado_version = tornado.version_info
except (ImportError, AttributeError):
pass
from .minitornado.ioloop import PeriodicCallback, PollIOLoop
from .minitornado.log import gen_log
class DelayedCallback(PeriodicCallback):
"""Schedules the given callback to be called once.
The callback is called once, after callback_time milliseconds.
`start` must be called after the DelayedCallback is created.
The timeout is calculated from when `start` is called.
"""
def __init__(self, callback, callback_time, io_loop=None):
# PeriodicCallback require callback_time to be positive
warnings.warn(
"""DelayedCallback is deprecated.
Use loop.add_timeout instead.""",
DeprecationWarning,
)
callback_time = max(callback_time, 1e-3)
super().__init__(callback, callback_time, io_loop)
def start(self):
"""Starts the timer."""
self._running = True
self._firstrun = True
self._next_timeout = time.time() + self.callback_time / 1000.0
self.io_loop.add_timeout(self._next_timeout, self._run)
def _run(self):
if not self._running:
return
self._running = False
try:
self.callback()
except Exception:
gen_log.error("Error in delayed callback", exc_info=True)
class ZMQPoller:
"""A poller that can be used in the tornado IOLoop.
This simply wraps a regular zmq.Poller, scaling the timeout
by 1000, so that it is in seconds rather than milliseconds.
"""
def __init__(self):
self._poller = Poller()
@staticmethod
def _map_events(events):
"""translate IOLoop.READ/WRITE/ERROR event masks into zmq.POLLIN/OUT/ERR"""
z_events = 0
if events & IOLoop.READ:
z_events |= POLLIN
if events & IOLoop.WRITE:
z_events |= POLLOUT
if events & IOLoop.ERROR:
z_events |= POLLERR
return z_events
@staticmethod
def _remap_events(z_events):
"""translate zmq.POLLIN/OUT/ERR event masks into IOLoop.READ/WRITE/ERROR"""
events = 0
if z_events & POLLIN:
events |= IOLoop.READ
if z_events & POLLOUT:
events |= IOLoop.WRITE
if z_events & POLLERR:
events |= IOLoop.ERROR
return events
def register(self, fd, events):
return self._poller.register(fd, self._map_events(events))
def modify(self, fd, events):
return self._poller.modify(fd, self._map_events(events))
def unregister(self, fd):
return self._poller.unregister(fd)
def poll(self, timeout):
"""poll in seconds rather than milliseconds.
Event masks will be IOLoop.READ/WRITE/ERROR
"""
z_events = self._poller.poll(1000 * timeout)
return [(fd, self._remap_events(evt)) for (fd, evt) in z_events]
def close(self):
pass
class ZMQIOLoop(PollIOLoop):
"""ZMQ subclass of tornado's IOLoop
Minor modifications, so that .current/.instance return self
"""
_zmq_impl = ZMQPoller
def initialize(self, impl=None, **kwargs):
impl = self._zmq_impl() if impl is None else impl
super().initialize(impl=impl, **kwargs)
@classmethod
def instance(cls, *args, **kwargs):
"""Returns a global `IOLoop` instance.
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
"""
# install ZMQIOLoop as the active IOLoop implementation
# when using tornado 3
if tornado_version >= (3,):
PollIOLoop.configure(cls)
loop = PollIOLoop.instance(*args, **kwargs)
if not isinstance(loop, cls):
warnings.warn(
f"IOLoop.current expected instance of {cls!r}, got {loop!r}",
RuntimeWarning,
stacklevel=2,
)
return loop
@classmethod
def current(cls, *args, **kwargs):
"""Returns the current threads IOLoop."""
# install ZMQIOLoop as the active IOLoop implementation
# when using tornado 3
if tornado_version >= (3,):
PollIOLoop.configure(cls)
loop = PollIOLoop.current(*args, **kwargs)
if not isinstance(loop, cls):
warnings.warn(
f"IOLoop.current expected instance of {cls!r}, got {loop!r}",
RuntimeWarning,
stacklevel=2,
)
return loop
def start(self):
try:
super().start()
except ZMQError as e:
if e.errno == ETERM:
# quietly return on ETERM
pass
else:
raise
# public API name
IOLoop = ZMQIOLoop
def install():
"""set the tornado IOLoop instance with the pyzmq IOLoop.
After calling this function, tornado's IOLoop.instance() and pyzmq's
IOLoop.instance() will return the same object.
An assertion error will be raised if tornado's IOLoop has been initialized
prior to calling this function.
"""
from tornado import ioloop
# check if tornado's IOLoop is already initialized to something other
# than the pyzmq IOLoop instance:
assert (
not ioloop.IOLoop.initialized()
) or ioloop.IOLoop.instance() is IOLoop.instance(), (
"tornado IOLoop already initialized"
)
if tornado_version >= (3,):
# tornado 3 has an official API for registering new defaults, yay!
ioloop.IOLoop.configure(ZMQIOLoop)
else:
# we have to set the global instance explicitly
ioloop.IOLoop._instance = IOLoop.instance()

View File

@@ -0,0 +1,104 @@
"""Future-returning APIs for tornado coroutines.
.. seealso::
:mod:`zmq.asyncio`
"""
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import asyncio
import warnings
from typing import Any, Type
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
import zmq as _zmq
from zmq._future import _AsyncPoller, _AsyncSocket
class CancelledError(Exception):
pass
class _TornadoFuture(Future):
"""Subclass Tornado Future, reinstating cancellation."""
def cancel(self):
if self.done():
return False
self.set_exception(CancelledError())
return True
def cancelled(self):
return self.done() and isinstance(self.exception(), CancelledError)
class _CancellableTornadoTimeout:
def __init__(self, loop, timeout):
self.loop = loop
self.timeout = timeout
def cancel(self):
self.loop.remove_timeout(self.timeout)
# mixin for tornado/asyncio compatibility
class _AsyncTornado:
_Future: Type[asyncio.Future] = _TornadoFuture
_READ = IOLoop.READ
_WRITE = IOLoop.WRITE
def _default_loop(self):
return IOLoop.current()
def _call_later(self, delay, callback):
io_loop = self._get_loop()
timeout = io_loop.call_later(delay, callback)
return _CancellableTornadoTimeout(io_loop, timeout)
class Poller(_AsyncTornado, _AsyncPoller):
def _watch_raw_socket(self, loop, socket, evt, f):
"""Schedule callback for a raw socket"""
loop.add_handler(socket, lambda *args: f(), evt)
def _unwatch_raw_sockets(self, loop, *sockets):
"""Unschedule callback for a raw socket"""
for socket in sockets:
loop.remove_handler(socket)
class Socket(_AsyncTornado, _AsyncSocket):
_poller_class = Poller
Poller._socket_class = Socket
class Context(_zmq.Context[Socket]):
# avoid sharing instance with base Context class
_instance = None
io_loop = None
@staticmethod
def _socket_class(self, socket_type):
return Socket(self, socket_type)
def __init__(self: "Context", *args: Any, **kwargs: Any) -> None:
io_loop = kwargs.pop('io_loop', None)
if io_loop is not None:
warnings.warn(
f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
" The currently active loop will always be used.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,146 @@
"""tornado IOLoop API with zmq compatibility
This module is deprecated in pyzmq 17.
To use zmq with tornado,
eventloop integration is no longer required
and tornado itself should be used.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import warnings
from typing import Any
try:
import tornado
from tornado import ioloop
from tornado.log import gen_log
if not hasattr(ioloop.IOLoop, 'configurable_default'):
raise ImportError(
"Tornado too old: %s" % getattr(tornado, 'version', 'unknown')
)
except ImportError:
from .minitornado import ioloop # type: ignore
from .minitornado.log import gen_log # type: ignore
PeriodicCallback = ioloop.PeriodicCallback # type: ignore
class DelayedCallback(PeriodicCallback): # type: ignore
"""Schedules the given callback to be called once.
The callback is called once, after callback_time milliseconds.
`start` must be called after the DelayedCallback is created.
The timeout is calculated from when `start` is called.
"""
def __init__(self, callback, callback_time, io_loop=None):
# PeriodicCallback require callback_time to be positive
warnings.warn(
"""DelayedCallback is deprecated.
Use loop.add_timeout instead.""",
DeprecationWarning,
)
callback_time = max(callback_time, 1e-3)
super().__init__(callback, callback_time, io_loop)
def start(self):
"""Starts the timer."""
self._running = True
self._firstrun = True
self._next_timeout = time.time() + self.callback_time / 1000.0
self.io_loop.add_timeout(self._next_timeout, self._run)
def _run(self):
if not self._running:
return
self._running = False
try:
self.callback()
except Exception:
gen_log.error("Error in delayed callback", exc_info=True)
def _deprecated():
if _deprecated.called: # type: ignore
return
_deprecated.called = True # type: ignore
warnings.warn(
"zmq.eventloop.ioloop is deprecated in pyzmq 17."
" pyzmq now works with default tornado and asyncio eventloops.",
DeprecationWarning,
stacklevel=3,
)
_deprecated.called = False # type: ignore
_IOLoop: Any
# resolve 'true' default loop
if '.minitornado.' in ioloop.__name__:
from ._deprecated import ZMQIOLoop as _IOLoop # type: ignore
else:
_IOLoop = ioloop.IOLoop
while _IOLoop.configurable_default() is not _IOLoop:
_IOLoop = _IOLoop.configurable_default()
class ZMQIOLoop(_IOLoop):
"""DEPRECATED: No longer needed as of pyzmq-17
PyZMQ tornado integration now works with the default :mod:`tornado.ioloop.IOLoop`.
"""
def __init__(self, *args, **kwargs):
_deprecated()
# super is object, which takes no args
return super().__init__()
@classmethod
def instance(cls, *args, **kwargs):
"""Returns a global `IOLoop` instance.
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
"""
# install ZMQIOLoop as the active IOLoop implementation
# when using tornado 3
ioloop.IOLoop.configure(cls)
_deprecated()
loop = ioloop.IOLoop.instance(*args, **kwargs)
return loop
@classmethod
def current(cls, *args, **kwargs):
"""Returns the current threads IOLoop."""
# install ZMQIOLoop as the active IOLoop implementation
# when using tornado 3
ioloop.IOLoop.configure(cls)
_deprecated()
loop = ioloop.IOLoop.current(*args, **kwargs)
return loop
# public API name
IOLoop = ZMQIOLoop
def install():
"""DEPRECATED
pyzmq 17 no longer needs any special integration for tornado.
"""
_deprecated()
ioloop.IOLoop.configure(ZMQIOLoop)
# if minitornado is used, fallback on deprecated ZMQIOLoop, install implementations
if '.minitornado.' in ioloop.__name__:
from ._deprecated import IOLoop, ZMQIOLoop, install # type: ignore # noqa

View File

@@ -0,0 +1,10 @@
import warnings
class VisibleDeprecationWarning(UserWarning):
"""A DeprecationWarning that users should see."""
warnings.warn("""zmq.eventloop.minitornado is deprecated in pyzmq 14.0 and will be removed.
Install tornado itself to use zmq with the tornado IOLoop.
""",
VisibleDeprecationWarning,
stacklevel=4,
)

View File

@@ -0,0 +1,14 @@
"""pyzmq does not ship tornado's futures,
this just raises informative NotImplementedErrors to avoid having to change too much code.
"""
class NotImplementedFuture(object):
def __init__(self, *args, **kwargs):
raise NotImplementedError("pyzmq does not ship tornado's Futures, "
"install tornado >= 3.0 for future support."
)
Future = TracebackFuture = NotImplementedFuture
def is_future(x):
return isinstance(x, Future)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""minimal subset of tornado.log for zmq.eventloop.minitornado"""
import logging
app_log = logging.getLogger("tornado.application")
gen_log = logging.getLogger("tornado.general")

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of platform-specific functionality.
For each function or class described in `tornado.platform.interface`,
the appropriate platform-specific implementation exists in this module.
Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
if os.name == 'nt':
from .common import Waker
from .windows import set_close_exec
else:
from .posix import set_close_exec, Waker
try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None

View File

@@ -0,0 +1,91 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import socket
from . import interface
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe.
For use on platforms that don't have os.pipe() (or where pipes cannot
be passed to select()), but do have sockets. This includes Windows
and Jython.
"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
a.listen(1)
connect_address = a.getsockname() # assigned (host, port) pair
try:
self.writer.connect(connect_address)
break # success
except socket.error as detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
detail[0] != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.send(b"x")
except (IOError, socket.error):
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result:
break
except (IOError, socket.error):
pass
def close(self):
self.reader.close()
self.writer.close()

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Interfaces for platform-specific functionality.
This module exists primarily for documentation purposes and as base classes
for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, print_function, with_statement
def set_close_exec(fd):
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()
class Waker(object):
"""A socket-like object that can wake another thread from ``select()``.
The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
thread wants to wake up the loop, it calls `wake`. Once it has woken
up, it will call `consume` to do any necessary per-wake cleanup. When
the ``IOLoop`` is closed, it closes its waker too.
"""
def fileno(self):
"""Returns the read file descriptor for this waker.
Must be suitable for use with ``select()`` or equivalent on the
local platform.
"""
raise NotImplementedError()
def write_fileno(self):
"""Returns the write file descriptor for this waker."""
raise NotImplementedError()
def wake(self):
"""Triggers activity on the waker's file descriptor."""
raise NotImplementedError()
def consume(self):
"""Called after the listen has woken up to do any necessary cleanup."""
raise NotImplementedError()
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import fcntl
import os
from . import interface
def set_close_exec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
class Waker(interface.Waker):
def __init__(self):
r, w = os.pipe()
_set_nonblocking(r)
_set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self.reader = os.fdopen(r, "rb", 0)
self.writer = os.fdopen(w, "wb", 0)
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.write(b"x")
except IOError:
pass
def consume(self):
try:
while True:
result = self.reader.read()
if not result:
break
except IOError:
pass
def close(self):
self.reader.close()
self.writer.close()

View File

@@ -0,0 +1,20 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
from __future__ import absolute_import, division, print_function, with_statement
import ctypes
import ctypes.wintypes
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.GetLastError()

View File

@@ -0,0 +1,388 @@
#!/usr/bin/env python
#
# Copyright 2010 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""`StackContext` allows applications to maintain threadlocal-like state
that follows execution as it moves to other execution contexts.
The motivating examples are to eliminate the need for explicit
``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
allow some additional context to be kept for logging.
This is slightly magic, but it's an extension of the idea that an
exception handler is a kind of stack-local state and when that stack
is suspended and resumed in a new context that state needs to be
preserved. `StackContext` shifts the burden of restoring that state
from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
in ``async_callback``) to the mechanisms that transfer control from
one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
thread pools, etc).
Example usage::
@contextlib.contextmanager
def die_on_error():
try:
yield
except Exception:
logging.error("exception in asynchronous operation",exc_info=True)
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its descendants*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouldn't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a
stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
(for example, if you're writing a thread pool), use
`.stack_context.wrap()` before any asynchronous operations to capture the
stack context from where the operation was started.
* If you're writing an asynchronous library that has some shared
resources (such as a connection pool), create those shared resources
within a ``with stack_context.NullContext():`` block. This will prevent
``StackContexts`` from leaking from one request to another.
* If you want to write something like an exception handler that will
persist across asynchronous calls, create a new `StackContext` (or
`ExceptionStackContext`), and make your asynchronous calls in a ``with``
block that references your `StackContext`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
import threading
from .util import raise_exc_info
class StackContextInconsistentError(Exception):
pass
class _State(threading.local):
def __init__(self):
self.contexts = (tuple(), None)
_state = _State()
class StackContext(object):
"""Establishes the given context as a StackContext that will be transferred.
Note that the parameter is a callable that returns a context
manager, not the context itself. That is, where for a
non-transferable context manager you would say::
with my_context():
StackContext takes the function itself rather than its result::
with StackContext(my_context):
The result of ``with StackContext() as cb:`` is a deactivation
callback. Run this callback when the StackContext is no longer
needed to ensure that it is not propagated any further (note that
deactivating a context does not affect any instances of that
context that are currently pending). This is an advanced feature
and not necessary in most applications.
"""
def __init__(self, context_factory):
self.context_factory = context_factory
self.contexts = []
self.active = True
def _deactivate(self):
self.active = False
# StackContext protocol
def enter(self):
context = self.context_factory()
self.contexts.append(context)
context.__enter__()
def exit(self, type, value, traceback):
context = self.contexts.pop()
context.__exit__(type, value, traceback)
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0] + (self,), self)
_state.contexts = self.new_contexts
try:
self.enter()
except:
_state.contexts = self.old_contexts
raise
return self._deactivate
def __exit__(self, type, value, traceback):
try:
self.exit(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
# Generator coroutines and with-statements with non-local
# effects interact badly. Check here for signs of
# the stack getting out of sync.
# Note that this check comes after restoring _state.context
# so that if it fails things are left in a (relatively)
# consistent state.
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class ExceptionStackContext(object):
"""Specialization of StackContext for exception handling.
The supplied ``exception_handler`` function will be called in the
event of an uncaught exception in this context. The semantics are
similar to a try/finally clause, and intended use cases are to log
an error, close a socket, or similar cleanup actions. The
``exc_info`` triple ``(type, value, traceback)`` will be passed to the
exception_handler function.
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
"""
def __init__(self, exception_handler):
self.exception_handler = exception_handler
self.active = True
def _deactivate(self):
self.active = False
def exit(self, type, value, traceback):
if type is not None:
return self.exception_handler(type, value, traceback)
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0], self)
_state.contexts = self.new_contexts
return self._deactivate
def __exit__(self, type, value, traceback):
try:
if type is not None:
return self.exception_handler(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class NullContext(object):
"""Resets the `StackContext`.
Useful when creating a shared resource on demand (e.g. an
`.AsyncHTTPClient`) where the stack that caused the creating is
not relevant to future operations.
"""
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (tuple(), None)
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
def _remove_deactivated(contexts):
"""Remove deactivated handlers from the chain"""
# Clean ctx handlers
stack_contexts = tuple([h for h in contexts[0] if h.active])
# Find new head
head = contexts[1]
while head is not None and not head.active:
head = head.old_contexts[1]
# Process chain
ctx = head
while ctx is not None:
parent = ctx.old_contexts[1]
while parent is not None:
if parent.active:
break
ctx.old_contexts = parent.old_contexts
parent = parent.old_contexts[1]
ctx = parent
return (stack_contexts, head)
def wrap(fn):
"""Returns a callable object that will restore the current `StackContext`
when executed.
Use this whenever saving a callback to be executed later in a
different execution context (either in a different thread or
asynchronously in the same thread).
"""
# Check if function is already wrapped
if fn is None or hasattr(fn, '_wrapped'):
return fn
# Capture current stack head
# TODO: Any other better way to store contexts and update them in wrapped function?
cap_contexts = [_state.contexts]
if not cap_contexts[0][0] and not cap_contexts[0][1]:
# Fast path when there are no active contexts.
def null_wrapper(*args, **kwargs):
try:
current_state = _state.contexts
_state.contexts = cap_contexts[0]
return fn(*args, **kwargs)
finally:
_state.contexts = current_state
null_wrapper._wrapped = True
return null_wrapper
def wrapped(*args, **kwargs):
ret = None
try:
# Capture old state
current_state = _state.contexts
# Remove deactivated items
cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
# Force new state
_state.contexts = contexts
# Current exception
exc = (None, None, None)
top = None
# Apply stack contexts
last_ctx = 0
stack = contexts[0]
# Apply state
for n in stack:
try:
n.enter()
last_ctx += 1
except:
# Exception happened. Record exception info and store top-most handler
exc = sys.exc_info()
top = n.old_contexts[1]
# Execute callback if no exception happened while restoring state
if top is None:
try:
ret = fn(*args, **kwargs)
except:
exc = sys.exc_info()
top = contexts[1]
# If there was exception, try to handle it by going through the exception chain
if top is not None:
exc = _handle_exception(top, exc)
else:
# Otherwise take shorter path and run stack contexts in reverse order
while last_ctx > 0:
last_ctx -= 1
c = stack[last_ctx]
try:
c.exit(*exc)
except:
exc = sys.exc_info()
top = c.old_contexts[1]
break
else:
top = None
# If if exception happened while unrolling, take longer exception handler path
if top is not None:
exc = _handle_exception(top, exc)
# If exception was not handled, raise it
if exc != (None, None, None):
raise_exc_info(exc)
finally:
_state.contexts = current_state
return ret
wrapped._wrapped = True
return wrapped
def _handle_exception(tail, exc):
while tail is not None:
try:
if tail.exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
tail = tail.old_contexts[1]
return exc
def run_with_stack_context(context, func):
"""Run a coroutine ``func`` in the given `StackContext`.
It is not safe to have a ``yield`` statement within a ``with StackContext``
block, so it is difficult to use stack context with `.gen.coroutine`.
This helper function runs the function in the correct context while
keeping the ``yield`` and ``with`` statements syntactically separate.
Example::
@gen.coroutine
def incorrect():
with StackContext(ctx):
# ERROR: this will raise StackContextInconsistentError
yield other_coroutine()
@gen.coroutine
def correct():
yield run_with_stack_context(StackContext(ctx), other_coroutine)
.. versionadded:: 3.1
"""
with context:
return func()

View File

@@ -0,0 +1,216 @@
"""Miscellaneous utility functions and classes.
This module is used internally by Tornado. It is not necessarily expected
that the functions and classes defined here will be useful to other
applications, but they are documented here in case they are.
The one public-facing part of this module is the `Configurable` class
and its `~Configurable.configure` method, which becomes a part of the
interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`,
and `.Resolver`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if not isinstance(b'', type('')):
def u(s):
return s
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
# These names don't exist in py3, so use noqa comments to disable
# warnings in flake8.
unicode_type = unicode # noqa
basestring_type = basestring # noqa
def import_object(name):
"""Imports an object by name.
import_object('x') is equivalent to 'import x'.
import_object('x.y.z') is equivalent to 'from x.y import z'.
>>> import tornado.escape
>>> import_object('tornado.escape') is tornado.escape
True
>>> import_object('tornado.escape.utf8') is tornado.escape.utf8
True
>>> import_object('tornado') is tornado
True
>>> import_object('tornado.missing_module')
Traceback (most recent call last):
...
ImportError: No module named missing_module
"""
if isinstance(name, unicode_type) and str is not unicode_type:
# On python 2 a byte string is required.
name = name.encode('utf-8')
if name.count('.') == 0:
return __import__(name, None, None)
parts = name.split('.')
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
try:
return getattr(obj, parts[-1])
except AttributeError:
raise ImportError("No module named %s" % parts[-1])
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
if sys.version_info > (3,):
exec("""
def raise_exc_info(exc_info):
raise exc_info[1].with_traceback(exc_info[2])
def exec_in(code, glob, loc=None):
if isinstance(code, str):
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec(code, glob, loc)
""")
else:
exec("""
def raise_exc_info(exc_info):
raise exc_info[0], exc_info[1], exc_info[2]
def exec_in(code, glob, loc=None):
if isinstance(code, basestring):
# exec(string) inherits the caller's future imports; compile
# the string first to prevent that.
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec code in glob, loc
""")
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instantiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
class Configurable(object):
"""Base class for configurable interfaces.
A configurable interface is an (abstract) class whose constructor
acts as a factory function for one of its implementation subclasses.
The implementation subclass as well as optional keyword arguments to
its initializer can be set globally at runtime with `configure`.
By using the constructor as the factory method, the interface
looks like a normal class, `isinstance` works as usual, etc. This
pattern is most useful when the choice of implementation is likely
to be a global decision (e.g. when `~select.epoll` is available,
always use it instead of `~select.select`), or when a
previously-monolithic class has been split into specialized
subclasses.
Configurable subclasses must define the class methods
`configurable_base` and `configurable_default`, and use the instance
method `initialize` instead of ``__init__``.
"""
__impl_class = None
__impl_kwargs = None
def __new__(cls, *args, **kwargs):
base = cls.configurable_base()
init_kwargs = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
init_kwargs.update(base.__impl_kwargs)
else:
impl = cls
init_kwargs.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(*args, **init_kwargs)
return instance
@classmethod
def configurable_base(cls):
"""Returns the base class of a configurable hierarchy.
This will normally return the class in which it is defined.
(which is *not* necessarily the same as the cls classmethod parameter).
"""
raise NotImplementedError()
@classmethod
def configurable_default(cls):
"""Returns the implementation class to be used if none is configured."""
raise NotImplementedError()
def initialize(self):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
.. versionchanged:: 4.2
Now accepts positional arguments in addition to keyword arguments.
"""
@classmethod
def configure(cls, impl, **kwargs):
"""Sets the class to use when the base class is instantiated.
Keyword arguments will be saved and added to the arguments passed
to the constructor. This can be used to set global defaults for
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes)):
impl = import_object(impl)
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)
base.__impl_class = impl
base.__impl_kwargs = kwargs
@classmethod
def configured_class(cls):
"""Returns the currently configured class."""
base = cls.configurable_base()
if cls.__impl_class is None:
base.__impl_class = cls.configurable_default()
return base.__impl_class
@classmethod
def _save_configuration(cls):
base = cls.configurable_base()
return (base.__impl_class, base.__impl_kwargs)
@classmethod
def _restore_configuration(cls, saved):
base = cls.configurable_base()
base.__impl_class = saved[0]
base.__impl_kwargs = saved[1]
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)

View File

@@ -0,0 +1,665 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A utility class to send to and recv from a non-blocking socket,
using tornado.
.. seealso::
- :mod:`zmq.asyncio`
- :mod:`zmq.eventloop.future`
"""
import pickle
import sys
import warnings
from queue import Queue
from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
import zmq
from zmq._typing import Literal
from zmq.utils import jsonapi
from .ioloop import IOLoop, gen_log
try:
import tornado.ioloop
from tornado.stack_context import wrap as stack_context_wrap # type: ignore
except ImportError:
if "zmq.eventloop.minitornado" in sys.modules:
from .minitornado.stack_context import (
wrap as stack_context_wrap, # type: ignore
)
else:
# tornado 5 deprecates stack_context,
# tornado 6 removes it
def stack_context_wrap(callback):
return callback
class ZMQStream:
"""A utility class to register callbacks when a zmq socket sends and receives
For use with zmq.eventloop.ioloop
There are three main methods
Methods:
* **on_recv(callback, copy=True):**
register a callback to be run every time the socket has something to receive
* **on_send(callback):**
register a callback to be run every time you call send
* **send(self, msg, flags=0, copy=False, callback=None):**
perform a send that will trigger the callback
if callback is passed, on_send is also called.
There are also send_multipart(), send_json(), send_pyobj()
Three other methods for deactivating the callbacks:
* **stop_on_recv():**
turn off the recv callback
* **stop_on_send():**
turn off the send callback
which simply call ``on_<evt>(None)``.
The entire socket interface, excluding direct recv methods, is also
provided, primarily through direct-linking the methods.
e.g.
>>> stream.bind is stream.socket.bind
True
"""
socket: zmq.Socket
io_loop: "tornado.ioloop.IOLoop"
poller: zmq.Poller
_send_queue: Queue
_recv_callback: Optional[Callable]
_send_callback: Optional[Callable]
_close_callback = Optional[Callable]
_state: int = 0
_flushed: bool = False
_recv_copy: bool = False
_fd: int
def __init__(
self, socket: "zmq.Socket", io_loop: Optional["tornado.ioloop.IOLoop"] = None
):
self.socket = socket
self.io_loop = io_loop or IOLoop.current()
self.poller = zmq.Poller()
self._fd = cast(int, self.socket.FD)
self._send_queue = Queue()
self._recv_callback = None
self._send_callback = None
self._close_callback = None
self._recv_copy = False
self._flushed = False
self._state = 0
self._init_io_state()
# shortcircuit some socket methods
self.bind = self.socket.bind
self.bind_to_random_port = self.socket.bind_to_random_port
self.connect = self.socket.connect
self.setsockopt = self.socket.setsockopt
self.getsockopt = self.socket.getsockopt
self.setsockopt_string = self.socket.setsockopt_string
self.getsockopt_string = self.socket.getsockopt_string
self.setsockopt_unicode = self.socket.setsockopt_unicode
self.getsockopt_unicode = self.socket.getsockopt_unicode
def stop_on_recv(self):
"""Disable callback and automatic receiving."""
return self.on_recv(None)
def stop_on_send(self):
"""Disable callback on sending."""
return self.on_send(None)
def stop_on_err(self):
"""DEPRECATED, does nothing"""
gen_log.warn("on_err does nothing, and will be removed")
def on_err(self, callback: Callable):
"""DEPRECATED, does nothing"""
gen_log.warn("on_err does nothing, and will be removed")
@overload
def on_recv(
self,
callback: Callable[[List[bytes]], Any],
) -> None:
...
@overload
def on_recv(
self,
callback: Callable[[List[bytes]], Any],
copy: Literal[True],
) -> None:
...
@overload
def on_recv(
self,
callback: Callable[[List[zmq.Frame]], Any],
copy: Literal[False],
) -> None:
...
@overload
def on_recv(
self,
callback: Union[
Callable[[List[zmq.Frame]], Any],
Callable[[List[bytes]], Any],
],
copy: bool = ...,
):
...
def on_recv(
self,
callback: Union[
Callable[[List[zmq.Frame]], Any],
Callable[[List[bytes]], Any],
],
copy: bool = True,
) -> None:
"""Register a callback for when a message is ready to recv.
There can be only one callback registered at a time, so each
call to `on_recv` replaces previously registered callbacks.
on_recv(None) disables recv event polling.
Use on_recv_stream(callback) instead, to register a callback that will receive
both this ZMQStream and the message, instead of just the message.
Parameters
----------
callback : callable
callback must take exactly one argument, which will be a
list, as returned by socket.recv_multipart()
if callback is None, recv callbacks are disabled.
copy : bool
copy is passed directly to recv, so if copy is False,
callback will receive Message objects. If copy is True,
then callback will receive bytes/str objects.
Returns : None
"""
self._check_closed()
assert callback is None or callable(callback)
self._recv_callback = stack_context_wrap(callback)
self._recv_copy = copy
if callback is None:
self._drop_io_state(zmq.POLLIN)
else:
self._add_io_state(zmq.POLLIN)
@overload
def on_recv_stream(
self,
callback: Callable[["ZMQStream", List[bytes]], Any],
) -> None:
...
@overload
def on_recv_stream(
self,
callback: Callable[["ZMQStream", List[bytes]], Any],
copy: Literal[True],
) -> None:
...
@overload
def on_recv_stream(
self,
callback: Callable[["ZMQStream", List[zmq.Frame]], Any],
copy: Literal[False],
) -> None:
...
@overload
def on_recv_stream(
self,
callback: Union[
Callable[["ZMQStream", List[zmq.Frame]], Any],
Callable[["ZMQStream", List[bytes]], Any],
],
copy: bool = ...,
):
...
def on_recv_stream(
self,
callback: Union[
Callable[["ZMQStream", List[zmq.Frame]], Any],
Callable[["ZMQStream", List[bytes]], Any],
],
copy: bool = True,
):
"""Same as on_recv, but callback will get this stream as first argument
callback must take exactly two arguments, as it will be called as::
callback(stream, msg)
Useful when a single callback should be used with multiple streams.
"""
if callback is None:
self.stop_on_recv()
else:
def stream_callback(msg):
return callback(self, msg)
self.on_recv(stream_callback, copy=copy)
def on_send(
self, callback: Callable[[Sequence[Any], Optional[zmq.MessageTracker]], Any]
):
"""Register a callback to be called on each send
There will be two arguments::
callback(msg, status)
* `msg` will be the list of sendable objects that was just sent
* `status` will be the return result of socket.send_multipart(msg) -
MessageTracker or None.
Non-copying sends return a MessageTracker object whose
`done` attribute will be True when the send is complete.
This allows users to track when an object is safe to write to
again.
The second argument will always be None if copy=True
on the send.
Use on_send_stream(callback) to register a callback that will be passed
this ZMQStream as the first argument, in addition to the other two.
on_send(None) disables recv event polling.
Parameters
----------
callback : callable
callback must take exactly two arguments, which will be
the message being sent (always a list),
and the return result of socket.send_multipart(msg) -
MessageTracker or None.
if callback is None, send callbacks are disabled.
"""
self._check_closed()
assert callback is None or callable(callback)
self._send_callback = stack_context_wrap(callback)
def on_send_stream(
self,
callback: Callable[
["ZMQStream", Sequence[Any], Optional[zmq.MessageTracker]], Any
],
):
"""Same as on_send, but callback will get this stream as first argument
Callback will be passed three arguments::
callback(stream, msg, status)
Useful when a single callback should be used with multiple streams.
"""
if callback is None:
self.stop_on_send()
else:
self.on_send(lambda msg, status: callback(self, msg, status))
def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
"""Send a message, optionally also register a new callback for sends.
See zmq.socket.send for details.
"""
return self.send_multipart(
[msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
)
def send_multipart(
self,
msg: Sequence[Any],
flags: int = 0,
copy: bool = True,
track: bool = False,
callback: Callable = None,
**kwargs: Any
) -> None:
"""Send a multipart message, optionally also register a new callback for sends.
See zmq.socket.send_multipart for details.
"""
kwargs.update(dict(flags=flags, copy=copy, track=track))
self._send_queue.put((msg, kwargs))
callback = callback or self._send_callback
if callback is not None:
self.on_send(callback)
else:
# noop callback
self.on_send(lambda *args: None)
self._add_io_state(zmq.POLLOUT)
def send_string(
self,
u: str,
flags: int = 0,
encoding: str = 'utf-8',
callback: Optional[Callable] = None,
**kwargs: Any
):
"""Send a unicode message with an encoding.
See zmq.socket.send_unicode for details.
"""
if not isinstance(u, str):
raise TypeError("unicode/str objects only")
return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
send_unicode = send_string
def send_json(
self,
obj: Any,
flags: int = 0,
callback: Optional[Callable] = None,
**kwargs: Any
):
"""Send json-serialized version of an object.
See zmq.socket.send_json for details.
"""
msg = jsonapi.dumps(obj)
return self.send(msg, flags=flags, callback=callback, **kwargs)
def send_pyobj(
self,
obj: Any,
flags: int = 0,
protocol: int = -1,
callback: Optional[Callable] = None,
**kwargs: Any
):
"""Send a Python object as a message using pickle to serialize.
See zmq.socket.send_json for details.
"""
msg = pickle.dumps(obj, protocol)
return self.send(msg, flags, callback=callback, **kwargs)
def _finish_flush(self):
"""callback for unsetting _flushed flag."""
self._flushed = False
def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: Optional[int] = None):
"""Flush pending messages.
This method safely handles all pending incoming and/or outgoing messages,
bypassing the inner loop, passing them to the registered callbacks.
A limit can be specified, to prevent blocking under high load.
flush will return the first time ANY of these conditions are met:
* No more events matching the flag are pending.
* the total number of events handled reaches the limit.
Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
is registered, unlike normal IOLoop operation. This allows flush to be
used to remove *and ignore* incoming messages.
Parameters
----------
flag : int, default=POLLIN|POLLOUT
0MQ poll flags.
If flag|POLLIN, recv events will be flushed.
If flag|POLLOUT, send events will be flushed.
Both flags can be set at once, which is the default.
limit : None or int, optional
The maximum number of messages to send or receive.
Both send and recv count against this limit.
Returns
-------
int : count of events handled (both send and recv)
"""
self._check_closed()
# unset self._flushed, so callbacks will execute, in case flush has
# already been called this iteration
already_flushed = self._flushed
self._flushed = False
# initialize counters
count = 0
def update_flag():
"""Update the poll flag, to prevent registering POLLOUT events
if we don't have pending sends."""
return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
flag = update_flag()
if not flag:
# nothing to do
return 0
self.poller.register(self.socket, flag)
events = self.poller.poll(0)
while events and (not limit or count < limit):
s, event = events[0]
if event & zmq.POLLIN: # receiving
self._handle_recv()
count += 1
if self.socket is None:
# break if socket was closed during callback
break
if event & zmq.POLLOUT and self.sending():
self._handle_send()
count += 1
if self.socket is None:
# break if socket was closed during callback
break
flag = update_flag()
if flag:
self.poller.register(self.socket, flag)
events = self.poller.poll(0)
else:
events = []
if count: # only bypass loop if we actually flushed something
# skip send/recv callbacks this iteration
self._flushed = True
# reregister them at the end of the loop
if not already_flushed: # don't need to do it again
self.io_loop.add_callback(self._finish_flush)
elif already_flushed:
self._flushed = True
# update ioloop poll state, which may have changed
self._rebuild_io_state()
return count
def set_close_callback(self, callback: Optional[Callable]):
"""Call the given callback when the stream is closed."""
self._close_callback = stack_context_wrap(callback)
def close(self, linger: Optional[int] = None) -> None:
"""Close this stream."""
if self.socket is not None:
if self.socket.closed:
# fallback on raw fd for closed sockets
# hopefully this happened promptly after close,
# otherwise somebody else may have the FD
warnings.warn(
"Unregistering FD %s after closing socket. "
"This could result in unregistering handlers for the wrong socket. "
"Please use stream.close() instead of closing the socket directly."
% self._fd,
stacklevel=2,
)
self.io_loop.remove_handler(self._fd)
else:
self.io_loop.remove_handler(self.socket)
self.socket.close(linger)
self.socket = None # type: ignore
if self._close_callback:
self._run_callback(self._close_callback)
def receiving(self) -> bool:
"""Returns True if we are currently receiving from the stream."""
return self._recv_callback is not None
def sending(self) -> bool:
"""Returns True if we are currently sending to the stream."""
return not self._send_queue.empty()
def closed(self) -> bool:
if self.socket is None:
return True
if self.socket.closed:
# underlying socket has been closed, but not by us!
# trigger our cleanup
self.close()
return True
return False
def _run_callback(self, callback, *args, **kwargs):
"""Wrap running callbacks in try/except to allow us to
close our socket."""
try:
# Use a NullContext to ensure that all StackContexts are run
# inside our blanket exception handler rather than outside.
callback(*args, **kwargs)
except Exception:
gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise
def _handle_events(self, fd, events):
"""This method is the actual handler for IOLoop, that gets called whenever
an event on my socket is posted. It dispatches to _handle_recv, etc."""
if not self.socket:
gen_log.warning("Got events for closed stream %s", self)
return
try:
zmq_events = self.socket.EVENTS
except zmq.ContextTerminated:
gen_log.warning("Got events for stream %s after terminating context", self)
return
try:
# dispatch events:
if zmq_events & zmq.POLLIN and self.receiving():
self._handle_recv()
if not self.socket:
return
if zmq_events & zmq.POLLOUT and self.sending():
self._handle_send()
if not self.socket:
return
# rebuild the poll state
self._rebuild_io_state()
except Exception:
gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
raise
def _handle_recv(self):
"""Handle a recv event."""
if self._flushed:
return
try:
msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
# state changed since poll event
pass
else:
raise
else:
if self._recv_callback:
callback = self._recv_callback
self._run_callback(callback, msg)
def _handle_send(self):
"""Handle a send event."""
if self._flushed:
return
if not self.sending():
gen_log.error("Shouldn't have handled a send event")
return
msg, kwargs = self._send_queue.get()
try:
status = self.socket.send_multipart(msg, **kwargs)
except zmq.ZMQError as e:
gen_log.error("SEND Error: %s", e)
status = e
if self._send_callback:
callback = self._send_callback
self._run_callback(callback, msg, status)
def _check_closed(self):
if not self.socket:
raise OSError("Stream is closed")
def _rebuild_io_state(self):
"""rebuild io state based on self.sending() and receiving()"""
if self.socket is None:
return
state = 0
if self.receiving():
state |= zmq.POLLIN
if self.sending():
state |= zmq.POLLOUT
self._state = state
self._update_handler(state)
def _add_io_state(self, state):
"""Add io_state to poller."""
self._state = self._state | state
self._update_handler(self._state)
def _drop_io_state(self, state):
"""Stop poller from watching an io_state."""
self._state = self._state & (~state)
self._update_handler(self._state)
def _update_handler(self, state):
"""Update IOLoop handler with state."""
if self.socket is None:
return
if state & self.socket.events:
# events still exist that haven't been processed
# explicitly schedule handling to avoid missing events due to edge-triggered FDs
self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
def _init_io_state(self):
"""initialize the ioloop event handler"""
self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)

View File

@@ -0,0 +1,39 @@
# -----------------------------------------------------------------------------
# Copyright (C) 2011-2012 Travis Cline
#
# This file is part of pyzmq
# It is adapted from upstream project zeromq_gevent under the New BSD License
#
# Distributed under the terms of the New BSD License. The full license is in
# the file COPYING.BSD, distributed as part of this software.
# -----------------------------------------------------------------------------
"""zmq.green - gevent compatibility with zeromq.
Usage
-----
Instead of importing zmq directly, do so in the following manner:
..
import zmq.green as zmq
Any calls that would have blocked the current thread will now only block the
current green thread.
This compatibility is accomplished by ensuring the nonblocking flag is set
before any blocking operation and the ØMQ file descriptor is polled internally
to trigger needed events.
"""
from zmq import *
from zmq.green.core import _Context, _Socket
from zmq.green.poll import _Poller
Context = _Context # type: ignore
Socket = _Socket # type: ignore
Poller = _Poller # type: ignore
from zmq.green.device import device # type: ignore

View File

@@ -0,0 +1,320 @@
# -----------------------------------------------------------------------------
# Copyright (C) 2011-2012 Travis Cline
#
# This file is part of pyzmq
# It is adapted from upstream project zeromq_gevent under the New BSD License
#
# Distributed under the terms of the New BSD License. The full license is in
# the file COPYING.BSD, distributed as part of this software.
# -----------------------------------------------------------------------------
"""This module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq <zmq>` to be non blocking
"""
import sys
import time
import warnings
from typing import Tuple
import gevent
from gevent.event import AsyncResult
from gevent.hub import get_hub
import zmq
from zmq import Context as _original_Context
from zmq import Socket as _original_Socket
from .poll import _Poller
if hasattr(zmq, 'RCVTIMEO'):
TIMEOS: Tuple = (zmq.RCVTIMEO, zmq.SNDTIMEO)
else:
TIMEOS = ()
def _stop(evt):
"""simple wrapper for stopping an Event, allowing for method rename in gevent 1.0"""
try:
evt.stop()
except AttributeError:
# gevent<1.0 compat
evt.cancel()
class _Socket(_original_Socket):
"""Green version of :class:`zmq.Socket`
The following methods are overridden:
* send
* recv
To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or receiving
is deferred to the hub if a ``zmq.EAGAIN`` (retry) error is raised.
The `__state_changed` method is triggered when the zmq.FD for the socket is
marked as readable and triggers the necessary read and write events (which
are waited for in the recv and send methods).
Some double underscore prefixes are used to minimize pollution of
:class:`zmq.Socket`'s namespace.
"""
__in_send_multipart = False
__in_recv_multipart = False
__writable = None
__readable = None
_state_event = None
_gevent_bug_timeout = 11.6 # timeout for not trusting gevent
_debug_gevent = False # turn on if you think gevent is missing events
_poller_class = _Poller
_repr_cls = "zmq.green.Socket"
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self.__in_send_multipart = False
self.__in_recv_multipart = False
self.__setup_events()
def __del__(self):
self.close()
def close(self, linger=None):
super().close(linger)
self.__cleanup_events()
def __cleanup_events(self):
# close the _state_event event, keeps the number of active file descriptors down
if getattr(self, '_state_event', None):
_stop(self._state_event)
self._state_event = None
# if the socket has entered a close state resume any waiting greenlets
self.__writable.set()
self.__readable.set()
def __setup_events(self):
self.__readable = AsyncResult()
self.__writable = AsyncResult()
self.__readable.set()
self.__writable.set()
try:
self._state_event = get_hub().loop.io(
self.getsockopt(zmq.FD), 1
) # read state watcher
self._state_event.start(self.__state_changed)
except AttributeError:
# for gevent<1.0 compatibility
from gevent.core import read_event
self._state_event = read_event(
self.getsockopt(zmq.FD), self.__state_changed, persist=True
)
def __state_changed(self, event=None, _evtype=None):
if self.closed:
self.__cleanup_events()
return
try:
# avoid triggering __state_changed from inside __state_changed
events = super().getsockopt(zmq.EVENTS)
except zmq.ZMQError as exc:
self.__writable.set_exception(exc)
self.__readable.set_exception(exc)
else:
if events & zmq.POLLOUT:
self.__writable.set()
if events & zmq.POLLIN:
self.__readable.set()
def _wait_write(self):
assert self.__writable.ready(), "Only one greenlet can be waiting on this event"
self.__writable = AsyncResult()
# timeout is because libzmq cannot be trusted to properly signal a new send event:
# this is effectively a maximum poll interval of 1s
tic = time.time()
dt = self._gevent_bug_timeout
if dt:
timeout = gevent.Timeout(seconds=dt)
else:
timeout = None
try:
if timeout:
timeout.start()
self.__writable.get(block=True)
except gevent.Timeout as t:
if t is not timeout:
raise
toc = time.time()
# gevent bug: get can raise timeout even on clean return
# don't display zmq bug warning for gevent bug (this is getting ridiculous)
if (
self._debug_gevent
and timeout
and toc - tic > dt
and self.getsockopt(zmq.EVENTS) & zmq.POLLOUT
):
print(
"BUG: gevent may have missed a libzmq send event on %i!" % self.FD,
file=sys.stderr,
)
finally:
if timeout:
timeout.close()
self.__writable.set()
def _wait_read(self):
assert self.__readable.ready(), "Only one greenlet can be waiting on this event"
self.__readable = AsyncResult()
# timeout is because libzmq cannot always be trusted to play nice with libevent.
# I can only confirm that this actually happens for send, but lets be symmetrical
# with our dirty hacks.
# this is effectively a maximum poll interval of 1s
tic = time.time()
dt = self._gevent_bug_timeout
if dt:
timeout = gevent.Timeout(seconds=dt)
else:
timeout = None
try:
if timeout:
timeout.start()
self.__readable.get(block=True)
except gevent.Timeout as t:
if t is not timeout:
raise
toc = time.time()
# gevent bug: get can raise timeout even on clean return
# don't display zmq bug warning for gevent bug (this is getting ridiculous)
if (
self._debug_gevent
and timeout
and toc - tic > dt
and self.getsockopt(zmq.EVENTS) & zmq.POLLIN
):
print(
"BUG: gevent may have missed a libzmq recv event on %i!" % self.FD,
file=sys.stderr,
)
finally:
if timeout:
timeout.close()
self.__readable.set()
def send(self, data, flags=0, copy=True, track=False, **kwargs):
"""send, which will only block current greenlet
state_changed always fires exactly once (success or fail) at the
end of this method.
"""
# if we're given the NOBLOCK flag act as normal and let the EAGAIN get raised
if flags & zmq.NOBLOCK:
try:
msg = super().send(data, flags, copy, track, **kwargs)
finally:
if not self.__in_send_multipart:
self.__state_changed()
return msg
# ensure the zmq.NOBLOCK flag is part of flags
flags |= zmq.NOBLOCK
while (
True
): # Attempt to complete this operation indefinitely, blocking the current greenlet
try:
# attempt the actual call
msg = super().send(data, flags, copy, track)
except zmq.ZMQError as e:
# if the raised ZMQError is not EAGAIN, reraise
if e.errno != zmq.EAGAIN:
if not self.__in_send_multipart:
self.__state_changed()
raise
else:
if not self.__in_send_multipart:
self.__state_changed()
return msg
# defer to the event loop until we're notified the socket is writable
self._wait_write()
def recv(self, flags=0, copy=True, track=False):
"""recv, which will only block current greenlet
state_changed always fires exactly once (success or fail) at the
end of this method.
"""
if flags & zmq.NOBLOCK:
try:
msg = super().recv(flags, copy, track)
finally:
if not self.__in_recv_multipart:
self.__state_changed()
return msg
flags |= zmq.NOBLOCK
while True:
try:
msg = super().recv(flags, copy, track)
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
if not self.__in_recv_multipart:
self.__state_changed()
raise
else:
if not self.__in_recv_multipart:
self.__state_changed()
return msg
self._wait_read()
def send_multipart(self, *args, **kwargs):
"""wrap send_multipart to prevent state_changed on each partial send"""
self.__in_send_multipart = True
try:
msg = super().send_multipart(*args, **kwargs)
finally:
self.__in_send_multipart = False
self.__state_changed()
return msg
def recv_multipart(self, *args, **kwargs):
"""wrap recv_multipart to prevent state_changed on each partial recv"""
self.__in_recv_multipart = True
try:
msg = super().recv_multipart(*args, **kwargs)
finally:
self.__in_recv_multipart = False
self.__state_changed()
return msg
def get(self, opt):
"""trigger state_changed on getsockopt(EVENTS)"""
if opt in TIMEOS:
warnings.warn(
"TIMEO socket options have no effect in zmq.green", UserWarning
)
optval = super().get(opt)
if opt == zmq.EVENTS:
self.__state_changed()
return optval
def set(self, opt, val):
"""set socket option"""
if opt in TIMEOS:
warnings.warn(
"TIMEO socket options have no effect in zmq.green", UserWarning
)
return super().set(opt, val)
class _Context(_original_Context[_Socket]):
"""Replacement for :class:`zmq.Context`
Ensures that the greened Socket above is used in calls to `socket`.
"""
_socket_class = _Socket
_repr_cls = "zmq.green.Context"
# avoid sharing instance with base Context class
_instance = None

View File

@@ -0,0 +1,33 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.green import Poller
def device(device_type, isocket, osocket):
"""Start a zeromq device (gevent-compatible).
Unlike the true zmq.device, this does not release the GIL.
Parameters
----------
device_type : (QUEUE, FORWARDER, STREAMER)
The type of device to start (ignored).
isocket : Socket
The Socket instance for the incoming traffic.
osocket : Socket
The Socket instance for the outbound traffic.
"""
p = Poller()
if osocket == -1:
osocket = isocket
p.register(isocket, zmq.POLLIN)
p.register(osocket, zmq.POLLIN)
while True:
events = dict(p.poll())
if isocket in events:
osocket.send_multipart(isocket.recv_multipart())
if osocket in events:
isocket.send_multipart(osocket.recv_multipart())

View File

@@ -0,0 +1,3 @@
from zmq.green.eventloop.ioloop import IOLoop
__all__ = ['IOLoop']

View File

@@ -0,0 +1 @@
from zmq.eventloop.ioloop import * # noqa

View File

@@ -0,0 +1,11 @@
from zmq.eventloop import zmqstream
from zmq.green.eventloop.ioloop import IOLoop
class ZMQStream(zmqstream.ZMQStream):
def __init__(self, socket, io_loop=None):
io_loop = io_loop or IOLoop.instance()
super().__init__(socket, io_loop=io_loop)
__all__ = ["ZMQStream"]

View File

@@ -0,0 +1,99 @@
import gevent
from gevent import select
import zmq
from zmq import Poller as _original_Poller
class _Poller(_original_Poller):
"""Replacement for :class:`zmq.Poller`
Ensures that the greened Poller below is used in calls to
:meth:`zmq.Poller.poll`.
"""
_gevent_bug_timeout = 1.33 # minimum poll interval, for working around gevent bug
def _get_descriptors(self):
"""Returns three elements tuple with socket descriptors ready
for gevent.select.select
"""
rlist = []
wlist = []
xlist = []
for socket, flags in self.sockets:
if isinstance(socket, zmq.Socket):
rlist.append(socket.getsockopt(zmq.FD))
continue
elif isinstance(socket, int):
fd = socket
elif hasattr(socket, 'fileno'):
try:
fd = int(socket.fileno())
except:
raise ValueError('fileno() must return an valid integer fd')
else:
raise TypeError(
'Socket must be a 0MQ socket, an integer fd '
'or have a fileno() method: %r' % socket
)
if flags & zmq.POLLIN:
rlist.append(fd)
if flags & zmq.POLLOUT:
wlist.append(fd)
if flags & zmq.POLLERR:
xlist.append(fd)
return (rlist, wlist, xlist)
def poll(self, timeout=-1):
"""Overridden method to ensure that the green version of
Poller is used.
Behaves the same as :meth:`zmq.core.Poller.poll`
"""
if timeout is None:
timeout = -1
if timeout < 0:
timeout = -1
rlist = None
wlist = None
xlist = None
if timeout > 0:
tout = gevent.Timeout.start_new(timeout / 1000.0)
else:
tout = None
try:
# Loop until timeout or events available
rlist, wlist, xlist = self._get_descriptors()
while True:
events = super().poll(0)
if events or timeout == 0:
return events
# wait for activity on sockets in a green way
# set a minimum poll frequency,
# because gevent < 1.0 cannot be trusted to catch edge-triggered FD events
_bug_timeout = gevent.Timeout.start_new(self._gevent_bug_timeout)
try:
select.select(rlist, wlist, xlist)
except gevent.Timeout as t:
if t is not _bug_timeout:
raise
finally:
_bug_timeout.cancel()
except gevent.Timeout as t:
if t is not tout:
raise
return []
finally:
if timeout > 0:
tout.cancel()

View File

@@ -0,0 +1,132 @@
"""pyzmq log watcher.
Easily view log messages published by the PUBHandler in zmq.log.handlers
Designed to be run as an executable module - try this to see options:
python -m zmq.log -h
Subscribes to the '' (empty string) topic by default which means it will work
out-of-the-box with a PUBHandler object instantiated with default settings.
If you change the root topic with PUBHandler.setRootTopic() you must pass
the value to this script with the --topic argument.
Note that the default formats for the PUBHandler object selectively include
the log level in the message. This creates redundancy in this script as it
always prints the topic of the message, which includes the log level.
Consider overriding the default formats with PUBHandler.setFormat() to
avoid this issue.
"""
# encoding: utf-8
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import argparse
from datetime import datetime
from typing import Dict
import zmq
parser = argparse.ArgumentParser('ZMQ Log Watcher')
parser.add_argument('zmq_pub_url', type=str, help='URL to a ZMQ publisher socket.')
parser.add_argument(
'-t',
'--topic',
type=str,
default='',
help='Only receive messages that start with this topic.',
)
parser.add_argument(
'--timestamp', action='store_true', help='Append local time to the log messages.'
)
parser.add_argument(
'--separator',
type=str,
default=' | ',
help='String to print between topic and message.',
)
parser.add_argument(
'--dateformat',
type=str,
default='%Y-%d-%m %H:%M',
help='Set alternative date format for use with --timestamp.',
)
parser.add_argument(
'--align',
action='store_true',
default=False,
help='Try to align messages by the width of their topics.',
)
parser.add_argument(
'--color',
action='store_true',
default=False,
help='Color the output based on the error level. Requires the colorama module.',
)
args = parser.parse_args()
if args.color:
import colorama
colorama.init()
colors = {
'DEBUG': colorama.Fore.LIGHTCYAN_EX,
'INFO': colorama.Fore.LIGHTWHITE_EX,
'WARNING': colorama.Fore.YELLOW,
'ERROR': colorama.Fore.LIGHTRED_EX,
'CRITICAL': colorama.Fore.LIGHTRED_EX,
'__RESET__': colorama.Fore.RESET,
}
else:
colors = {}
ctx = zmq.Context()
sub = ctx.socket(zmq.SUB)
sub.subscribe(args.topic.encode("utf8"))
sub.connect(args.zmq_pub_url)
topic_widths: Dict[int, int] = {}
while True:
try:
if sub.poll(10, zmq.POLLIN):
topic, msg = sub.recv_multipart()
topics = topic.decode('utf8').strip().split('.')
if args.align:
topics.extend(' ' for extra in range(len(topics), len(topic_widths)))
aligned_parts = []
for key, part in enumerate(topics):
topic_widths[key] = max(len(part), topic_widths.get(key, 0))
fmt = ''.join(('{:<', str(topic_widths[key]), '}'))
aligned_parts.append(fmt.format(part))
if len(topics) == 1:
level = topics[0]
else:
level = topics[1]
fields = {
'msg': msg.decode('utf8').strip(),
'ts': datetime.now().strftime(args.dateformat) + ' '
if args.timestamp
else '',
'aligned': '.'.join(aligned_parts)
if args.align
else topic.decode('utf8').strip(),
'color': colors.get(level, ''),
'color_rst': colors.get('__RESET__', ''),
'sep': args.separator,
}
print('{ts}{color}{aligned}{sep}{msg}{color_rst}'.format(**fields))
except KeyboardInterrupt:
break
sub.disconnect(args.zmq_pub_url)
if args.color:
print(colorama.Fore.RESET)

View File

@@ -0,0 +1,197 @@
"""pyzmq logging handlers.
This mainly defines the PUBHandler object for publishing logging messages over
a zmq.PUB socket.
The PUBHandler can be used with the regular logging module, as in::
>>> import logging
>>> handler = PUBHandler('tcp://127.0.0.1:12345')
>>> handler.root_topic = 'foo'
>>> logger = logging.getLogger('foobar')
>>> logger.setLevel(logging.DEBUG)
>>> logger.addHandler(handler)
After this point, all messages logged by ``logger`` will be published on the
PUB socket.
Code adapted from StarCluster:
https://github.com/jtriley/StarCluster/blob/StarCluster-0.91/starcluster/logger.py
"""
import logging
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from typing import Optional, Union
import zmq
TOPIC_DELIM = "::" # delimiter for splitting topics on the receiving end.
class PUBHandler(logging.Handler):
"""A basic logging handler that emits log messages through a PUB socket.
Takes a PUB socket already bound to interfaces or an interface to bind to.
Example::
sock = context.socket(zmq.PUB)
sock.bind('inproc://log')
handler = PUBHandler(sock)
Or::
handler = PUBHandler('inproc://loc')
These are equivalent.
Log messages handled by this handler are broadcast with ZMQ topics
``this.root_topic`` comes first, followed by the log level
(DEBUG,INFO,etc.), followed by any additional subtopics specified in the
message by: log.debug("subtopic.subsub::the real message")
"""
ctx: zmq.Context
socket: zmq.Socket
def __init__(
self,
interface_or_socket: Union[str, zmq.Socket],
context: Optional[zmq.Context] = None,
root_topic: str = '',
) -> None:
logging.Handler.__init__(self)
self.root_topic = root_topic
self.formatters = {
logging.DEBUG: logging.Formatter(
"%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"
),
logging.INFO: logging.Formatter("%(message)s\n"),
logging.WARN: logging.Formatter(
"%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"
),
logging.ERROR: logging.Formatter(
"%(levelname)s %(filename)s:%(lineno)d - %(message)s - %(exc_info)s\n"
),
logging.CRITICAL: logging.Formatter(
"%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"
),
}
if isinstance(interface_or_socket, zmq.Socket):
self.socket = interface_or_socket
self.ctx = self.socket.context
else:
self.ctx = context or zmq.Context()
self.socket = self.ctx.socket(zmq.PUB)
self.socket.bind(interface_or_socket)
@property
def root_topic(self) -> str:
return self._root_topic
@root_topic.setter
def root_topic(self, value: str):
self.setRootTopic(value)
def setRootTopic(self, root_topic: str):
"""Set the root topic for this handler.
This value is prepended to all messages published by this handler, and it
defaults to the empty string ''. When you subscribe to this socket, you must
set your subscription to an empty string, or to at least the first letter of
the binary representation of this string to ensure you receive any messages
from this handler.
If you use the default empty string root topic, messages will begin with
the binary representation of the log level string (INFO, WARN, etc.).
Note that ZMQ SUB sockets can have multiple subscriptions.
"""
if isinstance(root_topic, bytes):
root_topic = root_topic.decode("utf8")
self._root_topic = root_topic
def setFormatter(self, fmt, level=logging.NOTSET):
"""Set the Formatter for this handler.
If no level is provided, the same format is used for all levels. This
will overwrite all selective formatters set in the object constructor.
"""
if level == logging.NOTSET:
for fmt_level in self.formatters.keys():
self.formatters[fmt_level] = fmt
else:
self.formatters[level] = fmt
def format(self, record):
"""Format a record."""
return self.formatters[record.levelno].format(record)
def emit(self, record):
"""Emit a log message on my socket."""
try:
topic, record.msg = record.msg.split(TOPIC_DELIM, 1)
except ValueError:
topic = ""
try:
bmsg = self.format(record).encode("utf8")
except Exception:
self.handleError(record)
return
topic_list = []
if self.root_topic:
topic_list.append(self.root_topic)
topic_list.append(record.levelname)
if topic:
topic_list.append(topic)
btopic = '.'.join(topic_list).encode("utf8")
self.socket.send_multipart([btopic, bmsg])
class TopicLogger(logging.Logger):
"""A simple wrapper that takes an additional argument to log methods.
All the regular methods exist, but instead of one msg argument, two
arguments: topic, msg are passed.
That is::
logger.debug('msg')
Would become::
logger.debug('topic.sub', 'msg')
"""
def log(self, level, topic, msg, *args, **kwargs):
"""Log 'msg % args' with level and topic.
To pass exception information, use the keyword argument exc_info
with a True value::
logger.log(level, "zmq.fun", "We have a %s",
"mysterious problem", exc_info=1)
"""
logging.Logger.log(self, level, f'{topic}::{msg}', *args, **kwargs)
# Generate the methods of TopicLogger, since they are just adding a
# topic prefix to a message.
for name in "debug warn warning error critical fatal".split():
meth = getattr(logging.Logger, name)
setattr(
TopicLogger,
name,
lambda self, level, topic, msg, *args, **kwargs: meth(
self, level, topic + TOPIC_DELIM + msg, *args, **kwargs
),
)

View File

View File

@@ -0,0 +1 @@
from zmq.ssh.tunnel import *

View File

@@ -0,0 +1,99 @@
#
# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
# Edits Copyright (C) 2010 The IPython Team
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
"""
Sample script showing how to do local port forwarding over paramiko.
This script connects to the requested SSH server and sets up local port
forwarding (the openssh -L option) from a local port through a tunneled
connection to a destination reachable from the SSH server machine.
"""
import logging
import select
import socketserver
logger = logging.getLogger('ssh')
class ForwardServer(socketserver.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True
class Handler(socketserver.BaseRequestHandler):
def handle(self):
try:
chan = self.ssh_transport.open_channel(
'direct-tcpip',
(self.chain_host, self.chain_port),
self.request.getpeername(),
)
except Exception as e:
logger.debug(
'Incoming request to %s:%d failed: %s'
% (self.chain_host, self.chain_port, repr(e))
)
return
if chan is None:
logger.debug(
'Incoming request to %s:%d was rejected by the SSH server.'
% (self.chain_host, self.chain_port)
)
return
logger.debug(
'Connected! Tunnel open %r -> %r -> %r'
% (
self.request.getpeername(),
chan.getpeername(),
(self.chain_host, self.chain_port),
)
)
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
chan.send(data)
if chan in r:
data = chan.recv(1024)
if len(data) == 0:
break
self.request.send(data)
chan.close()
self.request.close()
logger.debug('Tunnel closed ')
def forward_tunnel(local_port, remote_host, remote_port, transport):
# this is a little convoluted, but lets me configure things for the Handler
# object. (SocketServer doesn't give Handlers any way to access the outer
# server normally.)
class SubHander(Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
__all__ = ['forward_tunnel']

View File

@@ -0,0 +1,433 @@
"""Basic ssh tunnel utilities, and convenience functions for tunneling
zeromq connections.
"""
# Copyright (C) 2010-2011 IPython Development Team
# Copyright (C) 2011- PyZMQ Developers
#
# Redistributed from IPython under the terms of the BSD License.
import atexit
import os
import re
import signal
import socket
import sys
import warnings
from getpass import getpass, getuser
from multiprocessing import Process
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
import paramiko
SSHException = paramiko.ssh_exception.SSHException
except ImportError:
paramiko = None # type: ignore
class SSHException(Exception): # type: ignore
pass
else:
from .forward import forward_tunnel
try:
import pexpect
except ImportError:
pexpect = None
def select_random_ports(n):
"""Select and return n random ports that are available."""
ports = []
sockets = []
for i in range(n):
sock = socket.socket()
sock.bind(('', 0))
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports
# -----------------------------------------------------------------------------
# Check for passwordless login
# -----------------------------------------------------------------------------
_password_pat = re.compile(rb'pass(word|phrase):', re.IGNORECASE)
def try_passwordless_ssh(server, keyfile, paramiko=None):
"""Attempt to make an ssh connection without a password.
This is mainly used for requiring password input only once
when many tunnels may be connected to the same server.
If paramiko is None, the default for the platform is chosen.
"""
if paramiko is None:
paramiko = sys.platform == 'win32'
if not paramiko:
f = _try_passwordless_openssh
else:
f = _try_passwordless_paramiko
return f(server, keyfile)
def _try_passwordless_openssh(server, keyfile):
"""Try passwordless login with shell ssh command."""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko")
cmd = 'ssh -f ' + server
if keyfile:
cmd += ' -i ' + keyfile
cmd += ' exit'
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop('SSH_ASKPASS', None)
ssh_newkey = 'Are you sure you want to continue connecting'
p = pexpect.spawn(cmd, env=env)
while True:
try:
i = p.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
raise SSHException(
'The authenticity of the host can\'t be established.'
)
except pexpect.TIMEOUT:
continue
except pexpect.EOF:
return True
else:
return False
def _try_passwordless_paramiko(server, keyfile):
"""Try passwordless login with paramiko."""
if paramiko is None:
msg = "Paramiko unavailable, "
if sys.platform == 'win32':
msg += "Paramiko is required for ssh tunneled connections on Windows."
else:
msg += "use OpenSSH."
raise ImportError(msg)
username, server, port = _split_server(server)
client = paramiko.SSHClient()
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
try:
client.load_host_keys(known_hosts)
except FileNotFoundError:
pass
policy_name = os.environ.get("PYZMQ_PARAMIKO_HOST_KEY_POLICY", None)
if policy_name:
policy = getattr(paramiko, f"{policy_name}Policy")
client.set_missing_host_key_policy(policy())
try:
client.connect(
server, port, username=username, key_filename=keyfile, look_for_keys=True
)
except paramiko.AuthenticationException:
return False
else:
client.close()
return True
def tunnel_connection(
socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60
):
"""Connect a socket to an address via an ssh tunnel.
This is a wrapper for socket.connect(addr), when addr is not accessible
from the local machine. It simply creates an ssh tunnel using the remaining args,
and calls socket.connect('tcp://localhost:lport') where lport is the randomly
selected local port of the tunnel.
"""
new_url, tunnel = open_tunnel(
addr,
server,
keyfile=keyfile,
password=password,
paramiko=paramiko,
timeout=timeout,
)
socket.connect(new_url)
return tunnel
def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60):
"""Open a tunneled connection from a 0MQ url.
For use inside tunnel_connection.
Returns
-------
(url, tunnel) : (str, object)
The 0MQ url that has been forwarded, and the tunnel object
"""
lport = select_random_ports(1)[0]
transport, addr = addr.split('://')
ip, rport = addr.split(':')
rport = int(rport)
if paramiko is None:
paramiko = sys.platform == 'win32'
if paramiko:
tunnelf = paramiko_tunnel
else:
tunnelf = openssh_tunnel
tunnel = tunnelf(
lport,
rport,
server,
remoteip=ip,
keyfile=keyfile,
password=password,
timeout=timeout,
)
return 'tcp://127.0.0.1:%i' % lport, tunnel
def openssh_tunnel(
lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60
):
"""Create an ssh tunnel using command-line ssh that connects port lport
on this machine to localhost:rport on server. The tunnel
will automatically close when not in use, remaining open
for a minimum of timeout seconds for an initial connection.
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko_tunnel")
ssh = "ssh "
if keyfile:
ssh += "-i " + keyfile
if ':' in server:
server, port = server.split(':')
ssh += " -p %s" % port
cmd = f"{ssh} -O check {server}"
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")])
cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % (
ssh,
lport,
remoteip,
rport,
server,
)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1))
return pid
cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % (
ssh,
lport,
remoteip,
rport,
server,
timeout,
)
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop('SSH_ASKPASS', None)
ssh_newkey = 'Are you sure you want to continue connecting'
tunnel = pexpect.spawn(cmd, env=env)
failed = False
while True:
try:
i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
raise SSHException(
'The authenticity of the host can\'t be established.'
)
except pexpect.TIMEOUT:
continue
except pexpect.EOF:
if tunnel.exitstatus:
print(tunnel.exitstatus)
print(tunnel.before)
print(tunnel.after)
raise RuntimeError("tunnel '%s' failed to start" % (cmd))
else:
return tunnel.pid
else:
if failed:
print("Password rejected, try again")
password = None
if password is None:
password = getpass("%s's password: " % (server))
tunnel.sendline(password)
failed = True
def _stop_tunnel(cmd):
pexpect.run(cmd)
def _split_server(server):
if '@' in server:
username, server = server.split('@', 1)
else:
username = getuser()
if ':' in server:
server, port = server.split(':')
port = int(port)
else:
port = 22
return username, server, port
def paramiko_tunnel(
lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60
):
"""launch a tunner with paramiko in a subprocess. This should only be used
when shell ssh is unavailable (e.g. Windows).
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
If you are familiar with ssh tunnels, this creates the tunnel:
ssh server -L localhost:lport:remoteip:rport
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if paramiko is None:
raise ImportError("Paramiko not available")
if password is None:
if not _try_passwordless_paramiko(server, keyfile):
password = getpass("%s's password: " % (server))
p = Process(
target=_paramiko_tunnel,
args=(lport, rport, server, remoteip),
kwargs=dict(keyfile=keyfile, password=password),
)
p.daemon = True
p.start()
return p
def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
"""Function for actually starting a paramiko tunnel, to be passed
to multiprocessing.Process(target=this), and not called directly.
"""
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(
server,
port,
username=username,
key_filename=keyfile,
look_for_keys=True,
password=password,
)
# except paramiko.AuthenticationException:
# if password is None:
# password = getpass("%s@%s's password: "%(username, server))
# client.connect(server, port, username=username, password=password)
# else:
# raise
except Exception as e:
print('*** Failed to connect to %s:%d: %r' % (server, port, e))
sys.exit(1)
# Don't let SIGINT kill the tunnel subprocess
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
forward_tunnel(lport, remoteip, rport, client.get_transport())
except KeyboardInterrupt:
print('SIGINT: Port forwarding stopped cleanly')
sys.exit(0)
except Exception as e:
print("Port forwarding stopped uncleanly: %s" % e)
sys.exit(255)
if sys.platform == 'win32':
ssh_tunnel = paramiko_tunnel
else:
ssh_tunnel = openssh_tunnel
__all__ = [
'tunnel_connection',
'ssh_tunnel',
'openssh_tunnel',
'paramiko_tunnel',
'try_passwordless_ssh',
]

View File

@@ -0,0 +1,25 @@
"""pure-Python sugar wrappers for core 0MQ objects."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from zmq import error
from zmq.sugar import context, frame, poll, socket, tracker, version
__all__ = []
for submod in (context, error, frame, poll, socket, tracker, version):
__all__.extend(submod.__all__)
from zmq.error import * # noqa
from zmq.sugar.context import * # noqa
from zmq.sugar.frame import * # noqa
from zmq.sugar.poll import * # noqa
from zmq.sugar.socket import * # noqa
# deprecated:
from zmq.sugar.stopwatch import Stopwatch # noqa
from zmq.sugar.tracker import * # noqa
from zmq.sugar.version import * # noqa
__all__.append('Stopwatch')

View File

@@ -0,0 +1,10 @@
from zmq.error import *
from . import constants as constants
from .constants import *
from .context import *
from .frame import *
from .poll import *
from .socket import *
from .tracker import *
from .version import *

View File

@@ -0,0 +1,76 @@
"""Mixin for mapping set/getattr to self.set/get"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import errno
from typing import TypeVar, Union
from .. import constants
T = TypeVar("T")
OptValT = Union[str, bytes, int]
class AttributeSetter:
def __setattr__(self, key: str, value: OptValT) -> None:
"""set zmq options by attribute"""
if key in self.__dict__:
object.__setattr__(self, key, value)
return
# regular setattr only allowed for class-defined attributes
for cls in self.__class__.mro():
if key in cls.__dict__ or key in getattr(cls, "__annotations__", {}):
object.__setattr__(self, key, value)
return
upper_key = key.upper()
try:
opt = getattr(constants, upper_key)
except AttributeError:
raise AttributeError(
f"{self.__class__.__name__} has no such option: {upper_key}"
)
else:
self._set_attr_opt(upper_key, opt, value)
def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None:
"""override if setattr should do something other than call self.set"""
self.set(opt, value)
def __getattr__(self, key: str) -> OptValT:
"""get zmq options by attribute"""
upper_key = key.upper()
try:
opt = getattr(constants, upper_key)
except AttributeError:
raise AttributeError(
f"{self.__class__.__name__} has no such option: {upper_key}"
) from None
else:
from zmq import ZMQError
try:
return self._get_attr_opt(upper_key, opt)
except ZMQError as e:
# EINVAL will be raised on access for write-only attributes.
# Turn that into an AttributeError
# necessary for mocking
if e.errno in {errno.EINVAL, errno.EFAULT}:
raise AttributeError(f"{key} attribute is write-only")
else:
raise
def _get_attr_opt(self, name, opt) -> OptValT:
"""override if getattr should do something other than call self.get"""
return self.get(opt)
def get(self, opt: int) -> OptValT:
pass
def set(self, opt: int, val: OptValT) -> None:
pass
__all__ = ['AttributeSetter']

View File

@@ -0,0 +1,320 @@
"""Python bindings for 0MQ."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import atexit
import os
import warnings
from threading import Lock
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from weakref import WeakSet
from zmq.backend import Context as ContextBase
from zmq.constants import ContextOption, Errno, SocketOption
from zmq.error import ZMQError
from .attrsettr import AttributeSetter, OptValT
from .socket import Socket
# notice when exiting, to avoid triggering term on exit
_exiting = False
def _notice_atexit() -> None:
global _exiting
_exiting = True
atexit.register(_notice_atexit)
T = TypeVar('T', bound='Context')
ST = TypeVar('ST', bound='Socket', covariant=True)
class Context(ContextBase, AttributeSetter, Generic[ST]):
"""Create a zmq Context
A zmq Context creates sockets via its ``ctx.socket`` method.
"""
sockopts: Dict[int, Any]
_instance: Any = None
_instance_lock = Lock()
_instance_pid: Optional[int] = None
_shadow = False
_sockets: WeakSet
# mypy doesn't like a default value here
_socket_class: Type[ST] = Socket # type: ignore
def __init__(self: "Context[Socket]", io_threads: int = 1, **kwargs: Any) -> None:
super().__init__(io_threads=io_threads, **kwargs)
if kwargs.get('shadow', False):
self._shadow = True
else:
self._shadow = False
self.sockopts = {}
self._sockets = WeakSet()
def __del__(self) -> None:
"""deleting a Context should terminate it, without trying non-threadsafe destroy"""
# Calling locals() here conceals issue #1167 on Windows CPython 3.5.4.
locals()
if not self._shadow and not _exiting and not self.closed:
warnings.warn(
f"unclosed context {self}",
ResourceWarning,
stacklevel=2,
source=self,
)
self.term()
_repr_cls = "zmq.Context"
def __repr__(self) -> str:
cls = self.__class__
# look up _repr_cls on exact class, not inherited
_repr_cls = cls.__dict__.get("_repr_cls", None)
if _repr_cls is None:
_repr_cls = f"{cls.__module__}.{cls.__name__}"
closed = ' closed' if self.closed else ''
if getattr(self, "_sockets", None):
n_sockets = len(self._sockets)
s = 's' if n_sockets > 1 else ''
sockets = f"{n_sockets} socket{s}"
else:
sockets = ""
return f"<{_repr_cls}({sockets}) at {hex(id(self))}{closed}>"
def __enter__(self: T) -> T:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.term()
def __copy__(self: T, memo: Any = None) -> T:
"""Copying a Context creates a shadow copy"""
return self.__class__.shadow(self.underlying)
__deepcopy__ = __copy__
@classmethod
def shadow(cls: Type[T], address: int) -> T:
"""Shadow an existing libzmq context
address is the integer address of the libzmq context
or an FFI pointer to it.
.. versionadded:: 14.1
"""
from zmq.utils.interop import cast_int_addr
address = cast_int_addr(address)
return cls(shadow=address)
@classmethod
def shadow_pyczmq(cls: Type[T], ctx: Any) -> T:
"""Shadow an existing pyczmq context
ctx is the FFI `zctx_t *` pointer
.. versionadded:: 14.1
"""
from pyczmq import zctx # type: ignore
from zmq.utils.interop import cast_int_addr
underlying = zctx.underlying(ctx)
address = cast_int_addr(underlying)
return cls(shadow=address)
# static method copied from tornado IOLoop.instance
@classmethod
def instance(cls: Type[T], io_threads: int = 1) -> T:
"""Returns a global Context instance.
Most single-threaded applications have a single, global Context.
Use this method instead of passing around Context instances
throughout your code.
A common pattern for classes that depend on Contexts is to use
a default argument to enable programs with multiple Contexts
but not require the argument for simpler applications::
class MyClass(object):
def __init__(self, context=None):
self.context = context or Context.instance()
.. versionchanged:: 18.1
When called in a subprocess after forking,
a new global instance is created instead of inheriting
a Context that won't work from the parent process.
"""
if (
cls._instance is None
or cls._instance_pid != os.getpid()
or cls._instance.closed
):
with cls._instance_lock:
if (
cls._instance is None
or cls._instance_pid != os.getpid()
or cls._instance.closed
):
cls._instance = cls(io_threads=io_threads)
cls._instance_pid = os.getpid()
return cls._instance
def term(self) -> None:
"""Close or terminate the context.
Context termination is performed in the following steps:
- Any blocking operations currently in progress on sockets open within context shall
raise :class:`zmq.ContextTerminated`.
With the exception of socket.close(), any further operations on sockets open within this context
shall raise :class:`zmq.ContextTerminated`.
- After interrupting all blocking calls, term shall block until the following conditions are satisfied:
- All sockets open within context have been closed.
- For each socket within context, all messages sent on the socket have either been
physically transferred to a network peer,
or the socket's linger period set with the zmq.LINGER socket option has expired.
For further details regarding socket linger behaviour refer to libzmq documentation for ZMQ_LINGER.
This can be called to close the context by hand. If this is not called,
the context will automatically be closed when it is garbage collected.
"""
super().term()
# -------------------------------------------------------------------------
# Hooks for ctxopt completion
# -------------------------------------------------------------------------
def __dir__(self) -> List[str]:
keys = dir(self.__class__)
keys.extend(ContextOption.__members__)
return keys
# -------------------------------------------------------------------------
# Creating Sockets
# -------------------------------------------------------------------------
def _add_socket(self, socket: Any) -> None:
"""Add a weakref to a socket for Context.destroy / reference counting"""
self._sockets.add(socket)
def _rm_socket(self, socket: Any) -> None:
"""Remove a socket for Context.destroy / reference counting"""
# allow _sockets to be None in case of process teardown
if getattr(self, "_sockets", None) is not None:
self._sockets.discard(socket)
def destroy(self, linger: Optional[float] = None) -> None:
"""Close all sockets associated with this context and then terminate
the context.
.. warning::
destroy involves calling ``zmq_close()``, which is **NOT** threadsafe.
If there are active sockets in other threads, this must not be called.
Parameters
----------
linger : int, optional
If specified, set LINGER on sockets prior to closing them.
"""
if self.closed:
return
sockets = self._sockets
self._sockets = WeakSet()
for s in sockets:
if s and not s.closed:
if linger is not None:
s.setsockopt(SocketOption.LINGER, linger)
s.close()
self.term()
def socket(self: T, socket_type: int, **kwargs: Any) -> ST:
"""Create a Socket associated with this Context.
Parameters
----------
socket_type : int
The socket type, which can be any of the 0MQ socket types:
REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc.
kwargs:
will be passed to the __init__ method of the socket class.
"""
if self.closed:
raise ZMQError(Errno.ENOTSUP)
s: ST = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame
self, socket_type, **kwargs
)
for opt, value in self.sockopts.items():
try:
s.setsockopt(opt, value)
except ZMQError:
# ignore ZMQErrors, which are likely for socket options
# that do not apply to a particular socket type, e.g.
# SUBSCRIBE for non-SUB sockets.
pass
self._add_socket(s)
return s
def setsockopt(self, opt: int, value: Any) -> None:
"""set default socket options for new sockets created by this Context
.. versionadded:: 13.0
"""
self.sockopts[opt] = value
def getsockopt(self, opt: int) -> OptValT:
"""get default socket options for new sockets created by this Context
.. versionadded:: 13.0
"""
return self.sockopts[opt]
def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None:
"""set default sockopts as attributes"""
if name in ContextOption.__members__:
return self.set(opt, value)
elif name in SocketOption.__members__:
self.sockopts[opt] = value
else:
raise AttributeError(f"No such context or socket option: {name}")
def _get_attr_opt(self, name: str, opt: int) -> OptValT:
"""get default sockopts as attributes"""
if name in ContextOption.__members__:
return self.get(opt)
else:
if opt not in self.sockopts:
raise AttributeError(name)
else:
return self.sockopts[opt]
def __delattr__(self, key: str) -> None:
"""delete default sockopts as attributes"""
key = key.upper()
try:
opt = getattr(SocketOption, key)
except AttributeError:
raise AttributeError(f"No such socket option: {key!r}")
else:
if opt not in self.sockopts:
raise AttributeError(key)
else:
del self.sockopts[opt]
__all__ = ['Context']

View File

@@ -0,0 +1,106 @@
"""0MQ Frame pure Python methods."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.backend import Frame as FrameBase
from .attrsettr import AttributeSetter
def _draft(v, feature):
zmq.error._check_version(v, feature)
if not zmq.DRAFT_API:
raise RuntimeError(
"libzmq and pyzmq must be built with draft support for %s" % feature
)
class Frame(FrameBase, AttributeSetter):
"""Frame(data=None, track=False, copy=None, copy_threshold=zmq.COPY_THRESHOLD)
A zmq message Frame class for non-copying send/recvs and access to message properties.
A ``zmq.Frame`` wraps an underlying ``zmq_msg_t``.
Message *properties* can be accessed by treating a Frame like a dictionary (``frame["User-Id"]``).
.. versionadded:: 14.4, libzmq 4
Frames created by ``recv(copy=False)`` can be used to access message properties and attributes,
such as the CURVE User-Id.
For example::
frames = socket.recv_multipart(copy=False)
user_id = frames[0]["User-Id"]
This class is used if you want to do non-copying send and recvs.
When you pass a chunk of bytes to this class, e.g. ``Frame(buf)``, the
ref-count of `buf` is increased by two: once because the Frame saves `buf` as
an instance attribute and another because a ZMQ message is created that
points to the buffer of `buf`. This second ref-count increase makes sure
that `buf` lives until all messages that use it have been sent.
Once 0MQ sends all the messages and it doesn't need the buffer of ``buf``,
0MQ will call ``Py_DECREF(s)``.
Parameters
----------
data : object, optional
any object that provides the buffer interface will be used to
construct the 0MQ message data.
track : bool [default: False]
whether a MessageTracker_ should be created to track this object.
Tracking a message has a cost at creation, because it creates a threadsafe
Event object.
copy : bool [default: use copy_threshold]
Whether to create a copy of the data to pass to libzmq
or share the memory with libzmq.
If unspecified, copy_threshold is used.
copy_threshold: int [default: zmq.COPY_THRESHOLD]
If copy is unspecified, messages smaller than this many bytes
will be copied and messages larger than this will be shared with libzmq.
"""
def __getitem__(self, key):
# map Frame['User-Id'] to Frame.get('User-Id')
return self.get(key)
@property
def group(self):
"""The RADIO-DISH group of the message.
Requires libzmq >= 4.2 and pyzmq built with draft APIs enabled.
.. versionadded:: 17
"""
_draft((4, 2), "RADIO-DISH")
return self.get('group')
@group.setter
def group(self, group):
_draft((4, 2), "RADIO-DISH")
self.set('group', group)
@property
def routing_id(self):
"""The CLIENT-SERVER routing id of the message.
Requires libzmq >= 4.2 and pyzmq built with draft APIs enabled.
.. versionadded:: 17
"""
_draft((4, 2), "CLIENT-SERVER")
return self.get('routing_id')
@routing_id.setter
def routing_id(self, routing_id):
_draft((4, 2), "CLIENT-SERVER")
self.set('routing_id', routing_id)
# keep deprecated alias
Message = Frame
__all__ = ['Frame', 'Message']

View File

@@ -0,0 +1,164 @@
"""0MQ polling related functions and classes."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from typing import Any, Dict, List, Optional, Tuple
from zmq.backend import zmq_poll
from zmq.constants import POLLERR, POLLIN, POLLOUT
# -----------------------------------------------------------------------------
# Polling related methods
# -----------------------------------------------------------------------------
class Poller:
"""A stateful poll interface that mirrors Python's built-in poll."""
sockets: List[Tuple[Any, int]]
_map: Dict
def __init__(self) -> None:
self.sockets = []
self._map = {}
def __contains__(self, socket: Any) -> bool:
return socket in self._map
def register(self, socket: Any, flags: int = POLLIN | POLLOUT):
"""p.register(socket, flags=POLLIN|POLLOUT)
Register a 0MQ socket or native fd for I/O monitoring.
register(s,0) is equivalent to unregister(s).
Parameters
----------
socket : zmq.Socket or native socket
A zmq.Socket or any Python object having a ``fileno()``
method that returns a valid file descriptor.
flags : int
The events to watch for. Can be POLLIN, POLLOUT or POLLIN|POLLOUT.
If `flags=0`, socket will be unregistered.
"""
if flags:
if socket in self._map:
idx = self._map[socket]
self.sockets[idx] = (socket, flags)
else:
idx = len(self.sockets)
self.sockets.append((socket, flags))
self._map[socket] = idx
elif socket in self._map:
# uregister sockets registered with no events
self.unregister(socket)
else:
# ignore new sockets with no events
pass
def modify(self, socket, flags=POLLIN | POLLOUT):
"""Modify the flags for an already registered 0MQ socket or native fd."""
self.register(socket, flags)
def unregister(self, socket: Any):
"""Remove a 0MQ socket or native fd for I/O monitoring.
Parameters
----------
socket : Socket
The socket instance to stop polling.
"""
idx = self._map.pop(socket)
self.sockets.pop(idx)
# shift indices after deletion
for socket, flags in self.sockets[idx:]:
self._map[socket] -= 1
def poll(self, timeout: Optional[int] = None) -> List[Tuple[Any, int]]:
"""Poll the registered 0MQ or native fds for I/O.
If there are currently events ready to be processed, this function will return immediately.
Otherwise, this function will return as soon the first event is available or after timeout
milliseconds have elapsed.
Parameters
----------
timeout : int
The timeout in milliseconds. If None, no `timeout` (infinite). This
is in milliseconds to be compatible with ``select.poll()``.
Returns
-------
events : list of tuples
The list of events that are ready to be processed.
This is a list of tuples of the form ``(socket, event_mask)``, where the 0MQ Socket
or integer fd is the first element, and the poll event mask (POLLIN, POLLOUT) is the second.
It is common to call ``events = dict(poller.poll())``,
which turns the list of tuples into a mapping of ``socket : event_mask``.
"""
if timeout is None or timeout < 0:
timeout = -1
elif isinstance(timeout, float):
timeout = int(timeout)
return zmq_poll(self.sockets, timeout=timeout)
def select(rlist: List, wlist: List, xlist: List, timeout: Optional[float] = None):
"""select(rlist, wlist, xlist, timeout=None) -> (rlist, wlist, xlist)
Return the result of poll as a lists of sockets ready for r/w/exception.
This has the same interface as Python's built-in ``select.select()`` function.
Parameters
----------
timeout : float, int, optional
The timeout in seconds. If None, no timeout (infinite). This is in seconds to be
compatible with ``select.select()``.
rlist : list of sockets/FDs
sockets/FDs to be polled for read events
wlist : list of sockets/FDs
sockets/FDs to be polled for write events
xlist : list of sockets/FDs
sockets/FDs to be polled for error events
Returns
-------
(rlist, wlist, xlist) : tuple of lists of sockets (length 3)
Lists correspond to sockets available for read/write/error events respectively.
"""
if timeout is None:
timeout = -1
# Convert from sec -> ms for zmq_poll.
# zmq_poll accepts 3.x style timeout in ms
timeout = int(timeout * 1000.0)
if timeout < 0:
timeout = -1
sockets = []
for s in set(rlist + wlist + xlist):
flags = 0
if s in rlist:
flags |= POLLIN
if s in wlist:
flags |= POLLOUT
if s in xlist:
flags |= POLLERR
sockets.append((s, flags))
return_sockets = zmq_poll(sockets, timeout)
rlist, wlist, xlist = [], [], []
for s, flags in return_sockets:
if flags & POLLIN:
rlist.append(s)
if flags & POLLOUT:
wlist.append(s)
if flags & POLLERR:
xlist.append(s)
return rlist, wlist, xlist
# -----------------------------------------------------------------------------
# Symbols to export
# -----------------------------------------------------------------------------
__all__ = ['Poller', 'select']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
"""Deprecated Stopwatch implementation"""
# Copyright (c) PyZMQ Development Team.
# Distributed under the terms of the Modified BSD License.
class Stopwatch:
"""Deprecated zmq.Stopwatch implementation
You can use Python's builtin timers (time.monotonic, etc.).
"""
def __init__(self):
import warnings
warnings.warn(
"zmq.Stopwatch is deprecated. Use stdlib time.monotonic and friends instead",
DeprecationWarning,
stacklevel=2,
)
self._start = 0
import time
try:
self._monotonic = time.monotonic
except AttributeError:
self._monotonic = time.time
def start(self):
"""Start the counter"""
self._start = self._monotonic()
def stop(self):
"""Return time since start in microseconds"""
stop = self._monotonic()
return int(1e6 * (stop - self._start))

View File

@@ -0,0 +1,120 @@
"""Tracker for zero-copy messages with 0MQ."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
from threading import Event
from typing import Set, Tuple, Union
from zmq.backend import Frame
from zmq.error import NotDone
class MessageTracker:
"""MessageTracker(*towatch)
A class for tracking if 0MQ is done using one or more messages.
When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread
sends the message at some later time. Often you want to know when 0MQ has
actually sent the message though. This is complicated by the fact that
a single 0MQ message can be sent multiple times using different sockets.
This class allows you to track all of the 0MQ usages of a message.
Parameters
----------
towatch : Event, MessageTracker, Message instances.
This objects to track. This class can track the low-level
Events used by the Message class, other MessageTrackers or
actual Messages.
"""
events: Set[Event]
peers: Set["MessageTracker"]
def __init__(self, *towatch: Tuple[Union["MessageTracker", Event, Frame]]):
"""MessageTracker(*towatch)
Create a message tracker to track a set of mesages.
Parameters
----------
*towatch : tuple of Event, MessageTracker, Message instances.
This list of objects to track. This class can track the low-level
Events used by the Message class, other MessageTrackers or
actual Messages.
"""
self.events = set()
self.peers = set()
for obj in towatch:
if isinstance(obj, Event):
self.events.add(obj)
elif isinstance(obj, MessageTracker):
self.peers.add(obj)
elif isinstance(obj, Frame):
if not obj.tracker:
raise ValueError("Not a tracked message")
self.peers.add(obj.tracker)
else:
raise TypeError("Require Events or Message Frames, not %s" % type(obj))
@property
def done(self):
"""Is 0MQ completely done with the message(s) being tracked?"""
for evt in self.events:
if not evt.is_set():
return False
for pm in self.peers:
if not pm.done:
return False
return True
def wait(self, timeout: Union[float, int] = -1):
"""mt.wait(timeout=-1)
Wait for 0MQ to be done with the message or until `timeout`.
Parameters
----------
timeout : float [default: -1, wait forever]
Maximum time in (s) to wait before raising NotDone.
Returns
-------
None
if done before `timeout`
Raises
------
NotDone
if `timeout` reached before I am done.
"""
tic = time.time()
remaining: float
if timeout is False or timeout < 0:
remaining = 3600 * 24 * 7 # a week
else:
remaining = timeout
for evt in self.events:
if remaining < 0:
raise NotDone
evt.wait(timeout=remaining)
if not evt.is_set():
raise NotDone
toc = time.time()
remaining -= toc - tic
tic = toc
for peer in self.peers:
if remaining < 0:
raise NotDone
peer.wait(timeout=remaining)
toc = time.time()
remaining -= toc - tic
tic = toc
_FINISHED_TRACKER = MessageTracker()
__all__ = ['MessageTracker', '_FINISHED_TRACKER']

View File

@@ -0,0 +1,66 @@
"""PyZMQ and 0MQ version functions."""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import re
from typing import Match, Tuple, Union, cast
from zmq.backend import zmq_version_info
__version__: str = "23.0.0"
_version_pat = re.compile(r"(\d+)\.(\d+)\.(\d+)(.*)")
_match = cast(Match, _version_pat.match(__version__))
_version_groups = _match.groups()
VERSION_MAJOR = int(_version_groups[0])
VERSION_MINOR = int(_version_groups[1])
VERSION_PATCH = int(_version_groups[2])
VERSION_EXTRA = _version_groups[3].lstrip(".")
version_info: Union[Tuple[int, int, int], Tuple[int, int, int, float]] = (
VERSION_MAJOR,
VERSION_MINOR,
VERSION_PATCH,
)
if VERSION_EXTRA:
version_info = (
VERSION_MAJOR,
VERSION_MINOR,
VERSION_PATCH,
float('inf'),
)
__revision__: str = ''
def pyzmq_version() -> str:
"""return the version of pyzmq as a string"""
if __revision__:
return '+'.join([__version__, __revision__[:6]])
else:
return __version__
def pyzmq_version_info() -> Union[Tuple[int, int, int], Tuple[int, int, int, float]]:
"""return the pyzmq version as a tuple of at least three numbers
If pyzmq is a development version, `inf` will be appended after the third integer.
"""
return version_info
def zmq_version() -> str:
"""return the version of libzmq as a string"""
return "%i.%i.%i" % zmq_version_info()
__all__ = [
'zmq_version',
'zmq_version_info',
'pyzmq_version',
'pyzmq_version_info',
'__version__',
'__revision__',
]

View File

@@ -0,0 +1,257 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import os
import platform
import signal
import sys
import time
from functools import partial
from threading import Thread
from typing import List
from unittest import SkipTest, TestCase
from pytest import mark
import zmq
from zmq.utils import jsonapi
try:
import gevent
from zmq import green as gzmq
have_gevent = True
except ImportError:
have_gevent = False
PYPY = platform.python_implementation() == 'PyPy'
# -----------------------------------------------------------------------------
# skip decorators (directly from unittest)
# -----------------------------------------------------------------------------
_id = lambda x: x
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
# -----------------------------------------------------------------------------
# Base test class
# -----------------------------------------------------------------------------
def term_context(ctx, timeout):
"""Terminate a context with a timeout"""
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
class BaseZMQTestCase(TestCase):
green = False
teardown_timeout = 10
test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60)
sockets: List[zmq.Socket]
@property
def _is_pyzmq_test(self):
return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0]
@property
def _should_test_timeout(self):
return (
self._is_pyzmq_test
and hasattr(signal, 'SIGALRM')
and self.test_timeout_seconds
)
@property
def Context(self):
if self.green:
return gzmq.Context
else:
return zmq.Context
def socket(self, socket_type):
s = self.context.socket(socket_type)
self.sockets.append(s)
return s
def _alarm_timeout(self, timeout, *args):
raise TimeoutError(f"Test did not complete in {timeout} seconds")
def setUp(self):
super().setUp()
if self.green and not have_gevent:
raise SkipTest("requires gevent")
self.context = self.Context.instance()
self.sockets = []
if self._should_test_timeout:
# use SIGALRM to avoid test hangs
signal.signal(
signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds)
)
signal.alarm(self.test_timeout_seconds)
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close(0)
for ctx in contexts:
try:
term_context(ctx, self.teardown_timeout)
except Exception:
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise
super().tearDown()
def create_bound_pair(
self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'
):
"""Create a bound socket pair using a random port."""
s1 = self.context.socket(type1)
s1.setsockopt(zmq.LINGER, 0)
port = s1.bind_to_random_port(interface)
s2 = self.context.socket(type2)
s2.setsockopt(zmq.LINGER, 0)
s2.connect(f'{interface}:{port}')
self.sockets.extend([s1, s2])
return s1, s2
def ping_pong(self, s1, s2, msg):
s1.send(msg)
msg2 = s2.recv()
s2.send(msg2)
msg3 = s1.recv()
return msg3
def ping_pong_json(self, s1, s2, o):
if jsonapi.jsonmod is None:
raise SkipTest("No json library")
s1.send_json(o)
o2 = s2.recv_json()
s2.send_json(o2)
o3 = s1.recv_json()
return o3
def ping_pong_pyobj(self, s1, s2, o):
s1.send_pyobj(o)
o2 = s2.recv_pyobj()
s2.send_pyobj(o2)
o3 = s1.recv_pyobj()
return o3
def assertRaisesErrno(self, errno, func, *args, **kwargs):
try:
func(*args, **kwargs)
except zmq.ZMQError as e:
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def _select_recv(self, multipart, socket, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3, 1, 0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
assert len(r) > 0, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(**kwargs)
def recv(self, socket, **kwargs):
"""call recv in a way that raises if there is nothing to receive"""
return self._select_recv(False, socket, **kwargs)
def recv_multipart(self, socket, **kwargs):
"""call recv_multipart in a way that raises if there is nothing to receive"""
return self._select_recv(True, socket, **kwargs)
class PollZMQTestCase(BaseZMQTestCase):
pass
class GreenTest:
"""Mixin for making green versions of test classes"""
green = True
teardown_timeout = 10
def assertRaisesErrno(self, errno, func, *args, **kwargs):
if errno == zmq.EAGAIN:
raise SkipTest("Skipping because we're green.")
try:
func(*args, **kwargs)
except zmq.ZMQError:
e = sys.exc_info()[1]
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close()
try:
gevent.joinall(
[gevent.spawn(ctx.term) for ctx in contexts],
timeout=self.teardown_timeout,
raise_error=True,
)
except gevent.Timeout:
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
def skip_green(self):
raise SkipTest("Skipping because we are green")
def skip_green(f):
def skipping_test(self, *args, **kwargs):
if self.green:
raise SkipTest("Skipping because we are green")
else:
return f(self, *args, **kwargs)
return skipping_test

View File

@@ -0,0 +1 @@
"""pytest configuration and fixtures"""

View File

@@ -0,0 +1,498 @@
"""Test asyncio support"""
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import sys
from concurrent.futures import CancelledError
from multiprocessing import Process
import pytest
from pytest import mark
import zmq
import zmq.asyncio as zaio
from zmq.auth.asyncio import AsyncioAuthenticator
from zmq.tests import BaseZMQTestCase
from zmq.tests.test_auth import TestThreadAuthentication
class ProcessForTeardownTest(Process):
def __init__(self, event_loop_policy_class):
Process.__init__(self)
self.event_loop_policy_class = event_loop_policy_class
def run(self):
"""Leave context, socket and event loop upon implicit disposal"""
asyncio.set_event_loop_policy(self.event_loop_policy_class())
actx = zaio.Context.instance()
socket = actx.socket(zmq.PAIR)
socket.bind_to_random_port("tcp://127.0.0.1")
async def never_ending_task(socket):
await socket.recv() # never ever receive anything
loop = asyncio.new_event_loop()
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
try:
loop.run_until_complete(coro)
except asyncio.TimeoutError:
pass # expected timeout
else:
assert False, "never_ending_task was completed unexpectedly"
finally:
loop.close()
class TestAsyncIOSocket(BaseZMQTestCase):
Context = zaio.Context
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
super().tearDown()
self.loop.close()
# verify cleanup of references to selectors
assert zaio._selectors == {}
if 'zmq._asyncio_selector' in sys.modules:
assert zmq._asyncio_selector._selector_loops == set()
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, zaio.Socket)
s.close()
def test_instance_subclass_first(self):
actx = zmq.asyncio.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = zmq.asyncio.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_recv_multipart(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b"hi"]
self.loop.run_until_complete(test())
def test_recv(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b"hi"
assert recvd == b"there"
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
def test_recv_timeout(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with self.assertRaises(zmq.Again):
await f1
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
def test_send_timeout(self):
async def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with self.assertRaises(zmq.Again):
await s.send(b"not going anywhere")
self.loop.run_until_complete(test())
def test_recv_string(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = "πøøπ"
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg
self.loop.run_until_complete(test())
def test_recv_json(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_until_complete(test())
def test_recv_json_cancelled(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
self.loop.run_until_complete(test())
def test_recv_pyobj(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_until_complete(test())
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get("identities", []))
content = json.dumps(msg["content"]).encode("utf8")
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode("utf8"))
return {
"identities": identities,
"content": content,
}
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
await a.send_serialized(msg, serialize)
recvd = await b.recv_serialized(deserialize)
assert recvd["content"] == msg["content"]
assert recvd["identities"]
# bounce back, tests identities
await b.send_serialized(recvd, serialize)
r2 = await a.recv_serialized(deserialize)
assert r2["content"] == msg["content"]
assert not r2["identities"]
self.loop.run_until_complete(test())
def test_custom_serialize_error(self):
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
with pytest.raises(TypeError):
await a.send_serialized(json, json.dumps)
await a.send(b"not json")
with pytest.raises(TypeError):
await b.recv_serialized(json.loads)
self.loop.run_until_complete(test())
def test_recv_dontwait(self):
async def test():
push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = pull.recv(zmq.DONTWAIT)
with self.assertRaises(zmq.Again):
await f
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
assert msg == b"ping"
self.loop.run_until_complete(test())
def test_recv_cancel(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
f = b.poll(timeout=1)
assert not f.done()
evt = await f
assert evt == 0
f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll_base_socket(self):
async def test():
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = zaio.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll_on_closed_socket(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=1)
b.close()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
self.loop.run_until_complete(test())
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Windows does not support polling on files",
)
def test_poll_raw(self):
async def test():
p = zaio.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, "rb")
w = os.fdopen(w, "wb")
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = await p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b"x")
w.flush()
evts = await p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b"x"
r.close()
w.close()
loop = asyncio.new_event_loop()
loop.run_until_complete(test())
def test_multiple_loops(self):
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
async def test():
await a.send(b'buf')
msg = await b.recv()
assert msg == b'buf'
for i in range(3):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
loop.close()
def test_shadow(self):
async def test():
ctx = zmq.Context()
s = ctx.socket(zmq.PULL)
async_s = zaio.Socket(s)
assert isinstance(async_s, self.socket_class)
def test_process_teardown(self):
event_loop_policy_class = type(asyncio.get_event_loop_policy())
proc = ProcessForTeardownTest(event_loop_policy_class)
proc.start()
try:
proc.join(10) # starting new Python process may cost a lot
self.assertEqual(
proc.exitcode,
0,
"Python process died with code %d" % proc.exitcode
if proc.exitcode
else "process teardown hangs",
)
finally:
proc.terminate()
class TestAsyncioAuthentication(TestThreadAuthentication):
"""Test authentication running in a asyncio task"""
Context = zaio.Context
def shortDescription(self):
"""Rewrite doc strings from TestThreadAuthentication from
'threaded' to 'asyncio'.
"""
doc = self._testMethodDoc
if doc:
doc = doc.split("\n")[0].strip()
if doc.startswith("threaded auth"):
doc = doc.replace("threaded auth", "asyncio auth")
return doc
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
super().tearDown()
self.loop.close()
def make_auth(self):
return AsyncioAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
async def go():
result = False
iface = "tcp://127.0.0.1"
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
# set timeouts
server.SNDTIMEO = client.RCVTIMEO = 1000
try:
await server.send_multipart(msg)
except zmq.Again:
return False
try:
rcvd_msg = await client.recv_multipart()
except zmq.Again:
return False
else:
assert rcvd_msg == msg
result = True
return result
return self.loop.run_until_complete(go())
def _select_recv(self, multipart, socket, **kwargs):
recv = socket.recv_multipart if multipart else socket.recv
async def coro():
if not await socket.poll(5000):
raise TimeoutError("Should have received a message")
return await recv(**kwargs)
return self.loop.run_until_complete(coro())

View File

@@ -0,0 +1,579 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import os
import shutil
import tempfile
import warnings
import pytest
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
class BaseAuthTestCase(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4, 0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to have curve support")
super().setUp()
# enable debug logging while we run tests
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
self.auth = self.make_auth()
self.auth.start()
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
def make_auth(self):
raise NotImplementedError()
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.remove_certs(self.base_dir)
super().tearDown()
def create_certs(self):
"""Create CURVE certificates for a test"""
# Create temporary CURVE keypairs for this test run. We create all keys in a
# temp directory and then move them into the appropriate private or public
# directory.
base_dir = tempfile.mkdtemp()
keys_dir = os.path.join(base_dir, 'certificates')
public_keys_dir = os.path.join(base_dir, 'public_keys')
secret_keys_dir = os.path.join(base_dir, 'private_keys')
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
server_public_file, server_secret_file = zmq.auth.create_certificates(
keys_dir, "server"
)
client_public_file, client_secret_file = zmq.auth.create_certificates(
keys_dir, "client"
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
)
return (base_dir, public_keys_dir, secret_keys_dir)
def remove_certs(self, base_dir):
"""Remove certificates for a test"""
shutil.rmtree(base_dir)
def load_certs(self, secret_keys_dir):
"""Return server and client certificate keys"""
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
return server_public, server_secret, client_public, client_secret
class TestThreadAuthentication(BaseAuthTestCase):
"""Test authentication running in a thread"""
def make_auth(self):
return ThreadAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
# run poll on server twice
# to flush spurious events
server.poll(100, zmq.POLLOUT)
if server.poll(1000, zmq.POLLOUT):
try:
server.send_multipart(msg, zmq.NOBLOCK)
except zmq.Again:
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
return False
else:
return False
if client.poll(1000):
try:
rcvd_msg = client.recv_multipart(zmq.NOBLOCK)
except zmq.Again:
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
else:
assert rcvd_msg == msg
result = True
return result
def test_null(self):
"""threaded auth - NULL"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
self.auth.stop()
self.auth = None
# use a new context, so ZAP isn't inherited
self.context = self.Context()
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
server = self.socket(zmq.PUSH)
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_blacklist(self):
"""threaded auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert not self.can_connect(server, client)
def test_whitelist(self):
"""threaded auth - Whitelist"""
# Whitelist 127.0.0.1, connection should pass"
self.auth.allow('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_plain(self):
"""threaded auth - PLAIN"""
# Try PLAIN authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
assert not self.can_connect(server, client)
# Try PLAIN authentication - with server configured, connection should pass
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
assert self.can_connect(server, client)
# Try PLAIN authentication - with bogus credentials, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Bogus'
assert not self.can_connect(server, client)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
client.close()
server.close()
def test_curve(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
# Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
# Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(server, client)
# Try CURVE authentication - with server configured, connection should pass
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
# Try connecting using NULL and no authentication enabled, connection should pass
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_curve_callback(self):
"""threaded auth - CURVE with callback authentication"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
# Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
# Try CURVE authentication - with callback authentication configured, connection should pass
class CredentialsProvider:
def __init__(self):
self.client = client_public
def callback(self, domain, key):
if key == self.client:
return True
else:
return False
provider = CredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(server, client)
# Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
class WrongCredentialsProvider:
def __init__(self):
self.client = "WrongCredentials"
def callback(self, domain, key):
if key == self.client:
return True
else:
return False
provider = WrongCredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
@skip_pypy
def test_curve_user_id(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# test default user-id map
client.send(b'test')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == client_public.decode("utf8")
# test custom user-id map
self.auth.curve_user_id = lambda client_key: 'custom'
client2 = self.socket(zmq.PUSH)
client2.curve_publickey = client_public
client2.curve_secretkey = client_secret
client2.curve_serverkey = server_public
assert self.can_connect(client2, server)
client2.send(b'test2')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test2'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == 'custom'
def with_ioloop(method, expect_success=True):
"""decorator for running tests with an IOLoop"""
def test_method(self):
r = method(self)
loop = self.io_loop
if expect_success:
self.pullstream.on_recv(self.on_message_succeed)
else:
self.pullstream.on_recv(self.on_message_fail)
loop.call_later(1, self.attempt_connection)
loop.call_later(1.2, self.send_msg)
if expect_success:
loop.call_later(2, self.on_test_timeout_fail)
else:
loop.call_later(2, self.on_test_timeout_succeed)
loop.start()
if self.fail_msg:
self.fail(self.fail_msg)
return r
return test_method
def should_auth(method):
return with_ioloop(method, True)
def should_not_auth(method):
return with_ioloop(method, False)
class TestIOLoopAuthentication(BaseAuthTestCase):
"""Test authentication running in ioloop"""
def setUp(self):
try:
from tornado import ioloop
except ImportError:
pytest.skip("Requires tornado")
from zmq.eventloop import zmqstream
self.fail_msg = None
self.io_loop = ioloop.IOLoop()
super().setUp()
self.server = self.socket(zmq.PUSH)
self.client = self.socket(zmq.PULL)
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
def make_auth(self):
from zmq.auth.ioloop import IOLoopAuthenticator
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.io_loop.close(all_fds=True)
super().tearDown()
def attempt_connection(self):
"""Check if client can connect to server using tcp transport"""
iface = 'tcp://127.0.0.1'
port = self.server.bind_to_random_port(iface)
self.client.connect("%s:%i" % (iface, port))
def send_msg(self):
"""Send a message from server to a client"""
msg = [b"Hello World"]
self.pushstream.send_multipart(msg)
def on_message_succeed(self, frames):
"""A message was received, as expected."""
if frames != [b"Hello World"]:
self.fail_msg = "Unexpected message received"
self.io_loop.stop()
def on_message_fail(self, frames):
"""A message was received, unexpectedly."""
self.fail_msg = 'Received messaged unexpectedly, security failed'
self.io_loop.stop()
def on_test_timeout_succeed(self):
"""Test timer expired, indicates test success"""
self.io_loop.stop()
def on_test_timeout_fail(self):
"""Test timer expired, indicates test failure"""
self.fail_msg = 'Test timed out'
self.io_loop.stop()
@should_auth
def test_none(self):
"""ioloop auth - NONE"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
# no auth should be running
self.auth.stop()
self.auth = None
@should_auth
def test_null(self):
"""ioloop auth - NULL"""
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
self.server.zap_domain = b'global'
@should_not_auth
def test_blacklist(self):
"""ioloop auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
self.server.zap_domain = b'global'
@should_auth
def test_whitelist(self):
"""ioloop auth - Whitelist"""
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
self.auth.allow('127.0.0.1')
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
@should_not_auth
def test_plain_unconfigured_server(self):
"""ioloop auth - PLAIN, unconfigured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - without configuring server, connection should fail
self.server.plain_server = True
@should_auth
def test_plain_configured_server(self):
"""ioloop auth - PLAIN, configured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - with server configured, connection should pass
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_plain_bogus_credentials(self):
"""ioloop auth - PLAIN, bogus credentials"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Bogus'
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_curve_unconfigured_server(self):
"""ioloop auth - CURVE, unconfigured server"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_allow_any(self):
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_configured_server(self):
"""ioloop auth - CURVE, configured server"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public

View File

@@ -0,0 +1,303 @@
import time
from unittest import TestCase
from zmq.tests import SkipTest
try:
from zmq.backend.cffi import ( # type: ignore
IDENTITY,
POLLIN,
POLLOUT,
PULL,
PUSH,
REP,
REQ,
zmq_version_info,
)
from zmq.backend.cffi._cffi import C, ffi
have_ffi_backend = True
except ImportError:
have_ffi_backend = False
class TestCFFIBackend(TestCase):
def setUp(self):
if not have_ffi_backend:
raise SkipTest('CFFI not available')
def test_zmq_version_info(self):
version = zmq_version_info()
assert version[0] in range(2, 11)
def test_zmq_ctx_new_destroy(self):
ctx = C.zmq_ctx_new()
assert ctx != ffi.NULL
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_socket_open_close(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_setsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[3]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_getsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
option_len = ffi.new('size_t*', 3)
option = ffi.new('char[3]')
ret = C.zmq_getsockopt(socket, IDENTITY, ffi.cast('void*', option), option_len)
assert ret == 0
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, 8)
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind_connect(self):
ctx = C.zmq_ctx_new()
socket1 = C.zmq_socket(ctx, PUSH)
socket2 = C.zmq_socket(ctx, PULL)
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket1
assert ffi.NULL != socket2
assert 0 == C.zmq_close(socket1)
assert 0 == C.zmq_close(socket2)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_msg_init_close(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_size(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
data = C.zmq_msg_data(zmq_msg)
assert ffi.NULL != zmq_msg
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_send(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 0 == C.zmq_msg_close(zmq_msg)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_recv(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_poll(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
receiver_pollitem = ffi.new('zmq_pollitem_t*')
receiver_pollitem.socket = receiver
receiver_pollitem.fd = 0
receiver_pollitem.events = POLLIN | POLLOUT
receiver_pollitem.revents = 0
ret = C.zmq_poll(ffi.NULL, 0, 0)
assert ret == 0
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 0
ret = C.zmq_msg_send(zmq_msg, sender, 0)
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
assert ret == 5
time.sleep(0.2)
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 1
assert int(receiver_pollitem.revents) & POLLIN
assert not int(receiver_pollitem.revents) & POLLOUT
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert ret_recv == 5
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
sender_pollitem = ffi.new('zmq_pollitem_t*')
sender_pollitem.socket = sender
sender_pollitem.fd = 0
sender_pollitem.events = POLLIN | POLLOUT
sender_pollitem.revents = 0
ret = C.zmq_poll(sender_pollitem, 1, 0)
assert ret == 0
zmq_msg_again = ffi.new('zmq_msg_t*')
message_again = ffi.new('char[11]', b'Hello Again')
C.zmq_msg_init_data(
zmq_msg_again,
ffi.cast('void*', message_again),
ffi.cast('size_t', 11),
ffi.NULL,
ffi.NULL,
)
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
time.sleep(0.2)
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
assert int(sender_pollitem.revents) & POLLIN
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
assert 11 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello Again"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), int(C.zmq_msg_size(zmq_msg2)))[:]
)
assert 0 == C.zmq_close(sender)
assert 0 == C.zmq_close(receiver)
assert 0 == C.zmq_ctx_destroy(ctx)
assert 0 == C.zmq_msg_close(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg2)
assert 0 == C.zmq_msg_close(zmq_msg_again)

View File

@@ -0,0 +1,19 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
import zmq.constants
def test_constants():
assert zmq.POLLIN is zmq.PollEvent.POLLIN
assert zmq.PUSH is zmq.SocketType.PUSH
assert zmq.constants.SUBSCRIBE is zmq.SocketOption.SUBSCRIBE
def test_socket_options():
assert zmq.IDENTITY is zmq.SocketOption.ROUTING_ID
assert zmq.IDENTITY._opt_type is zmq.constants._OptType.bytes
assert zmq.AFFINITY._opt_type is zmq.constants._OptType.int64
assert zmq.CURVE_SERVER._opt_type is zmq.constants._OptType.int
assert zmq.FD._opt_type is zmq.constants._OptType.fd

View File

@@ -0,0 +1,401 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import os
import sys
import time
from queue import Queue
from threading import Event, Thread
from unittest import mock
from pytest import mark
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest
class KwargTestSocket(zmq.Socket):
test_kwarg_value = None
def __init__(self, *args, **kwargs):
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
super().__init__(*args, **kwargs)
class KwargTestContext(zmq.Context):
_socket_class = KwargTestSocket
class TestContext(BaseZMQTestCase):
def test_init(self):
c1 = self.Context()
assert isinstance(c1, self.Context)
del c1
c2 = self.Context()
assert isinstance(c2, self.Context)
del c2
c3 = self.Context()
assert isinstance(c3, self.Context)
del c3
_repr_cls = "zmq.Context"
def test_repr(self):
with self.Context() as ctx:
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' not in repr(ctx)
with ctx.socket(zmq.PUSH) as push:
assert f'{self._repr_cls}(1 socket)' in repr(ctx)
with ctx.socket(zmq.PULL) as pull:
assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' in repr(ctx)
def test_dir(self):
ctx = self.Context()
assert 'socket' in dir(ctx)
if zmq.zmq_version_info() > (3,):
assert 'IO_THREADS' in dir(ctx)
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
m = mock.Mock(spec=self.context)
def test_term(self):
c = self.Context()
c.term()
assert c.closed
def test_context_manager(self):
with self.Context() as c:
pass
assert c.closed
def test_fail_init(self):
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
def test_term_hang(self):
rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
req.setsockopt(zmq.LINGER, 0)
req.send(b'hello', copy=False)
req.close()
rep.close()
self.context.term()
def test_instance(self):
ctx = self.Context.instance()
c2 = self.Context.instance(io_threads=2)
assert c2 is ctx
c2.term()
c3 = self.Context.instance()
c4 = self.Context.instance()
assert not c3 is c2
assert not c3.closed
assert c3 is c4
def test_instance_subclass_first(self):
self.context.term()
class SubContext(zmq.Context):
pass
sctx = SubContext.instance()
ctx = zmq.Context.instance()
ctx.term()
sctx.term()
assert type(ctx) is zmq.Context
assert type(sctx) is SubContext
def test_instance_subclass_second(self):
self.context.term()
class SubContextInherit(zmq.Context):
pass
class SubContextNoInherit(zmq.Context):
_instance = None
ctx = zmq.Context.instance()
sctx = SubContextInherit.instance()
sctx2 = SubContextNoInherit.instance()
ctx.term()
sctx.term()
sctx2.term()
assert type(ctx) is zmq.Context
assert type(sctx) is zmq.Context
assert type(sctx2) is SubContextNoInherit
def test_instance_threadsafe(self):
self.context.term() # clear default context
q = Queue()
# slow context initialization,
# to ensure that we are both trying to create one at the same time
class SlowContext(self.Context):
def __init__(self, *a, **kw):
time.sleep(1)
super().__init__(*a, **kw)
def f():
q.put(SlowContext.instance())
# call ctx.instance() in several threads at once
N = 16
threads = [Thread(target=f) for i in range(N)]
[t.start() for t in threads]
# also call it in the main thread (not first)
ctx = SlowContext.instance()
assert isinstance(ctx, SlowContext)
# check that all the threads got the same context
for i in range(N):
thread_ctx = q.get(timeout=5)
assert thread_ctx is ctx
# cleanup
ctx.term()
[t.join(timeout=5) for t in threads]
def test_socket_passes_kwargs(self):
test_kwarg_value = 'testing one two three'
with KwargTestContext() as ctx:
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
assert socket.test_kwarg_value is test_kwarg_value
def test_many_sockets(self):
"""opening and closing many sockets shouldn't cause problems"""
ctx = self.Context()
for i in range(16):
sockets = [ctx.socket(zmq.REP) for i in range(65)]
[s.close() for s in sockets]
# give the reaper a chance
time.sleep(1e-2)
ctx.term()
def test_sockopts(self):
"""setting socket options with ctx attributes"""
ctx = self.Context()
ctx.linger = 5
assert ctx.linger == 5
s = ctx.socket(zmq.REQ)
assert s.linger == 5
assert s.getsockopt(zmq.LINGER) == 5
s.close()
# check that subscribe doesn't get set on sockets that don't subscribe:
ctx.subscribe = b''
s = ctx.socket(zmq.REQ)
s.close()
ctx.term()
@mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
def test_destroy(self):
"""Context.destroy should close sockets"""
ctx = self.Context()
sockets = [ctx.socket(zmq.REP) for i in range(65)]
# close half of the sockets
[s.close() for s in sockets[::2]]
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
for s in sockets:
assert s.closed
def test_destroy_linger(self):
"""Context.destroy should set linger on closing sockets"""
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
req.send(b'hi')
time.sleep(1e-2)
self.context.destroy(linger=0)
# reaper is not instantaneous
time.sleep(1e-2)
for s in (req, rep):
assert s.closed
def test_term_noclose(self):
"""Context.term won't close sockets"""
ctx = self.Context()
s = ctx.socket(zmq.REQ)
assert not s.closed
t = Thread(target=ctx.term)
t.start()
t.join(timeout=0.1)
assert t.is_alive(), "Context should be waiting"
s.close()
t.join(timeout=0.1)
assert not t.is_alive(), "Context should have closed"
def test_gc(self):
"""test close&term by garbage collection alone"""
if PYPY:
raise SkipTest("GC doesn't work ")
# test credit @dln (GH #137):
def gcf():
def inner():
ctx = self.Context()
ctx.socket(zmq.PUSH)
inner()
gc.collect()
t = Thread(target=gcf)
t.start()
t.join(timeout=1)
assert not t.is_alive(), "Garbage collection should have cleaned up context"
def test_cyclic_destroy(self):
"""ctx.destroy should succeed when cyclic ref prevents gc"""
# test credit @dln (GH #137):
class CyclicReference:
def __init__(self, parent=None):
self.parent = parent
def crash(self, sock):
self.sock = sock
self.child = CyclicReference(self)
def crash_zmq():
ctx = self.Context()
sock = ctx.socket(zmq.PULL)
c = CyclicReference()
c.crash(sock)
ctx.destroy()
crash_zmq()
def test_term_thread(self):
"""ctx.term should not crash active threads (#139)"""
ctx = self.Context()
evt = Event()
evt.clear()
def block():
s = ctx.socket(zmq.REP)
s.bind_to_random_port('tcp://127.0.0.1')
evt.set()
try:
s.recv()
except zmq.ZMQError as e:
assert e.errno == zmq.ETERM
return
finally:
s.close()
self.fail("recv should have been interrupted with ETERM")
t = Thread(target=block)
t.start()
evt.wait(1)
assert evt.is_set(), "sync event never fired"
time.sleep(0.01)
ctx.term()
t.join(timeout=1)
assert not t.is_alive(), "term should have interrupted s.recv()"
def test_destroy_no_sockets(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.bind_to_random_port('tcp://127.0.0.1')
s.close()
ctx.destroy()
assert s.closed
assert ctx.closed
def test_ctx_opts(self):
if zmq.zmq_version_info() < (3,):
raise SkipTest("context options require libzmq 3")
ctx = self.Context()
ctx.set(zmq.MAX_SOCKETS, 2)
assert ctx.get(zmq.MAX_SOCKETS) == 2
ctx.max_sockets = 100
assert ctx.max_sockets == 100
assert ctx.get(zmq.MAX_SOCKETS) == 100
def test_copy(self):
c1 = self.Context()
c2 = copy.copy(c1)
c2b = copy.deepcopy(c1)
c3 = copy.deepcopy(c2)
assert c2._shadow
assert c3._shadow
assert c1.underlying == c2.underlying
assert c1.underlying == c3.underlying
assert c1.underlying == c2b.underlying
s = c3.socket(zmq.PUB)
s.close()
c1.term()
def test_shadow(self):
ctx = self.Context()
ctx2 = self.Context.shadow(ctx.underlying)
assert ctx.underlying == ctx2.underlying
s = ctx.socket(zmq.PUB)
s.close()
del ctx2
assert not ctx.closed
s = ctx.socket(zmq.PUB)
ctx2 = self.Context.shadow(ctx.underlying)
s2 = ctx2.socket(zmq.PUB)
s.close()
s2.close()
ctx.term()
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
del ctx2
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket, zstr
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
a = zsocket.new(ctx, zmq.PUSH)
zsocket.bind(a, "inproc://a")
ctx2 = self.Context.shadow_pyczmq(ctx)
b = ctx2.socket(zmq.PULL)
b.connect("inproc://a")
zstr.send(a, b'hi')
rcvd = self.recv(b)
assert rcvd == b'hi'
b.close()
@mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
def test_fork_instance(self):
ctx = self.Context.instance()
parent_ctx_id = id(ctx)
r_fd, w_fd = os.pipe()
reader = os.fdopen(r_fd, 'r')
child_pid = os.fork()
if child_pid == 0:
ctx = self.Context.instance()
writer = os.fdopen(w_fd, 'w')
child_ctx_id = id(ctx)
ctx.term()
writer.write(str(child_ctx_id) + "\n")
writer.flush()
writer.close()
os._exit(0)
else:
os.close(w_fd)
child_id_s = reader.readline()
reader.close()
assert child_id_s
assert int(child_id_s) != parent_ctx_id
ctx.term()
if False: # disable green context tests
class TestContextGreen(GreenTest, TestContext):
"""gevent subclass of context tests"""
# skip tests that use real threads:
test_gc = GreenTest.skip_green
test_term_thread = GreenTest.skip_green
test_destroy_linger = GreenTest.skip_green
_repr_cls = "zmq.green.Context"

Some files were not shown because too many files have changed in this diff Show More