first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@ -0,0 +1,257 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import os
import platform
import signal
import sys
import time
from functools import partial
from threading import Thread
from typing import List
from unittest import SkipTest, TestCase
from pytest import mark
import zmq
from zmq.utils import jsonapi
try:
import gevent
from zmq import green as gzmq
have_gevent = True
except ImportError:
have_gevent = False
PYPY = platform.python_implementation() == 'PyPy'
# -----------------------------------------------------------------------------
# skip decorators (directly from unittest)
# -----------------------------------------------------------------------------
_id = lambda x: x
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
# -----------------------------------------------------------------------------
# Base test class
# -----------------------------------------------------------------------------
def term_context(ctx, timeout):
"""Terminate a context with a timeout"""
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
class BaseZMQTestCase(TestCase):
green = False
teardown_timeout = 10
test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60)
sockets: List[zmq.Socket]
@property
def _is_pyzmq_test(self):
return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0]
@property
def _should_test_timeout(self):
return (
self._is_pyzmq_test
and hasattr(signal, 'SIGALRM')
and self.test_timeout_seconds
)
@property
def Context(self):
if self.green:
return gzmq.Context
else:
return zmq.Context
def socket(self, socket_type):
s = self.context.socket(socket_type)
self.sockets.append(s)
return s
def _alarm_timeout(self, timeout, *args):
raise TimeoutError(f"Test did not complete in {timeout} seconds")
def setUp(self):
super().setUp()
if self.green and not have_gevent:
raise SkipTest("requires gevent")
self.context = self.Context.instance()
self.sockets = []
if self._should_test_timeout:
# use SIGALRM to avoid test hangs
signal.signal(
signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds)
)
signal.alarm(self.test_timeout_seconds)
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close(0)
for ctx in contexts:
try:
term_context(ctx, self.teardown_timeout)
except Exception:
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise
super().tearDown()
def create_bound_pair(
self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'
):
"""Create a bound socket pair using a random port."""
s1 = self.context.socket(type1)
s1.setsockopt(zmq.LINGER, 0)
port = s1.bind_to_random_port(interface)
s2 = self.context.socket(type2)
s2.setsockopt(zmq.LINGER, 0)
s2.connect(f'{interface}:{port}')
self.sockets.extend([s1, s2])
return s1, s2
def ping_pong(self, s1, s2, msg):
s1.send(msg)
msg2 = s2.recv()
s2.send(msg2)
msg3 = s1.recv()
return msg3
def ping_pong_json(self, s1, s2, o):
if jsonapi.jsonmod is None:
raise SkipTest("No json library")
s1.send_json(o)
o2 = s2.recv_json()
s2.send_json(o2)
o3 = s1.recv_json()
return o3
def ping_pong_pyobj(self, s1, s2, o):
s1.send_pyobj(o)
o2 = s2.recv_pyobj()
s2.send_pyobj(o2)
o3 = s1.recv_pyobj()
return o3
def assertRaisesErrno(self, errno, func, *args, **kwargs):
try:
func(*args, **kwargs)
except zmq.ZMQError as e:
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def _select_recv(self, multipart, socket, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3, 1, 0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
assert len(r) > 0, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(**kwargs)
def recv(self, socket, **kwargs):
"""call recv in a way that raises if there is nothing to receive"""
return self._select_recv(False, socket, **kwargs)
def recv_multipart(self, socket, **kwargs):
"""call recv_multipart in a way that raises if there is nothing to receive"""
return self._select_recv(True, socket, **kwargs)
class PollZMQTestCase(BaseZMQTestCase):
pass
class GreenTest:
"""Mixin for making green versions of test classes"""
green = True
teardown_timeout = 10
def assertRaisesErrno(self, errno, func, *args, **kwargs):
if errno == zmq.EAGAIN:
raise SkipTest("Skipping because we're green.")
try:
func(*args, **kwargs)
except zmq.ZMQError:
e = sys.exc_info()[1]
self.assertEqual(
e.errno,
errno,
"wrong error raised, expected '%s' \
got '%s'"
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
)
else:
self.fail("Function did not raise any error")
def tearDown(self):
if self._should_test_timeout:
# cancel the timeout alarm, if there was one
signal.alarm(0)
contexts = {self.context}
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close()
try:
gevent.joinall(
[gevent.spawn(ctx.term) for ctx in contexts],
timeout=self.teardown_timeout,
raise_error=True,
)
except gevent.Timeout:
raise RuntimeError(
"context could not terminate, open sockets likely remain in test"
)
def skip_green(self):
raise SkipTest("Skipping because we are green")
def skip_green(f):
def skipping_test(self, *args, **kwargs):
if self.green:
raise SkipTest("Skipping because we are green")
else:
return f(self, *args, **kwargs)
return skipping_test

View File

@ -0,0 +1 @@
"""pytest configuration and fixtures"""

View File

@ -0,0 +1,498 @@
"""Test asyncio support"""
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import sys
from concurrent.futures import CancelledError
from multiprocessing import Process
import pytest
from pytest import mark
import zmq
import zmq.asyncio as zaio
from zmq.auth.asyncio import AsyncioAuthenticator
from zmq.tests import BaseZMQTestCase
from zmq.tests.test_auth import TestThreadAuthentication
class ProcessForTeardownTest(Process):
def __init__(self, event_loop_policy_class):
Process.__init__(self)
self.event_loop_policy_class = event_loop_policy_class
def run(self):
"""Leave context, socket and event loop upon implicit disposal"""
asyncio.set_event_loop_policy(self.event_loop_policy_class())
actx = zaio.Context.instance()
socket = actx.socket(zmq.PAIR)
socket.bind_to_random_port("tcp://127.0.0.1")
async def never_ending_task(socket):
await socket.recv() # never ever receive anything
loop = asyncio.new_event_loop()
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
try:
loop.run_until_complete(coro)
except asyncio.TimeoutError:
pass # expected timeout
else:
assert False, "never_ending_task was completed unexpectedly"
finally:
loop.close()
class TestAsyncIOSocket(BaseZMQTestCase):
Context = zaio.Context
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
super().tearDown()
self.loop.close()
# verify cleanup of references to selectors
assert zaio._selectors == {}
if 'zmq._asyncio_selector' in sys.modules:
assert zmq._asyncio_selector._selector_loops == set()
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, zaio.Socket)
s.close()
def test_instance_subclass_first(self):
actx = zmq.asyncio.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = zmq.asyncio.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_recv_multipart(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b"hi"]
self.loop.run_until_complete(test())
def test_recv(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b"hi"
assert recvd == b"there"
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
def test_recv_timeout(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with self.assertRaises(zmq.Again):
await f1
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
def test_send_timeout(self):
async def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with self.assertRaises(zmq.Again):
await s.send(b"not going anywhere")
self.loop.run_until_complete(test())
def test_recv_string(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = "πøøπ"
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg
self.loop.run_until_complete(test())
def test_recv_json(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_until_complete(test())
def test_recv_json_cancelled(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
self.loop.run_until_complete(test())
def test_recv_pyobj(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_until_complete(test())
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get("identities", []))
content = json.dumps(msg["content"]).encode("utf8")
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode("utf8"))
return {
"identities": identities,
"content": content,
}
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
await a.send_serialized(msg, serialize)
recvd = await b.recv_serialized(deserialize)
assert recvd["content"] == msg["content"]
assert recvd["identities"]
# bounce back, tests identities
await b.send_serialized(recvd, serialize)
r2 = await a.recv_serialized(deserialize)
assert r2["content"] == msg["content"]
assert not r2["identities"]
self.loop.run_until_complete(test())
def test_custom_serialize_error(self):
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
"content": {
"a": 5,
"b": "bee",
}
}
with pytest.raises(TypeError):
await a.send_serialized(json, json.dumps)
await a.send(b"not json")
with pytest.raises(TypeError):
await b.recv_serialized(json.loads)
self.loop.run_until_complete(test())
def test_recv_dontwait(self):
async def test():
push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = pull.recv(zmq.DONTWAIT)
with self.assertRaises(zmq.Again):
await f
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
assert msg == b"ping"
self.loop.run_until_complete(test())
def test_recv_cancel(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
f = b.poll(timeout=1)
assert not f.done()
evt = await f
assert evt == 0
f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll_base_socket(self):
async def test():
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = zaio.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
self.loop.run_until_complete(test())
def test_poll_on_closed_socket(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=1)
b.close()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
self.loop.run_until_complete(test())
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Windows does not support polling on files",
)
def test_poll_raw(self):
async def test():
p = zaio.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, "rb")
w = os.fdopen(w, "wb")
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = await p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b"x")
w.flush()
evts = await p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b"x"
r.close()
w.close()
loop = asyncio.new_event_loop()
loop.run_until_complete(test())
def test_multiple_loops(self):
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
async def test():
await a.send(b'buf')
msg = await b.recv()
assert msg == b'buf'
for i in range(3):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
loop.close()
def test_shadow(self):
async def test():
ctx = zmq.Context()
s = ctx.socket(zmq.PULL)
async_s = zaio.Socket(s)
assert isinstance(async_s, self.socket_class)
def test_process_teardown(self):
event_loop_policy_class = type(asyncio.get_event_loop_policy())
proc = ProcessForTeardownTest(event_loop_policy_class)
proc.start()
try:
proc.join(10) # starting new Python process may cost a lot
self.assertEqual(
proc.exitcode,
0,
"Python process died with code %d" % proc.exitcode
if proc.exitcode
else "process teardown hangs",
)
finally:
proc.terminate()
class TestAsyncioAuthentication(TestThreadAuthentication):
"""Test authentication running in a asyncio task"""
Context = zaio.Context
def shortDescription(self):
"""Rewrite doc strings from TestThreadAuthentication from
'threaded' to 'asyncio'.
"""
doc = self._testMethodDoc
if doc:
doc = doc.split("\n")[0].strip()
if doc.startswith("threaded auth"):
doc = doc.replace("threaded auth", "asyncio auth")
return doc
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
super().tearDown()
self.loop.close()
def make_auth(self):
return AsyncioAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
async def go():
result = False
iface = "tcp://127.0.0.1"
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
# set timeouts
server.SNDTIMEO = client.RCVTIMEO = 1000
try:
await server.send_multipart(msg)
except zmq.Again:
return False
try:
rcvd_msg = await client.recv_multipart()
except zmq.Again:
return False
else:
assert rcvd_msg == msg
result = True
return result
return self.loop.run_until_complete(go())
def _select_recv(self, multipart, socket, **kwargs):
recv = socket.recv_multipart if multipart else socket.recv
async def coro():
if not await socket.poll(5000):
raise TimeoutError("Should have received a message")
return await recv(**kwargs)
return self.loop.run_until_complete(coro())

View File

@ -0,0 +1,579 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import os
import shutil
import tempfile
import warnings
import pytest
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
class BaseAuthTestCase(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4, 0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to have curve support")
super().setUp()
# enable debug logging while we run tests
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
self.auth = self.make_auth()
self.auth.start()
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
def make_auth(self):
raise NotImplementedError()
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.remove_certs(self.base_dir)
super().tearDown()
def create_certs(self):
"""Create CURVE certificates for a test"""
# Create temporary CURVE keypairs for this test run. We create all keys in a
# temp directory and then move them into the appropriate private or public
# directory.
base_dir = tempfile.mkdtemp()
keys_dir = os.path.join(base_dir, 'certificates')
public_keys_dir = os.path.join(base_dir, 'public_keys')
secret_keys_dir = os.path.join(base_dir, 'private_keys')
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
server_public_file, server_secret_file = zmq.auth.create_certificates(
keys_dir, "server"
)
client_public_file, client_secret_file = zmq.auth.create_certificates(
keys_dir, "client"
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
)
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
)
return (base_dir, public_keys_dir, secret_keys_dir)
def remove_certs(self, base_dir):
"""Remove certificates for a test"""
shutil.rmtree(base_dir)
def load_certs(self, secret_keys_dir):
"""Return server and client certificate keys"""
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
return server_public, server_secret, client_public, client_secret
class TestThreadAuthentication(BaseAuthTestCase):
"""Test authentication running in a thread"""
def make_auth(self):
return ThreadAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
# run poll on server twice
# to flush spurious events
server.poll(100, zmq.POLLOUT)
if server.poll(1000, zmq.POLLOUT):
try:
server.send_multipart(msg, zmq.NOBLOCK)
except zmq.Again:
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
return False
else:
return False
if client.poll(1000):
try:
rcvd_msg = client.recv_multipart(zmq.NOBLOCK)
except zmq.Again:
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
else:
assert rcvd_msg == msg
result = True
return result
def test_null(self):
"""threaded auth - NULL"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
self.auth.stop()
self.auth = None
# use a new context, so ZAP isn't inherited
self.context = self.Context()
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
server = self.socket(zmq.PUSH)
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_blacklist(self):
"""threaded auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert not self.can_connect(server, client)
def test_whitelist(self):
"""threaded auth - Whitelist"""
# Whitelist 127.0.0.1, connection should pass"
self.auth.allow('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_plain(self):
"""threaded auth - PLAIN"""
# Try PLAIN authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
assert not self.can_connect(server, client)
# Try PLAIN authentication - with server configured, connection should pass
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
assert self.can_connect(server, client)
# Try PLAIN authentication - with bogus credentials, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Bogus'
assert not self.can_connect(server, client)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
client.close()
server.close()
def test_curve(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
# Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
# Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(server, client)
# Try CURVE authentication - with server configured, connection should pass
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
# Try connecting using NULL and no authentication enabled, connection should pass
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
assert self.can_connect(server, client)
def test_curve_callback(self):
"""threaded auth - CURVE with callback authentication"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
# Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
# Try CURVE authentication - with callback authentication configured, connection should pass
class CredentialsProvider:
def __init__(self):
self.client = client_public
def callback(self, domain, key):
if key == self.client:
return True
else:
return False
provider = CredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(server, client)
# Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
class WrongCredentialsProvider:
def __init__(self):
self.client = "WrongCredentials"
def callback(self, domain, key):
if key == self.client:
return True
else:
return False
provider = WrongCredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert not self.can_connect(server, client)
@skip_pypy
def test_curve_user_id(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# test default user-id map
client.send(b'test')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == client_public.decode("utf8")
# test custom user-id map
self.auth.curve_user_id = lambda client_key: 'custom'
client2 = self.socket(zmq.PUSH)
client2.curve_publickey = client_public
client2.curve_secretkey = client_secret
client2.curve_serverkey = server_public
assert self.can_connect(client2, server)
client2.send(b'test2')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test2'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == 'custom'
def with_ioloop(method, expect_success=True):
"""decorator for running tests with an IOLoop"""
def test_method(self):
r = method(self)
loop = self.io_loop
if expect_success:
self.pullstream.on_recv(self.on_message_succeed)
else:
self.pullstream.on_recv(self.on_message_fail)
loop.call_later(1, self.attempt_connection)
loop.call_later(1.2, self.send_msg)
if expect_success:
loop.call_later(2, self.on_test_timeout_fail)
else:
loop.call_later(2, self.on_test_timeout_succeed)
loop.start()
if self.fail_msg:
self.fail(self.fail_msg)
return r
return test_method
def should_auth(method):
return with_ioloop(method, True)
def should_not_auth(method):
return with_ioloop(method, False)
class TestIOLoopAuthentication(BaseAuthTestCase):
"""Test authentication running in ioloop"""
def setUp(self):
try:
from tornado import ioloop
except ImportError:
pytest.skip("Requires tornado")
from zmq.eventloop import zmqstream
self.fail_msg = None
self.io_loop = ioloop.IOLoop()
super().setUp()
self.server = self.socket(zmq.PUSH)
self.client = self.socket(zmq.PULL)
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
def make_auth(self):
from zmq.auth.ioloop import IOLoopAuthenticator
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.io_loop.close(all_fds=True)
super().tearDown()
def attempt_connection(self):
"""Check if client can connect to server using tcp transport"""
iface = 'tcp://127.0.0.1'
port = self.server.bind_to_random_port(iface)
self.client.connect("%s:%i" % (iface, port))
def send_msg(self):
"""Send a message from server to a client"""
msg = [b"Hello World"]
self.pushstream.send_multipart(msg)
def on_message_succeed(self, frames):
"""A message was received, as expected."""
if frames != [b"Hello World"]:
self.fail_msg = "Unexpected message received"
self.io_loop.stop()
def on_message_fail(self, frames):
"""A message was received, unexpectedly."""
self.fail_msg = 'Received messaged unexpectedly, security failed'
self.io_loop.stop()
def on_test_timeout_succeed(self):
"""Test timer expired, indicates test success"""
self.io_loop.stop()
def on_test_timeout_fail(self):
"""Test timer expired, indicates test failure"""
self.fail_msg = 'Test timed out'
self.io_loop.stop()
@should_auth
def test_none(self):
"""ioloop auth - NONE"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
# no auth should be running
self.auth.stop()
self.auth = None
@should_auth
def test_null(self):
"""ioloop auth - NULL"""
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
self.server.zap_domain = b'global'
@should_not_auth
def test_blacklist(self):
"""ioloop auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
self.server.zap_domain = b'global'
@should_auth
def test_whitelist(self):
"""ioloop auth - Whitelist"""
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
self.auth.allow('127.0.0.1')
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
@should_not_auth
def test_plain_unconfigured_server(self):
"""ioloop auth - PLAIN, unconfigured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - without configuring server, connection should fail
self.server.plain_server = True
@should_auth
def test_plain_configured_server(self):
"""ioloop auth - PLAIN, configured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - with server configured, connection should pass
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_plain_bogus_credentials(self):
"""ioloop auth - PLAIN, bogus credentials"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Bogus'
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_curve_unconfigured_server(self):
"""ioloop auth - CURVE, unconfigured server"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_allow_any(self):
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_configured_server(self):
"""ioloop auth - CURVE, configured server"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public

View File

@ -0,0 +1,303 @@
import time
from unittest import TestCase
from zmq.tests import SkipTest
try:
from zmq.backend.cffi import ( # type: ignore
IDENTITY,
POLLIN,
POLLOUT,
PULL,
PUSH,
REP,
REQ,
zmq_version_info,
)
from zmq.backend.cffi._cffi import C, ffi
have_ffi_backend = True
except ImportError:
have_ffi_backend = False
class TestCFFIBackend(TestCase):
def setUp(self):
if not have_ffi_backend:
raise SkipTest('CFFI not available')
def test_zmq_version_info(self):
version = zmq_version_info()
assert version[0] in range(2, 11)
def test_zmq_ctx_new_destroy(self):
ctx = C.zmq_ctx_new()
assert ctx != ffi.NULL
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_socket_open_close(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_setsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[3]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_getsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
option_len = ffi.new('size_t*', 3)
option = ffi.new('char[3]')
ret = C.zmq_getsockopt(socket, IDENTITY, ffi.cast('void*', option), option_len)
assert ret == 0
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, 8)
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind_connect(self):
ctx = C.zmq_ctx_new()
socket1 = C.zmq_socket(ctx, PUSH)
socket2 = C.zmq_socket(ctx, PULL)
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket1
assert ffi.NULL != socket2
assert 0 == C.zmq_close(socket1)
assert 0 == C.zmq_close(socket2)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_msg_init_close(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_size(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[]', b'Hello')
assert 0 == C.zmq_msg_init_data(
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
)
data = C.zmq_msg_data(zmq_msg)
assert ffi.NULL != zmq_msg
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_send(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 0 == C.zmq_msg_close(zmq_msg)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_recv(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_poll(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(
zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL,
)
receiver_pollitem = ffi.new('zmq_pollitem_t*')
receiver_pollitem.socket = receiver
receiver_pollitem.fd = 0
receiver_pollitem.events = POLLIN | POLLOUT
receiver_pollitem.revents = 0
ret = C.zmq_poll(ffi.NULL, 0, 0)
assert ret == 0
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 0
ret = C.zmq_msg_send(zmq_msg, sender, 0)
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
assert ret == 5
time.sleep(0.2)
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 1
assert int(receiver_pollitem.revents) & POLLIN
assert not int(receiver_pollitem.revents) & POLLOUT
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert ret_recv == 5
assert 5 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
)
sender_pollitem = ffi.new('zmq_pollitem_t*')
sender_pollitem.socket = sender
sender_pollitem.fd = 0
sender_pollitem.events = POLLIN | POLLOUT
sender_pollitem.revents = 0
ret = C.zmq_poll(sender_pollitem, 1, 0)
assert ret == 0
zmq_msg_again = ffi.new('zmq_msg_t*')
message_again = ffi.new('char[11]', b'Hello Again')
C.zmq_msg_init_data(
zmq_msg_again,
ffi.cast('void*', message_again),
ffi.cast('size_t', 11),
ffi.NULL,
ffi.NULL,
)
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
time.sleep(0.2)
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
assert int(sender_pollitem.revents) & POLLIN
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
assert 11 == C.zmq_msg_size(zmq_msg2)
assert (
b"Hello Again"
== ffi.buffer(C.zmq_msg_data(zmq_msg2), int(C.zmq_msg_size(zmq_msg2)))[:]
)
assert 0 == C.zmq_close(sender)
assert 0 == C.zmq_close(receiver)
assert 0 == C.zmq_ctx_destroy(ctx)
assert 0 == C.zmq_msg_close(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg2)
assert 0 == C.zmq_msg_close(zmq_msg_again)

View File

@ -0,0 +1,19 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
import zmq.constants
def test_constants():
assert zmq.POLLIN is zmq.PollEvent.POLLIN
assert zmq.PUSH is zmq.SocketType.PUSH
assert zmq.constants.SUBSCRIBE is zmq.SocketOption.SUBSCRIBE
def test_socket_options():
assert zmq.IDENTITY is zmq.SocketOption.ROUTING_ID
assert zmq.IDENTITY._opt_type is zmq.constants._OptType.bytes
assert zmq.AFFINITY._opt_type is zmq.constants._OptType.int64
assert zmq.CURVE_SERVER._opt_type is zmq.constants._OptType.int
assert zmq.FD._opt_type is zmq.constants._OptType.fd

View File

@ -0,0 +1,401 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import os
import sys
import time
from queue import Queue
from threading import Event, Thread
from unittest import mock
from pytest import mark
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest
class KwargTestSocket(zmq.Socket):
test_kwarg_value = None
def __init__(self, *args, **kwargs):
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
super().__init__(*args, **kwargs)
class KwargTestContext(zmq.Context):
_socket_class = KwargTestSocket
class TestContext(BaseZMQTestCase):
def test_init(self):
c1 = self.Context()
assert isinstance(c1, self.Context)
del c1
c2 = self.Context()
assert isinstance(c2, self.Context)
del c2
c3 = self.Context()
assert isinstance(c3, self.Context)
del c3
_repr_cls = "zmq.Context"
def test_repr(self):
with self.Context() as ctx:
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' not in repr(ctx)
with ctx.socket(zmq.PUSH) as push:
assert f'{self._repr_cls}(1 socket)' in repr(ctx)
with ctx.socket(zmq.PULL) as pull:
assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
assert f'{self._repr_cls}()' in repr(ctx)
assert 'closed' in repr(ctx)
def test_dir(self):
ctx = self.Context()
assert 'socket' in dir(ctx)
if zmq.zmq_version_info() > (3,):
assert 'IO_THREADS' in dir(ctx)
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
m = mock.Mock(spec=self.context)
def test_term(self):
c = self.Context()
c.term()
assert c.closed
def test_context_manager(self):
with self.Context() as c:
pass
assert c.closed
def test_fail_init(self):
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
def test_term_hang(self):
rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
req.setsockopt(zmq.LINGER, 0)
req.send(b'hello', copy=False)
req.close()
rep.close()
self.context.term()
def test_instance(self):
ctx = self.Context.instance()
c2 = self.Context.instance(io_threads=2)
assert c2 is ctx
c2.term()
c3 = self.Context.instance()
c4 = self.Context.instance()
assert not c3 is c2
assert not c3.closed
assert c3 is c4
def test_instance_subclass_first(self):
self.context.term()
class SubContext(zmq.Context):
pass
sctx = SubContext.instance()
ctx = zmq.Context.instance()
ctx.term()
sctx.term()
assert type(ctx) is zmq.Context
assert type(sctx) is SubContext
def test_instance_subclass_second(self):
self.context.term()
class SubContextInherit(zmq.Context):
pass
class SubContextNoInherit(zmq.Context):
_instance = None
ctx = zmq.Context.instance()
sctx = SubContextInherit.instance()
sctx2 = SubContextNoInherit.instance()
ctx.term()
sctx.term()
sctx2.term()
assert type(ctx) is zmq.Context
assert type(sctx) is zmq.Context
assert type(sctx2) is SubContextNoInherit
def test_instance_threadsafe(self):
self.context.term() # clear default context
q = Queue()
# slow context initialization,
# to ensure that we are both trying to create one at the same time
class SlowContext(self.Context):
def __init__(self, *a, **kw):
time.sleep(1)
super().__init__(*a, **kw)
def f():
q.put(SlowContext.instance())
# call ctx.instance() in several threads at once
N = 16
threads = [Thread(target=f) for i in range(N)]
[t.start() for t in threads]
# also call it in the main thread (not first)
ctx = SlowContext.instance()
assert isinstance(ctx, SlowContext)
# check that all the threads got the same context
for i in range(N):
thread_ctx = q.get(timeout=5)
assert thread_ctx is ctx
# cleanup
ctx.term()
[t.join(timeout=5) for t in threads]
def test_socket_passes_kwargs(self):
test_kwarg_value = 'testing one two three'
with KwargTestContext() as ctx:
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
assert socket.test_kwarg_value is test_kwarg_value
def test_many_sockets(self):
"""opening and closing many sockets shouldn't cause problems"""
ctx = self.Context()
for i in range(16):
sockets = [ctx.socket(zmq.REP) for i in range(65)]
[s.close() for s in sockets]
# give the reaper a chance
time.sleep(1e-2)
ctx.term()
def test_sockopts(self):
"""setting socket options with ctx attributes"""
ctx = self.Context()
ctx.linger = 5
assert ctx.linger == 5
s = ctx.socket(zmq.REQ)
assert s.linger == 5
assert s.getsockopt(zmq.LINGER) == 5
s.close()
# check that subscribe doesn't get set on sockets that don't subscribe:
ctx.subscribe = b''
s = ctx.socket(zmq.REQ)
s.close()
ctx.term()
@mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
def test_destroy(self):
"""Context.destroy should close sockets"""
ctx = self.Context()
sockets = [ctx.socket(zmq.REP) for i in range(65)]
# close half of the sockets
[s.close() for s in sockets[::2]]
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
for s in sockets:
assert s.closed
def test_destroy_linger(self):
"""Context.destroy should set linger on closing sockets"""
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
req.send(b'hi')
time.sleep(1e-2)
self.context.destroy(linger=0)
# reaper is not instantaneous
time.sleep(1e-2)
for s in (req, rep):
assert s.closed
def test_term_noclose(self):
"""Context.term won't close sockets"""
ctx = self.Context()
s = ctx.socket(zmq.REQ)
assert not s.closed
t = Thread(target=ctx.term)
t.start()
t.join(timeout=0.1)
assert t.is_alive(), "Context should be waiting"
s.close()
t.join(timeout=0.1)
assert not t.is_alive(), "Context should have closed"
def test_gc(self):
"""test close&term by garbage collection alone"""
if PYPY:
raise SkipTest("GC doesn't work ")
# test credit @dln (GH #137):
def gcf():
def inner():
ctx = self.Context()
ctx.socket(zmq.PUSH)
inner()
gc.collect()
t = Thread(target=gcf)
t.start()
t.join(timeout=1)
assert not t.is_alive(), "Garbage collection should have cleaned up context"
def test_cyclic_destroy(self):
"""ctx.destroy should succeed when cyclic ref prevents gc"""
# test credit @dln (GH #137):
class CyclicReference:
def __init__(self, parent=None):
self.parent = parent
def crash(self, sock):
self.sock = sock
self.child = CyclicReference(self)
def crash_zmq():
ctx = self.Context()
sock = ctx.socket(zmq.PULL)
c = CyclicReference()
c.crash(sock)
ctx.destroy()
crash_zmq()
def test_term_thread(self):
"""ctx.term should not crash active threads (#139)"""
ctx = self.Context()
evt = Event()
evt.clear()
def block():
s = ctx.socket(zmq.REP)
s.bind_to_random_port('tcp://127.0.0.1')
evt.set()
try:
s.recv()
except zmq.ZMQError as e:
assert e.errno == zmq.ETERM
return
finally:
s.close()
self.fail("recv should have been interrupted with ETERM")
t = Thread(target=block)
t.start()
evt.wait(1)
assert evt.is_set(), "sync event never fired"
time.sleep(0.01)
ctx.term()
t.join(timeout=1)
assert not t.is_alive(), "term should have interrupted s.recv()"
def test_destroy_no_sockets(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.bind_to_random_port('tcp://127.0.0.1')
s.close()
ctx.destroy()
assert s.closed
assert ctx.closed
def test_ctx_opts(self):
if zmq.zmq_version_info() < (3,):
raise SkipTest("context options require libzmq 3")
ctx = self.Context()
ctx.set(zmq.MAX_SOCKETS, 2)
assert ctx.get(zmq.MAX_SOCKETS) == 2
ctx.max_sockets = 100
assert ctx.max_sockets == 100
assert ctx.get(zmq.MAX_SOCKETS) == 100
def test_copy(self):
c1 = self.Context()
c2 = copy.copy(c1)
c2b = copy.deepcopy(c1)
c3 = copy.deepcopy(c2)
assert c2._shadow
assert c3._shadow
assert c1.underlying == c2.underlying
assert c1.underlying == c3.underlying
assert c1.underlying == c2b.underlying
s = c3.socket(zmq.PUB)
s.close()
c1.term()
def test_shadow(self):
ctx = self.Context()
ctx2 = self.Context.shadow(ctx.underlying)
assert ctx.underlying == ctx2.underlying
s = ctx.socket(zmq.PUB)
s.close()
del ctx2
assert not ctx.closed
s = ctx.socket(zmq.PUB)
ctx2 = self.Context.shadow(ctx.underlying)
s2 = ctx2.socket(zmq.PUB)
s.close()
s2.close()
ctx.term()
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
del ctx2
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket, zstr
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
a = zsocket.new(ctx, zmq.PUSH)
zsocket.bind(a, "inproc://a")
ctx2 = self.Context.shadow_pyczmq(ctx)
b = ctx2.socket(zmq.PULL)
b.connect("inproc://a")
zstr.send(a, b'hi')
rcvd = self.recv(b)
assert rcvd == b'hi'
b.close()
@mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
def test_fork_instance(self):
ctx = self.Context.instance()
parent_ctx_id = id(ctx)
r_fd, w_fd = os.pipe()
reader = os.fdopen(r_fd, 'r')
child_pid = os.fork()
if child_pid == 0:
ctx = self.Context.instance()
writer = os.fdopen(w_fd, 'w')
child_ctx_id = id(ctx)
ctx.term()
writer.write(str(child_ctx_id) + "\n")
writer.flush()
writer.close()
os._exit(0)
else:
os.close(w_fd)
child_id_s = reader.readline()
reader.close()
assert child_id_s
assert int(child_id_s) != parent_ctx_id
ctx.term()
if False: # disable green context tests
class TestContextGreen(GreenTest, TestContext):
"""gevent subclass of context tests"""
# skip tests that use real threads:
test_gc = GreenTest.skip_green
test_term_thread = GreenTest.skip_green
test_destroy_linger = GreenTest.skip_green
_repr_cls = "zmq.green.Context"

View File

@ -0,0 +1,43 @@
import sys
import pytest
import zmq
@pytest.mark.skipif(
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
)
@pytest.mark.skipif(
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
)
@pytest.mark.parametrize('language_level', [3, 2])
def test_cython(language_level, request, tmpdir):
import pyximport
assert 'zmq.tests.cython_ext' not in sys.modules
importers = pyximport.install(
setup_args=dict(include_dirs=zmq.get_includes()),
language_level=language_level,
build_dir=str(tmpdir),
)
cython_ext = None
def unimport():
pyximport.uninstall(*importers)
sys.modules.pop('zmq.tests.cython_ext', None)
request.addfinalizer(unimport)
# this import tests the compilation
from . import cython_ext
assert hasattr(cython_ext, 'send_recv_test')
# call the compiled function
# this shouldn't do much
msg = b'my msg'
received = cython_ext.send_recv_test(msg)
assert received == msg

View File

@ -0,0 +1,396 @@
import threading
from pytest import fixture, raises
import zmq
from zmq.decorators import context, socket
from zmq.tests import BaseZMQTestCase, term_context
##############################################
# Test cases for @context
##############################################
@fixture(autouse=True)
def term_context_instance(request):
request.addfinalizer(lambda: term_context(zmq.Context.instance(), timeout=10))
def test_ctx():
@context()
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
test()
def test_ctx_orig_args():
@context()
def f(foo, bar, ctx, baz=None):
assert isinstance(ctx, zmq.Context), ctx
assert foo == 42
assert bar is True
assert baz == 'mock'
f(42, True, baz='mock')
def test_ctx_arg_naming():
@context('myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_args():
@context('ctx', 5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_arg_kwarg():
@context('ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kw_naming():
@context(name='myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_kwargs():
@context(name='ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kwargs_default():
@context(name='ctx', io_threads=5)
def test(ctx=None):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_keyword_miss():
@context(name='ctx')
def test(other_name):
pass # the keyword ``ctx`` not found
with raises(TypeError):
test()
def test_ctx_multi_assign():
@context(name='ctx')
def test(ctx):
pass # explosion
with raises(TypeError):
test('mock')
def test_ctx_reinit():
result = {'foo': None, 'bar': None}
@context()
def f(key, ctx):
assert isinstance(ctx, zmq.Context), ctx
result[key] = ctx
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_multi_thread():
@context()
@context()
def f(foo, bar):
assert isinstance(foo, zmq.Context), foo
assert isinstance(bar, zmq.Context), bar
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
##############################################
# Test cases for @socket
##############################################
def test_ctx_skt():
@context()
@socket(zmq.PUB)
def test(ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
assert skt.type == zmq.PUB
test()
def test_skt_name():
@context()
@socket('myskt', zmq.PUB)
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_skt_kwarg():
@context()
@socket(zmq.PUB, name='myskt')
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_ctx_skt_name():
@context('ctx')
@socket('skt', zmq.PUB, context_name='ctx')
def test(ctx, skt):
assert isinstance(skt, zmq.Socket), skt
assert isinstance(ctx, zmq.Context), ctx
assert skt.type == zmq.PUB
test()
def test_skt_default_ctx():
@socket(zmq.PUB)
def test(skt):
assert isinstance(skt, zmq.Socket), skt
assert skt.context is zmq.Context.instance()
assert skt.type == zmq.PUB
test()
def test_skt_reinit():
result = {'foo': None, 'bar': None}
@socket(zmq.PUB)
def f(key, skt):
assert isinstance(skt, zmq.Socket), skt
result[key] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_skt_reinit():
result = {'foo': {'ctx': None, 'skt': None}, 'bar': {'ctx': None, 'skt': None}}
@context()
@socket(zmq.PUB)
def f(key, ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
result[key]['ctx'] = ctx
result[key]['skt'] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo']['ctx'] is not None, result
assert result['foo']['skt'] is not None, result
assert result['bar']['ctx'] is not None, result
assert result['bar']['skt'] is not None, result
assert result['foo']['ctx'] is not result['bar']['ctx'], result
assert result['foo']['skt'] is not result['bar']['skt'], result
def test_skt_type_miss():
@context()
@socket('myskt')
def f(ctx, myskt):
pass # the socket type is missing
with raises(TypeError):
f()
def test_multi_skts():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_single_ctx():
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(ctx, pub, sub, push):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is ctx
assert sub.context is ctx
assert push.context is ctx
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_with_name():
@socket('foo', zmq.PUSH)
@socket('bar', zmq.SUB)
@socket('baz', zmq.PUB)
def test(foo, bar, baz):
assert isinstance(foo, zmq.Socket), foo
assert isinstance(bar, zmq.Socket), bar
assert isinstance(baz, zmq.Socket), baz
assert foo.context is zmq.Context.instance()
assert bar.context is zmq.Context.instance()
assert baz.context is zmq.Context.instance()
assert foo.type == zmq.PUSH
assert bar.type == zmq.SUB
assert baz.type == zmq.PUB
test()
def test_func_return():
@context()
def f(ctx):
assert isinstance(ctx, zmq.Context), ctx
return 'something'
assert f() == 'something'
def test_skt_multi_thread():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def f(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
assert len(set(map(id, [pub, sub, push]))) == 3
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
class TestMethodDecorators(BaseZMQTestCase):
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
assert isinstance(self, TestMethodDecorators), self
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'bar'
assert pub.context is ctx
assert sub.context is ctx
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
def test_multi_skts_method(self):
self.multi_skts_method()
def test_multi_skts_method_other_args(self):
@socket(zmq.PUB)
@socket(zmq.SUB)
def f(foo, pub, sub, bar=None):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'mock'
assert bar == 'fake'
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
f('mock', bar='fake')

View File

@ -0,0 +1,168 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest, have_gevent
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestDevice(BaseZMQTestCase):
def test_device_types(self):
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
assert dev.device_type == devtype
del dev
def test_device_attributes(self):
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
assert dev.in_type == zmq.SUB
assert dev.out_type == zmq.PUB
assert dev.device_type == zmq.QUEUE
assert dev.daemon == True
del dev
def test_single_socket_forwarder_connect(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_in('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_out('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello again'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
def test_single_socket_forwarder_bind(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i' % port)
dev.start()
time.sleep(0.25)
msg = b'hello again'
req.send(msg)
assert msg == self.recv(req)
del dev
req.close()
def test_device_bind_to_random_with_args(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_device_bind_to_random_binderror(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
try:
for i in range(11):
dev.bind_in_to_random_port(iface, min_port=10000, max_port=10010)
except zmq.ZMQBindError as e:
return
else:
self.fail('Should have failed')
def test_proxy(self):
if zmq.zmq_version_info() < (3, 2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
push.send(msg)
self.sockets.extend([push, pull, mon])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
def test_proxy_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (3, 2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
if have_gevent:
import gevent
import zmq.green
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
def test_green_device(self):
rep = self.context.socket(zmq.REP)
req = self.context.socket(zmq.REQ)
self.sockets.extend([req, rep])
port = rep.bind_to_random_port('tcp://127.0.0.1')
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
req.connect('tcp://127.0.0.1:%i' % port)
req.send(b'hi')
timeout = gevent.Timeout(3)
timeout.start()
receiver = gevent.spawn(req.recv)
assert receiver.get(2) == b'hi'
timeout.cancel()
g.kill(block=True)

View File

@ -0,0 +1,47 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import pytest
import zmq
from zmq.tests import BaseZMQTestCase
class TestDraftSockets(BaseZMQTestCase):
def setUp(self):
if not zmq.DRAFT_API:
pytest.skip("draft api unavailable")
super().setUp()
def test_client_server(self):
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
client.send(b'request')
msg = self.recv(server, copy=False)
assert msg.routing_id is not None
server.send(b'reply', routing_id=msg.routing_id)
reply = self.recv(client)
assert reply == b'reply'
def test_radio_dish(self):
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
dish.rcvtimeo = 250
group = 'mygroup'
dish.join(group)
received_count = 0
received = set()
sent = set()
for i in range(10):
msg = str(i).encode('ascii')
sent.add(msg)
radio.send(msg, group=group)
try:
recvd = dish.recv()
except zmq.Again:
time.sleep(0.1)
else:
received.add(recvd)
received_count += 1
# assert that we got *something*
assert len(received.intersection(sent)) >= 5

View File

@ -0,0 +1,37 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from threading import Thread
import zmq
from zmq import Again, ContextTerminated, ZMQError, strerror
from zmq.tests import BaseZMQTestCase
class TestZMQError(BaseZMQTestCase):
def test_strerror(self):
"""test that strerror gets the right type."""
for i in range(10):
e = strerror(i)
assert isinstance(e, str)
def test_zmqerror(self):
for errno in range(10):
e = ZMQError(errno)
assert e.errno == errno
assert str(e) == strerror(errno)
def test_again(self):
s = self.context.socket(zmq.REP)
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
s.close()
def atest_ctxterm(self):
s = self.context.socket(zmq.REP)
t = Thread(target=self.context.term)
t.start()
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
s.close()
t.join()

View File

@ -0,0 +1,26 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
from pytest import mark
import zmq
only_bundled = mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
@mark.skipif('zmq.zmq_version_info() < (4, 1)')
def test_has():
assert not zmq.has('something weird')
@only_bundled
def test_has_curve():
"""bundled libzmq has curve support"""
assert zmq.has('curve')
@only_bundled
def test_has_ipc():
"""bundled libzmq has ipc support"""
assert zmq.has('ipc')

View File

@ -0,0 +1,353 @@
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import json
import os
import sys
from datetime import timedelta
import pytest
gen = pytest.importorskip('tornado.gen')
from tornado.ioloop import IOLoop
import zmq
from zmq.eventloop import future
from zmq.tests import BaseZMQTestCase
class TestFutureSocket(BaseZMQTestCase):
Context = future.Context
def setUp(self):
self.loop = IOLoop()
self.loop.make_current()
super().setUp()
def tearDown(self):
super().tearDown()
if self.loop:
self.loop.close(all_fds=True)
IOLoop.clear_current()
IOLoop.clear_instance()
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, future.Socket)
s.close()
def test_instance_subclass_first(self):
actx = self.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = self.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_recv_multipart(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
await a.send(b"hi")
recvd = await f
assert recvd == [b'hi']
self.loop.run_sync(test)
def test_recv(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.done()
assert f1.result() == b'hi'
assert recvd == b'there'
self.loop.run_sync(test)
def test_recv_cancel(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_recv_timeout(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
await f1
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f2.done()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_send_timeout(self):
async def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
await s.send(b"not going anywhere")
self.loop.run_sync(test)
def test_send_noblock(self):
async def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
await s.send(b"not going anywhere", flags=zmq.NOBLOCK)
self.loop.run_sync(test)
def test_send_multipart_noblock(self):
async def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
await s.send_multipart([b"not going anywhere"], flags=zmq.NOBLOCK)
self.loop.run_sync(test)
def test_recv_string(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = 'πøøπ'
await a.send_string(msg)
recvd = await f
assert f.done()
assert f.result() == msg
assert recvd == msg
self.loop.run_sync(test)
def test_recv_json(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
await a.send_json(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_sync(test)
def test_recv_json_cancelled(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
await gen.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
with pytest.raises(future.CancelledError):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await gen.sleep(0)
# make sure cancelled recv didn't eat up event
recvd = await gen.with_timeout(timedelta(seconds=5), b.recv_json())
assert recvd == obj
self.loop.run_sync(test)
def test_recv_pyobj(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj
self.loop.run_sync(test)
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
await a.send_serialized(msg, serialize)
recvd = await b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
await b.send_serialized(recvd, serialize)
r2 = await a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
self.loop.run_sync(test)
def test_custom_serialize_error(self):
async def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
with pytest.raises(TypeError):
await a.send_serialized(json, json.dumps)
await a.send(b"not json")
with pytest.raises(TypeError):
await b.recv_serialized(json.loads)
self.loop.run_sync(test)
def test_poll(self):
async def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
assert f.done()
assert f.result() == 0
f = b.poll(timeout=1)
assert not f.done()
evt = await f
assert evt == 0
f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b'hi', b'there']
self.loop.run_sync(test)
@pytest.mark.skipif(
sys.platform.startswith('win'), reason='Windows unsupported socket type'
)
def test_poll_base_socket(self):
async def test():
ctx = zmq.Context()
url = 'inproc://test'
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = future.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b'hi', b'there'])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b'hi', b'there']
a.close()
b.close()
ctx.term()
self.loop.run_sync(test)
def test_close_all_fds(self):
s = self.socket(zmq.PUB)
s._get_loop()
self.loop.close(all_fds=True)
self.loop = None # avoid second close later
assert s.closed
@pytest.mark.skipif(
sys.platform.startswith('win'),
reason='Windows does not support polling on files',
)
def test_poll_raw(self):
async def test():
p = future.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = await p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b'x')
w.flush()
evts = await p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b'x'
r.close()
w.close()
self.loop.run_sync(test)

View File

@ -0,0 +1,65 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
# flake8: noqa: F401
from unittest import TestCase
import pytest
class TestImports(TestCase):
"""Test Imports - the quickest test to ensure that we haven't
introduced version-incompatible syntax errors."""
def test_toplevel(self):
"""test toplevel import"""
import zmq
def test_core(self):
"""test core imports"""
from zmq import (
Context,
Frame,
Poller,
Socket,
constants,
device,
proxy,
pyzmq_version,
pyzmq_version_info,
zmq_version,
zmq_version_info,
)
def test_devices(self):
"""test device imports"""
import zmq.devices
from zmq.devices import basedevice, monitoredqueue, monitoredqueuedevice
def test_log(self):
"""test log imports"""
import zmq.log
from zmq.log import handlers
def test_eventloop(self):
"""test eventloop imports"""
try:
import tornado
except ImportError:
pytest.skip('requires tornado')
import zmq.eventloop
from zmq.eventloop import ioloop, zmqstream
def test_utils(self):
"""test util imports"""
import zmq.utils
from zmq.utils import jsonapi, strtypes
def test_ssh(self):
"""test ssh imports"""
from zmq.ssh import tunnel
def test_decorators(self):
"""test decorators imports"""
from zmq.decorators import context, socket

View File

@ -0,0 +1,33 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
class TestIncludes(TestCase):
def test_get_includes(self):
from os.path import basename
includes = zmq.get_includes()
assert isinstance(includes, list)
assert len(includes) >= 2
parent = includes[0]
assert isinstance(parent, str)
utilsdir = includes[1]
assert isinstance(utilsdir, str)
utils = basename(utilsdir)
assert utils == "utils"
def test_get_library_dirs(self):
from os.path import basename
libdirs = zmq.get_library_dirs()
assert isinstance(libdirs, list)
assert len(libdirs) == 1
parent = libdirs[0]
assert isinstance(parent, str)
libdir = basename(parent)
assert libdir == "zmq"

View File

@ -0,0 +1,142 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
import threading
import time
import pytest
import zmq
from zmq.tests import BaseZMQTestCase, have_gevent
try:
from tornado.ioloop import IOLoop as BaseIOLoop
from zmq.eventloop import ioloop
except ImportError:
_tornado = False
else:
_tornado = True
# tornado 5 with asyncio disables custom IOLoop implementations
t5asyncio = False
if _tornado:
import tornado
if tornado.version_info >= (5,) and asyncio:
t5asyncio = True
def printer():
os.system("say hello")
raise Exception
print(time.time())
class Delay(threading.Thread):
def __init__(self, f, delay=1):
self.f = f
self.delay = delay
self.aborted = False
self.cond = threading.Condition()
super().__init__()
def run(self):
self.cond.acquire()
self.cond.wait(self.delay)
self.cond.release()
if not self.aborted:
self.f()
def abort(self):
self.aborted = True
self.cond.acquire()
self.cond.notify()
self.cond.release()
class TestIOLoop(BaseZMQTestCase):
if _tornado:
IOLoop = ioloop.IOLoop
def setUp(self):
if not _tornado:
pytest.skip("tornado required")
super().setUp()
if asyncio:
asyncio.set_event_loop(asyncio.new_event_loop())
def tearDown(self):
super().tearDown()
BaseIOLoop.clear_current()
BaseIOLoop.clear_instance()
def test_simple(self):
"""simple IOLoop creation test"""
loop = self.IOLoop()
loop.make_current()
dc = ioloop.PeriodicCallback(loop.stop, 200)
pc = ioloop.PeriodicCallback(lambda: None, 10)
pc.start()
dc.start()
t = Delay(loop.stop, 1)
t.start()
loop.start()
if t.is_alive():
t.abort()
else:
self.fail("IOLoop failed to exit")
def test_instance(self):
"""IOLoop.instance returns the right object"""
loop = self.IOLoop.instance()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.instance()
assert base_loop is loop
def test_current(self):
"""IOLoop.current returns the right object"""
loop = ioloop.IOLoop.current()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.current()
assert base_loop is loop
def test_close_all(self):
"""Test close(all_fds=True)"""
loop = self.IOLoop.current()
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ)
loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ)
assert req.closed == False
assert rep.closed == False
loop.close(all_fds=True)
assert req.closed == True
assert rep.closed == True
if have_gevent and _tornado:
import zmq.green.eventloop.ioloop as green_ioloop
class TestIOLoopGreen(TestIOLoop):
IOLoop = green_ioloop.IOLoop
def xtest_instance(self):
"""Green IOLoop.instance returns the right object"""
loop = self.IOLoop.instance()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.instance()
assert base_loop is loop
def xtest_current(self):
"""Green IOLoop.current returns the right object"""
loop = self.IOLoop.current()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.current()
assert base_loop is loop

View File

@ -0,0 +1,178 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import time
import zmq
from zmq.log import handlers
from zmq.tests import BaseZMQTestCase
class TestPubLog(BaseZMQTestCase):
iface = 'inproc://zmqlog'
topic = 'zmq'
@property
def logger(self):
# print dir(self)
logger = logging.getLogger('zmqtest')
logger.setLevel(logging.DEBUG)
return logger
def connect_handler(self, topic=None):
topic = self.topic if topic is None else topic
logger = self.logger
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = topic
logger.addHandler(handler)
sub.setsockopt(zmq.SUBSCRIBE, topic.encode())
time.sleep(0.1)
return logger, handler, sub
def test_init_iface(self):
logger = self.logger
ctx = self.context
handler = handlers.PUBHandler(self.iface)
assert not handler.ctx is ctx
self.sockets.append(handler.socket)
# handler.ctx.term()
handler = handlers.PUBHandler(self.iface, self.context)
self.sockets.append(handler.socket)
assert handler.ctx is ctx
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
sub = ctx.socket(zmq.SUB)
self.sockets.append(sub)
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
sub.connect(self.iface)
import time
time.sleep(0.25)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
assert topic == b'zmq.INFO'
assert msg2 == (msg1 + "\n").encode("utf8")
logger.removeHandler(handler)
def test_init_socket(self):
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
logger = self.logger
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
assert handler.socket is pub
assert handler.ctx is pub.context
assert handler.ctx is self.context
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
import time
time.sleep(0.1)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
assert topic == b'zmq.INFO'
assert msg2 == (msg1 + "\n").encode("utf8")
logger.removeHandler(handler)
def test_root_topic(self):
logger, handler, sub = self.connect_handler()
handler.socket.bind(self.iface)
sub2 = sub.context.socket(zmq.SUB)
self.sockets.append(sub2)
sub2.connect(self.iface)
sub2.setsockopt(zmq.SUBSCRIBE, b'')
handler.root_topic = b'twoonly'
msg1 = 'ignored'
logger.info(msg1)
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
topic, msg2 = sub2.recv_multipart()
assert topic == b'twoonly.INFO'
assert msg2 == (msg1 + '\n').encode()
logger.removeHandler(handler)
def test_blank_root_topic(self):
logger, handler, sub_everything = self.connect_handler()
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
handler.socket.bind(self.iface)
sub_only_info = sub_everything.context.socket(zmq.SUB)
self.sockets.append(sub_only_info)
sub_only_info.connect(self.iface)
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
handler.setRootTopic(b'')
msg_debug = 'debug_message'
logger.debug(msg_debug)
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
topic, msg_debug_response = sub_everything.recv_multipart()
assert topic == b'DEBUG'
msg_info = 'info_message'
logger.info(msg_info)
topic, msg_info_response_everything = sub_everything.recv_multipart()
assert topic == b'INFO'
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
assert topic == b'INFO'
assert msg_info_response_everything == msg_info_response_onlyinfo
logger.removeHandler(handler)
def test_unicode_message(self):
logger, handler, sub = self.connect_handler()
base_topic = (self.topic + '.INFO').encode()
for msg, expected in [
('hello', [base_topic, b'hello\n']),
('héllo', [base_topic, 'héllo\n'.encode()]),
('tøpic::héllo', [base_topic + '.tøpic'.encode(), 'héllo\n'.encode()]),
]:
logger.info(msg)
received = sub.recv_multipart()
assert received == expected
logger.removeHandler(handler)
def test_set_info_formatter_via_property(self):
logger, handler, sub = self.connect_handler()
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'info message UNITTEST\n'
logger.removeHandler(handler)
def test_custom_global_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST %(message)s")
handler.setFormatter(formatter)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST info message'
logger.debug('debug message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST debug message'
logger.removeHandler(handler)
def test_custom_debug_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
handler.setFormatter(formatter, logging.DEBUG)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
logger.info('info message')
topic, msg = sub.recv_multipart()
assert msg == b'info message\n'
logger.debug('debug message')
topic, msg = sub.recv_multipart()
assert msg == b'UNITTEST DEBUG debug message'
logger.removeHandler(handler)

View File

@ -0,0 +1,370 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import sys
try:
from sys import getrefcount
except ImportError:
grc = None
else:
grc = getrefcount
import time
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest, skip_pypy
# some useful constants:
x = b'x'
if grc:
rc0 = grc(x)
v = memoryview(x)
view_rc = grc(x) - rc0
def await_gc(obj, rc):
"""wait for refcount on an object to drop to an expected value
Necessary because of the zero-copy gc thread,
which can take some time to receive its DECREF message.
"""
# count refs for this function
if sys.version_info < (3, 11):
my_refs = 2
else:
my_refs = 1
for i in range(50):
# rc + 2 because of the refs in this function
if grc(obj) <= rc + my_refs:
return
time.sleep(0.05)
class TestFrame(BaseZMQTestCase):
def tearDown(self):
super().tearDown()
for i in range(3):
gc.collect()
@skip_pypy
def test_above_30(self):
"""Message above 30 bytes are never copied by 0MQ."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = grc(s)
m = zmq.Frame(s, copy=False)
assert grc(s) == rc + 2
del m
await_gc(s, rc)
assert grc(s) == rc
del s
def test_str(self):
"""Test the str representations of the Frames."""
for i in range(16):
s = (2**i) * x
m = zmq.Frame(s)
m_str = str(m)
m_str_b = m_str.encode()
assert s == m_str_b
def test_bytes(self):
"""Test the Frame.bytes property."""
for i in range(1, 16):
s = (2**i) * x
m = zmq.Frame(s)
b = m.bytes
assert s == m.bytes
if not PYPY:
# check that it copies
assert b is not s
# check that it copies only once
assert b is m.bytes
def test_unicode(self):
"""Test the unicode representations of the Frames."""
s = 'asdf'
self.assertRaises(TypeError, zmq.Frame, s)
for i in range(16):
s = (2**i) * '§'
m = zmq.Frame(s.encode('utf8'))
assert s == m.bytes.decode('utf8')
def test_len(self):
"""Test the len of the Frames."""
for i in range(16):
s = (2**i) * x
m = zmq.Frame(s)
assert len(s) == len(m)
@skip_pypy
def test_lifecycle1(self):
"""Run through a ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = rc_0 = grc(s)
m = zmq.Frame(s, copy=False)
rc += 2
assert grc(s) == rc
m2 = copy.copy(m)
rc += 1
assert grc(s) == rc
# no increase in refcount for accessing buffer
# which references m2 directly
buf = m2.buffer
assert grc(s) == rc
assert s == str(m).encode()
assert s == bytes(m2)
assert s == m.bytes
assert s == bytes(buf)
# assert s is str(m)
# assert s is str(m2)
del m2
assert grc(s) == rc
# buf holds direct reference to m2 which holds
del buf
rc -= 1
assert grc(s) == rc
del m
rc -= 2
await_gc(s, rc)
assert grc(s) == rc
assert rc == rc_0
del s
@skip_pypy
def test_lifecycle2(self):
"""Run through a different ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i) * x
rc = rc_0 = grc(s)
m = zmq.Frame(s, copy=False)
rc += 2
assert grc(s) == rc
m2 = copy.copy(m)
rc += 1
assert grc(s) == rc
# no increase in refcount for accessing buffer
# which references m directly
buf = m.buffer
assert grc(s) == rc
assert s == str(m).encode()
assert s == bytes(m2)
assert s == m2.bytes
assert s == m.bytes
assert s == bytes(buf)
# assert s is str(m)
# assert s is str(m2)
del buf
assert grc(s) == rc
del m
rc -= 1
assert grc(s) == rc
del m2
rc -= 2
await_gc(s, rc)
assert grc(s) == rc
assert rc == rc_0
del s
def test_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
assert not m.tracker.done
pm = zmq.MessageTracker(m)
assert not pm.done
del m
for i in range(3):
gc.collect()
for i in range(10):
if pm.done:
break
time.sleep(0.1)
assert pm.done
def test_no_tracker(self):
m = zmq.Frame(b'asdf', track=False)
assert m.tracker == None
m2 = copy.copy(m)
assert m2.tracker == None
self.assertRaises(ValueError, zmq.MessageTracker, m)
def test_multi_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
m2 = zmq.Frame(b'whoda', copy=False, track=True)
mt = zmq.MessageTracker(m, m2)
assert not m.tracker.done
assert not mt.done
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
del m
for i in range(3):
gc.collect()
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
assert not mt.done
del m2
for i in range(3):
gc.collect()
assert mt.wait(0.1) is None
assert mt.done
def test_buffer_in(self):
"""test using a buffer as input"""
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
zmq.Frame(memoryview(ins))
def test_bad_buffer_in(self):
"""test using a bad object"""
self.assertRaises(TypeError, zmq.Frame, 5)
self.assertRaises(TypeError, zmq.Frame, object())
def test_buffer_out(self):
"""receiving buffered output"""
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
m = zmq.Frame(ins)
outb = m.buffer
assert isinstance(outb, memoryview)
assert outb is m.buffer
assert m.buffer is m.buffer
def test_memoryview_shape(self):
"""memoryview shape info"""
data = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
n = len(data)
f = zmq.Frame(data)
view1 = f.buffer
assert view1.ndim == 1
assert view1.shape == (n,)
assert view1.tobytes() == data
view2 = memoryview(f)
assert view2.ndim == 1
assert view2.shape == (n,)
assert view2.tobytes() == data
def test_multisend(self):
"""ensure that a message remains intact after multiple sends"""
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
s = b"message"
m = zmq.Frame(s)
assert s == m.bytes
a.send(m, copy=False)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=False)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=True)
time.sleep(0.1)
assert s == m.bytes
a.send(m, copy=True)
time.sleep(0.1)
assert s == m.bytes
for i in range(4):
r = b.recv()
assert s == r
assert s == m.bytes
def test_memoryview(self):
"""test messages from memoryview"""
s = b'carrotjuice'
memoryview(s)
m = zmq.Frame(s)
buf = m.buffer
s2 = buf.tobytes()
assert s2 == s
assert m.bytes == s
def test_noncopying_recv(self):
"""check for clobbering message buffers"""
null = b'\0' * 64
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(32):
# try a few times
sb.send(null, copy=False)
m = sa.recv(copy=False)
mb = m.bytes
# buf = memoryview(m)
buf = m.buffer
del m
for i in range(5):
ff = b'\xff' * (40 + i * 10)
sb.send(ff, copy=False)
m2 = sa.recv(copy=False)
b = buf.tobytes()
assert b == null
assert mb == null
assert m2.bytes == ff
assert type(m2.bytes) is bytes
def test_noncopying_memoryview(self):
"""test non-copying memmoryview messages"""
null = b'\0' * 64
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(32):
# try a few times
sb.send(memoryview(null), copy=False)
m = sa.recv(copy=False)
buf = memoryview(m)
for i in range(5):
ff = b'\xff' * (40 + i * 10)
sb.send(memoryview(ff), copy=False)
m2 = sa.recv(copy=False)
buf2 = memoryview(m2)
assert buf.tobytes() == null
assert not buf.readonly
assert buf2.tobytes() == ff
assert not buf2.readonly
assert type(buf) is memoryview
def test_buffer_numpy(self):
"""test non-copying numpy array messages"""
try:
import numpy
from numpy.testing import assert_array_equal
except ImportError:
raise SkipTest("requires numpy")
rand = numpy.random.randint
shapes = [rand(2, 5) for i in range(5)]
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
dtypes = [int, float, '>i4', 'B']
for i in range(1, len(shapes) + 1):
shape = shapes[:i]
for dt in dtypes:
A = numpy.empty(shape, dtype=dt)
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
A['a'] = 1024
A['b'] = 1e9
A['c'] = 'hello there'
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
@skip_pypy
def test_frame_more(self):
"""test Frame.more attribute"""
frame = zmq.Frame(b"hello")
assert not frame.more
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
sa.send_multipart([b'hi', b'there'])
frame = self.recv(sb, copy=False)
assert frame.more
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
assert frame.get(zmq.MORE)
frame = self.recv(sb, copy=False)
assert not frame.more
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
assert not frame.get(zmq.MORE)

View File

@ -0,0 +1,76 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, require_zmq_4
from zmq.utils.monitor import recv_monitor_message
class TestSocketMonitor(BaseZMQTestCase):
@require_zmq_4
def test_monitor(self):
"""Test monitoring interface for sockets."""
s_rep = self.context.socket(zmq.REP)
s_req = self.context.socket(zmq.REQ)
self.sockets.extend([s_rep, s_req])
s_req.bind("tcp://127.0.0.1:6666")
# try monitoring the REP socket
s_rep.monitor(
"inproc://monitor.rep",
zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED,
)
# create listening socket for monitor
s_event = self.context.socket(zmq.PAIR)
self.sockets.append(s_event)
s_event.connect("inproc://monitor.rep")
s_event.linger = 0
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6666")
m = recv_monitor_message(s_event)
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
# test receive event for connected event
m = recv_monitor_message(s_event)
assert m['event'] == zmq.EVENT_CONNECTED
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
# test monitor can be disabled.
s_rep.disable_monitor()
m = recv_monitor_message(s_event)
assert m['event'] == zmq.EVENT_MONITOR_STOPPED
@require_zmq_4
def test_monitor_repeat(self):
s = self.socket(zmq.PULL)
m = s.get_monitor_socket()
self.sockets.append(m)
m2 = s.get_monitor_socket()
assert m is m2
s.disable_monitor()
evt = recv_monitor_message(m)
assert evt['event'] == zmq.EVENT_MONITOR_STOPPED
m.close()
s.close()
@require_zmq_4
def test_monitor_connected(self):
"""Test connected monitoring socket."""
s_rep = self.context.socket(zmq.REP)
s_req = self.context.socket(zmq.REQ)
self.sockets.extend([s_rep, s_req])
s_req.bind("tcp://127.0.0.1:6667")
# try monitoring the REP socket
# create listening socket for monitor
s_event = s_rep.get_monitor_socket()
s_event.linger = 0
self.sockets.append(s_event)
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6667")
m = recv_monitor_message(s_event)
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
# test receive event for connected event
m = recv_monitor_message(s_event)
assert m['event'] == zmq.EVENT_CONNECTED
assert m['endpoint'] == b"tcp://127.0.0.1:6667"

View File

@ -0,0 +1,219 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase
if PYPY or zmq.zmq_version_info() >= (4, 1):
# cleanup of shared Context doesn't work on PyPy
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
devices.Device.context_factory = zmq.Context
class TestMonitoredQueue(BaseZMQTestCase):
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
self.device = devices.ThreadMonitoredQueue(
zmq.PAIR, zmq.PAIR, zmq.PUB, in_prefix, out_prefix
)
alice = self.context.socket(zmq.PAIR)
bob = self.context.socket(zmq.PAIR)
mon = self.context.socket(zmq.SUB)
aport = alice.bind_to_random_port('tcp://127.0.0.1')
bport = bob.bind_to_random_port('tcp://127.0.0.1')
mport = mon.bind_to_random_port('tcp://127.0.0.1')
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
self.device.connect_in("tcp://127.0.0.1:%i" % aport)
self.device.connect_out("tcp://127.0.0.1:%i" % bport)
self.device.connect_mon("tcp://127.0.0.1:%i" % mport)
self.device.start()
time.sleep(0.2)
try:
# this is currenlty necessary to ensure no dropped monitor messages
# see LIBZMQ-248 for more info
mon.recv_multipart(zmq.NOBLOCK)
except zmq.ZMQError:
pass
self.sockets.extend([alice, bob, mon])
return alice, bob, mon
def teardown_device(self):
for socket in self.sockets:
socket.close()
del socket
del self.device
def test_reply(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_queue(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_monitor(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_prefix(self):
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'bar'] + bobs == mons
self.teardown_device()
def test_monitor_subscribe(self):
alice, bob, mon = self.build_device(b"out")
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_router_router(self):
"""test router-router MQ devices"""
dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out'
)
self.device = dev
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
a = self.context.socket(zmq.DEALER)
a.identity = b'a'
b = self.context.socket(zmq.DEALER)
b.identity = b'b'
self.sockets.extend([a, b])
a.connect('tcp://127.0.0.1:%i' % porta)
b.connect('tcp://127.0.0.1:%i' % portb)
dev.start()
time.sleep(1)
if zmq.zmq_version_info() >= (3, 1, 0):
# flush erroneous poll state, due to LIBZMQ-280
ping_msg = [b'ping', b'pong']
for s in (a, b):
s.send_multipart(ping_msg)
try:
s.recv(zmq.NOBLOCK)
except zmq.ZMQError:
pass
msg = [b'hello', b'there']
a.send_multipart([b'b'] + msg)
bmsg = self.recv_multipart(b)
assert bmsg == [b'a'] + msg
b.send_multipart(bmsg)
amsg = self.recv_multipart(a)
assert amsg == [b'b'] + msg
self.teardown_device()
def test_default_mq_args(self):
self.device = dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.DEALER, zmq.PUB
)
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
# this will raise if default args are wrong
dev.start()
self.teardown_device()
def test_mq_check_prefix(self):
ins = self.context.socket(zmq.ROUTER)
outs = self.context.socket(zmq.DEALER)
mons = self.context.socket(zmq.PUB)
self.sockets.extend([ins, outs, mons])
ins = 'in'
outs = 'out'
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)

View File

@ -0,0 +1,34 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestMultipart(BaseZMQTestCase):
def test_router_dealer(self):
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
msg1 = b'message1'
dealer.send(msg1)
self.recv(router)
more = router.rcvmore
assert more == True
msg2 = self.recv(router)
assert msg1 == msg2
more = router.rcvmore
assert more == False
def test_basic_multipart(self):
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg = [b'hi', b'there', b'b']
a.send_multipart(msg)
recvd = b.recv_multipart()
assert msg == recvd
if have_gevent:
class TestMultipartGreen(GreenTest, TestMultipart):
pass

View File

@ -0,0 +1,73 @@
"""
Test our typing with mypy
"""
import os
import sys
from subprocess import PIPE, STDOUT, Popen
import pytest
import zmq
pytest.importorskip("mypy")
zmq_dir = os.path.dirname(zmq.__file__)
def resolve_repo_dir(path):
"""Resolve a dir in the repo
Resolved relative to zmq dir
fallback on CWD (e.g. test run from repo, zmq installed, not -e)
"""
resolved_path = os.path.join(os.path.dirname(zmq_dir), path)
# fallback on CWD
if not os.path.exists(resolved_path):
resolved_path = path
return resolved_path
examples_dir = resolve_repo_dir("examples")
mypy_dir = resolve_repo_dir("mypy_tests")
def run_mypy(*mypy_args):
"""Run mypy for a path
Captures output and reports it on errors
"""
p = Popen(
[sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
)
o, _ = p.communicate()
out = o.decode("utf8", "replace")
print(out)
assert p.returncode == 0, out
if os.path.exists(examples_dir):
examples = [
d
for d in os.listdir(examples_dir)
if os.path.isdir(os.path.join(examples_dir, d))
]
@pytest.mark.skipif(
not os.path.exists(examples_dir), reason="only test from examples directory"
)
@pytest.mark.parametrize("example", examples)
def test_mypy_example(example):
example_dir = os.path.join(examples_dir, example)
run_mypy("--disallow-untyped-calls", example_dir)
if os.path.exists(mypy_dir):
mypy_tests = [p for p in os.listdir(mypy_dir) if p.endswith(".py")]
@pytest.mark.parametrize("filename", mypy_tests)
def test_mypy(filename):
run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename))

View File

@ -0,0 +1,52 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
x = b' '
class TestPair(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg1 = b'message1'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(10):
msg = i * x
s1.send(msg)
for i in range(10):
msg = i * x
s2.send(msg)
for i in range(10):
msg = s1.recv()
assert msg == i * x
for i in range(10):
msg = s2.recv()
assert msg == i * x
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10, b=list(range(10)))
self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10, b=range(10))
self.ping_pong_pyobj(s1, s2, o)
if have_gevent:
class TestReqRepGreen(GreenTest, TestPair):
pass

View File

@ -0,0 +1,239 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import sys
import time
from pytest import mark
import zmq
from zmq.tests import GreenTest, PollZMQTestCase, have_gevent
def wait():
time.sleep(0.25)
class TestPoll(PollZMQTestCase):
Poller = zmq.Poller
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
# Poll result should contain both sockets
socks = dict(poller.poll())
# Now make sure that both are send ready.
assert socks[s1] == zmq.POLLOUT
assert socks[s2] == zmq.POLLOUT
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
s1.send(b'msg1')
s2.send(b'msg2')
wait()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT | zmq.POLLIN
assert socks[s2] == zmq.POLLOUT | zmq.POLLIN
# Make sure that both are in POLLOUT after recv.
s1.recv()
s2.recv()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
assert socks[s2] == zmq.POLLOUT
poller.unregister(s1)
poller.unregister(s2)
def test_reqrep(self):
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
# Make sure that s1 is in state 0 and s2 is in POLLOUT
socks = dict(poller.poll())
assert s1 not in socks
assert socks[s2] == zmq.POLLOUT
# Make sure that s2 goes immediately into state 0 after send.
s2.send(b'msg1')
socks = dict(poller.poll())
assert s2 not in socks
# Make sure that s1 goes into POLLIN state after a time.sleep().
time.sleep(0.5)
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLIN
# Make sure that s1 goes into POLLOUT after recv.
s1.recv()
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
# Make sure s1 goes into state 0 after send.
s1.send(b'msg2')
socks = dict(poller.poll())
assert s1 not in socks
# Wait and then see that s2 is in POLLIN.
time.sleep(0.5)
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLIN
# Make sure that s2 is in POLLOUT after recv.
s2.recv()
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLOUT
poller.unregister(s1)
poller.unregister(s2)
def test_no_events(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, 0)
assert s1 in poller
assert s2 not in poller
poller.register(s1, 0)
assert s1 not in poller
def test_pubsub(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
poller.register(s2, zmq.POLLIN)
# Now make sure that both are send ready.
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
assert s2 not in socks
# Make sure that s1 stays in POLLOUT after a send.
s1.send(b'msg1')
socks = dict(poller.poll())
assert socks[s1] == zmq.POLLOUT
# Make sure that s2 is POLLIN after waiting.
wait()
socks = dict(poller.poll())
assert socks[s2] == zmq.POLLIN
# Make sure that s2 goes into 0 after recv.
s2.recv()
socks = dict(poller.poll())
assert s2 not in socks
poller.unregister(s1)
poller.unregister(s2)
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
def test_raw(self):
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
p = self.Poller()
p.register(r, zmq.POLLIN)
socks = dict(p.poll(1))
assert socks == {}
w.write(b'x')
w.flush()
socks = dict(p.poll(1))
assert socks == {r.fileno(): zmq.POLLIN}
w.close()
r.close()
@mark.flaky(reruns=3)
def test_timeout(self):
"""make sure Poller.poll timeout has the right units (milliseconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN)
tic = time.perf_counter()
poller.poll(0.005)
toc = time.perf_counter()
toc - tic < 0.1
tic = time.perf_counter()
poller.poll(50)
toc = time.perf_counter()
assert toc - tic < 0.1
assert toc - tic > 0.01
tic = time.perf_counter()
poller.poll(500)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.1
class TestSelect(PollZMQTestCase):
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
assert s1 in wlist
assert s2 in wlist
assert s1 not in rlist
assert s2 not in rlist
@mark.flaky(reruns=3)
def test_timeout(self):
"""make sure select timeout has the right units (seconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.perf_counter()
r, w, x = zmq.select([s1, s2], [], [], 0.005)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.001
tic = time.perf_counter()
r, w, x = zmq.select([s1, s2], [], [], 0.25)
toc = time.perf_counter()
assert toc - tic < 1
assert toc - tic > 0.1
if have_gevent:
import gevent
from zmq import green as gzmq
class TestPollGreen(GreenTest, TestPoll):
Poller = gzmq.Poller
def test_wakeup(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s2, zmq.POLLIN)
tic = time.perf_counter()
r = gevent.spawn(lambda: poller.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.perf_counter()
assert toc - tic < 1
def test_socket_poll(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.perf_counter()
r = gevent.spawn(lambda: s2.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.perf_counter()
assert toc - tic < 1

View File

@ -0,0 +1,95 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import struct
import time
import zmq
from zmq import devices
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestProxySteerable(BaseZMQTestCase):
def test_proxy_steerable(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
ctrl.send(b'TERMINATE')
dev.join()
def test_proxy_steerable_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend(
[
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max),
]
)
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_proxy_steerable_statistics(self):
if zmq.zmq_version_info() < (4, 3):
raise SkipTest("STATISTICS only in libzmq >= 4.3")
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
assert msg == self.recv(pull)
assert msg == self.recv(mon)
ctrl.send(b'STATISTICS')
stats = self.recv_multipart(ctrl)
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
assert 1 == stats_int[0]
assert len(msg) == stats_int[1]
assert 1 == stats_int[6]
assert len(msg) == stats_int[7]
ctrl.send(b'TERMINATE')
dev.join()

View File

@ -0,0 +1,41 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestPubSub(BaseZMQTestCase):
pass
# We are disabling this test while an issue is being resolved.
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
msg2 = s2.recv() # This is blocking!
assert msg1 == msg2
def test_topic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'x')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
msg1 = b'xmessage'
s1.send(msg1)
msg2 = s2.recv()
assert msg1 == msg2
if have_gevent:
class TestPubSubGreen(GreenTest, TestPubSub):
pass

View File

@ -0,0 +1,61 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
class TestReqRep(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = b'message 1'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
for i in range(10):
msg1 = i * b' '
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_bad_send_recv(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
if zmq.zmq_version() != '2.1.8':
# this doesn't work on 2.1.8
for copy in (True, False):
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
# I have to have this or we die on an Abort trap.
msg1 = b'asdf'
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10, b=list(range(10)))
self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10, b=range(10))
self.ping_pong_pyobj(s1, s2, o)
def test_large_msg(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = 10000 * b'X'
for i in range(10):
msg2 = self.ping_pong(s1, s2, msg1)
assert msg1 == msg2
if have_gevent:
class TestReqRepGreen(GreenTest, TestReqRep):
pass

View File

@ -0,0 +1,94 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import signal
import time
from threading import Thread
from pytest import mark
import zmq
from zmq.tests import BaseZMQTestCase, SkipTest
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
class TestEINTRSysCall(BaseZMQTestCase):
"""Base class for EINTR tests."""
# delay for initial signal delivery
signal_delay = 0.1
# timeout for tests. Must be > signal_delay
timeout = 0.25
timeout_ms = int(timeout * 1e3)
def alarm(self, t=None):
"""start a timer to fire only once
like signal.alarm, but with better resolution than integer seconds.
"""
if not hasattr(signal, 'setitimer'):
raise SkipTest('EINTR tests require setitimer')
if t is None:
t = self.signal_delay
self.timer_fired = False
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
# signal_period ignored, since only one timer event is allowed to fire
signal.setitimer(signal.ITIMER_REAL, t, 1000)
def stop_timer(self, *args):
self.timer_fired = True
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, self.orig_handler)
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_retry_recv(self):
pull = self.socket(zmq.PULL)
pull.rcvtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, pull.recv)
assert self.timer_fired
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_retry_send(self):
push = self.socket(zmq.PUSH)
push.sndtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, push.send, b'buf')
assert self.timer_fired
@mark.flaky(reruns=3)
def test_retry_poll(self):
x, y = self.create_bound_pair()
poller = zmq.Poller()
poller.register(x, zmq.POLLIN)
self.alarm()
def send():
time.sleep(2 * self.signal_delay)
y.send(b'ping')
t = Thread(target=send)
t.start()
evts = dict(poller.poll(2 * self.timeout_ms))
t.join()
assert x in evts
assert self.timer_fired
x.recv()
def test_retry_term(self):
push = self.socket(zmq.PUSH)
push.linger = self.timeout_ms
push.connect('tcp://127.0.0.1:5555')
push.send(b'ping')
time.sleep(0.1)
self.alarm()
self.context.destroy()
assert self.timer_fired
assert self.context.closed
def test_retry_getsockopt(self):
raise SkipTest("TODO: find a way to interrupt getsockopt")
def test_retry_setsockopt(self):
raise SkipTest("TODO: find a way to interrupt setsockopt")

View File

@ -0,0 +1,238 @@
"""Test libzmq security (libzmq >= 3.3.0)"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import contextlib
import os
import time
from threading import Thread
import zmq
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
from zmq.utils import z85
USER = b"admin"
PASS = b"password"
class TestSecurity(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4, 0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to be built with CURVE support")
super().setUp()
def zap_handler(self):
socket = self.context.socket(zmq.REP)
socket.bind("inproc://zeromq.zap.01")
try:
msg = self.recv_multipart(socket)
version, sequence, domain, address, identity, mechanism = msg[:6]
if mechanism == b'PLAIN':
username, password = msg[6:]
elif mechanism == b'CURVE':
msg[6]
assert version == b"1.0"
assert identity == b"IDENT"
reply = [version, sequence]
if (
mechanism == b'CURVE'
or (mechanism == b'PLAIN' and username == USER and password == PASS)
or (mechanism == b'NULL')
):
reply.extend(
[
b"200",
b"OK",
b"anonymous",
b"\5Hello\0\0\0\5World",
]
)
else:
reply.extend(
[
b"400",
b"Invalid username or password",
b"",
b"",
]
)
socket.send_multipart(reply)
finally:
socket.close()
@contextlib.contextmanager
def zap(self):
self.start_zap()
time.sleep(0.5) # allow time for the Thread to start
try:
yield
finally:
self.stop_zap()
def start_zap(self):
self.zap_thread = Thread(target=self.zap_handler)
self.zap_thread.start()
def stop_zap(self):
self.zap_thread.join()
def bounce(self, server, client, test_metadata=True):
msg = [os.urandom(64), os.urandom(64)]
client.send_multipart(msg)
frames = self.recv_multipart(server, copy=False)
recvd = list(map(lambda x: x.bytes, frames))
try:
if test_metadata and not PYPY:
for frame in frames:
assert frame.get('User-Id') == 'anonymous'
assert frame.get('Hello') == 'World'
assert frame['Socket-Type'] == 'DEALER'
except zmq.ZMQVersionError:
pass
assert recvd == msg
server.send_multipart(recvd)
msg2 = self.recv_multipart(client)
assert msg2 == msg
def test_null(self):
"""test NULL (default) security"""
server = self.socket(zmq.DEALER)
client = self.socket(zmq.DEALER)
assert client.MECHANISM == zmq.NULL
assert server.mechanism == zmq.NULL
assert client.plain_server == 0
assert server.plain_server == 0
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client, False)
def test_plain(self):
"""test PLAIN authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
assert client.plain_username == b''
assert client.plain_password == b''
client.plain_username = USER
client.plain_password = PASS
assert client.getsockopt(zmq.PLAIN_USERNAME) == USER
assert client.getsockopt(zmq.PLAIN_PASSWORD) == PASS
assert client.plain_server == 0
assert server.plain_server == 0
server.plain_server = True
assert server.mechanism == zmq.PLAIN
assert client.mechanism == zmq.PLAIN
assert not client.plain_server
assert server.plain_server
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)
def skip_plain_inauth(self):
"""test PLAIN failed authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
client.plain_username = USER
client.plain_password = b'incorrect'
server.plain_server = True
assert server.mechanism == zmq.PLAIN
assert client.mechanism == zmq.PLAIN
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
client.send(b'ping')
server.rcvtimeo = 250
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
def test_keypair(self):
"""test curve_keypair"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
assert type(secret) == bytes
assert type(public) == bytes
assert len(secret) == 40
assert len(public) == 40
# verify that it is indeed Z85
bsecret, bpublic = (z85.decode(key) for key in (public, secret))
assert type(bsecret) == bytes
assert type(bpublic) == bytes
assert len(bsecret) == 32
assert len(bpublic) == 32
def test_curve_public(self):
"""test curve_public"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
if zmq.zmq_version_info() < (4, 2):
raise SkipTest("curve_public is new in libzmq 4.2")
derived_public = zmq.curve_public(secret)
assert type(derived_public) == bytes
assert len(derived_public) == 40
# verify that it is indeed Z85
bpublic = z85.decode(derived_public)
assert type(bpublic) == bytes
assert len(bpublic) == 32
# verify that it is equal to the known public key
assert derived_public == public
def test_curve(self):
"""test CURVE encryption"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
try:
server.curve_server = True
except zmq.ZMQError as e:
# will raise EINVAL if no CURVE support
if e.errno == zmq.EINVAL:
raise SkipTest("CURVE unsupported")
server_public, server_secret = zmq.curve_keypair()
client_public, client_secret = zmq.curve_keypair()
server.curve_secretkey = server_secret
server.curve_publickey = server_public
client.curve_serverkey = server_public
client.curve_publickey = client_public
client.curve_secretkey = client_secret
assert server.mechanism == zmq.CURVE
assert client.mechanism == zmq.CURVE
assert server.get(zmq.CURVE_SERVER) == True
assert client.get(zmq.CURVE_SERVER) == False
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)

View File

@ -0,0 +1,671 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import errno
import json
import os
import platform
import socket
import sys
import time
import warnings
from unittest import mock
import pytest
from pytest import mark
import zmq
from zmq.tests import BaseZMQTestCase, GreenTest, SkipTest, have_gevent, skip_pypy
pypy = platform.python_implementation().lower() == 'pypy'
windows = platform.platform().lower().startswith('windows')
on_ci = bool(os.environ.get('CI'))
# polling on windows is slow
POLL_TIMEOUT = 1000 if windows else 100
class TestSocket(BaseZMQTestCase):
def test_create(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
# Superluminal protocol not yet implemented
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
s.close()
del ctx
def test_context_manager(self):
url = 'inproc://a'
with self.Context() as ctx:
with ctx.socket(zmq.PUSH) as a:
a.bind(url)
with ctx.socket(zmq.PULL) as b:
b.connect(url)
msg = b'hi'
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
assert b.closed == True
assert a.closed == True
assert ctx.closed == True
def test_connectbind_context_managers(self):
url = 'inproc://a'
msg = b'hi'
with self.Context() as ctx:
# Test connect() context manager
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
a.bind(url)
connect_context = b.connect(url)
assert f'connect={url!r}' in repr(connect_context)
with connect_context:
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
# b should now be disconnected, so sending and receiving don't work
with pytest.raises(zmq.Again):
a.send(msg, flags=zmq.DONTWAIT)
with pytest.raises(zmq.Again):
b.recv(flags=zmq.DONTWAIT)
a.unbind(url)
# Test bind() context manager
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
# unbind() just stops accepting of new connections, so we have to disconnect to test that
# unbind happened.
bind_context = a.bind(url)
assert f'bind={url!r}' in repr(bind_context)
with bind_context:
b.connect(url)
a.send(msg)
rcvd = self.recv(b)
assert rcvd == msg
b.disconnect(url)
b.connect(url)
# Since a is unbound from url, b is not connected to anything
with pytest.raises(zmq.Again):
a.send(msg, flags=zmq.DONTWAIT)
with pytest.raises(zmq.Again):
b.recv(flags=zmq.DONTWAIT)
_repr_cls = "zmq.Socket"
def test_repr(self):
with self.context.socket(zmq.PUSH) as s:
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
assert 'closed' not in repr(s)
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
assert 'closed' in repr(s)
def test_dir(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
assert 'send' in dir(s)
assert 'IDENTITY' in dir(s)
assert 'AFFINITY' in dir(s)
assert 'FD' in dir(s)
s.close()
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
s = self.socket(zmq.SUB)
m = mock.Mock(spec=s)
s.close()
def test_bind_unicode(self):
s = self.socket(zmq.PUB)
p = s.bind_to_random_port("tcp://*")
def test_connect_unicode(self):
s = self.socket(zmq.PUB)
s.connect("tcp://127.0.0.1:5555")
def test_bind_to_random_port(self):
# Check that bind_to_random_port do not hide useful exception
ctx = self.Context()
c = ctx.socket(zmq.PUB)
# Invalid format
try:
c.bind_to_random_port('tcp:*')
except zmq.ZMQError as e:
assert e.errno == zmq.EINVAL
# Invalid protocol
try:
c.bind_to_random_port('rand://*')
except zmq.ZMQError as e:
assert e.errno == zmq.EPROTONOSUPPORT
def test_identity(self):
s = self.context.socket(zmq.PULL)
self.sockets.append(s)
ident = b'identity\0\0'
s.identity = ident
assert s.get(zmq.IDENTITY) == ident
def test_unicode_sockopts(self):
"""test setting/getting sockopts with unicode strings"""
topic = "tést"
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
assert s.send_unicode == s.send_unicode
assert p.recv_unicode == p.recv_unicode
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
identb = s.getsockopt(zmq.IDENTITY)
identu = identb.decode('utf16')
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
assert identu == identu2
time.sleep(0.1) # wait for connection/subscription
p.send_unicode(topic, zmq.SNDMORE)
p.send_unicode(topic * 2, encoding='latin-1')
assert topic == s.recv_unicode()
assert topic * 2 == s.recv_unicode(encoding='latin-1')
def test_int_sockopts(self):
"test integer sockopts"
v = zmq.zmq_version_info()
if v < (3, 0):
default_hwm = 0
else:
default_hwm = 1000
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
p.setsockopt(zmq.LINGER, 0)
assert p.getsockopt(zmq.LINGER) == 0
p.setsockopt(zmq.LINGER, -1)
assert p.getsockopt(zmq.LINGER) == -1
assert p.hwm == default_hwm
p.hwm = 11
assert p.hwm == 11
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
assert p.getsockopt(zmq.EVENTS) == zmq.POLLOUT
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt, zmq.EVENTS, 2**7 - 1)
assert p.getsockopt(zmq.TYPE) == p.socket_type
assert p.getsockopt(zmq.TYPE) == zmq.PUB
assert s.getsockopt(zmq.TYPE) == s.socket_type
assert s.getsockopt(zmq.TYPE) == zmq.SUB
# check for overflow / wrong type:
errors = []
backref = {}
constants = zmq.constants
for name in constants.__all__:
value = getattr(constants, name)
if isinstance(value, int):
backref[value] = name
for opt in zmq.constants.SocketOption:
if opt._opt_type not in {
zmq.constants._OptType.int,
zmq.constants._OptType.int64,
}:
continue
if opt.name.startswith(
(
'HWM',
'ROUTER',
'XPUB',
'TCP',
'FAIL',
'REQ_',
'CURVE_',
'PROBE_ROUTER',
'IPC_FILTER',
'GSSAPI',
'STREAM_',
'VMCI_BUFFER_SIZE',
'VMCI_BUFFER_MIN_SIZE',
'VMCI_BUFFER_MAX_SIZE',
'VMCI_CONNECT_TIMEOUT',
'BLOCKY',
'IN_BATCH_SIZE',
'OUT_BATCH_SIZE',
'WSS_TRUST_SYSTEM',
'ONLY_FIRST_SUBSCRIBE',
'PRIORITY',
'RECONNECT_STOP',
)
):
# some sockopts are write-only
continue
try:
n = p.getsockopt(opt)
except zmq.ZMQError as e:
errors.append(f"getsockopt({opt!r}) raised {e}.")
else:
if n > 2**31:
errors.append(
f"getsockopt({opt!r}) returned a ridiculous value."
" It is probably the wrong type."
)
if errors:
self.fail('\n'.join([''] + errors))
def test_bad_sockopts(self):
"""Test that appropriate errors are raised on bad socket options"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
s.setsockopt(zmq.LINGER, 0)
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
# but only int sockopts are allowed through this way, otherwise raise a TypeError
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
# some sockopts are valid in general, but not on every socket:
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
def test_sockopt_roundtrip(self):
"test set/getsockopt roundtrip."
p = self.context.socket(zmq.PUB)
self.sockets.append(p)
p.setsockopt(zmq.LINGER, 11)
assert p.getsockopt(zmq.LINGER) == 11
def test_send_unicode(self):
"test sending unicode objects"
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a, b])
u = "çπ§"
self.assertRaises(TypeError, a.send, u, copy=False)
self.assertRaises(TypeError, a.send, u, copy=True)
a.send_unicode(u)
s = b.recv()
assert s == u.encode('utf8')
assert s.decode('utf8') == u
a.send_unicode(u, encoding='utf16')
s = b.recv_unicode(encoding='utf16')
assert s == u
def test_send_multipart_check_type(self):
"check type on all frames in send_multipart"
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a, b])
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
a.send_multipart([b'b'])
rcvd = self.recv_multipart(b)
assert rcvd == [b'b']
@skip_pypy
def test_tracker(self):
"test the MessageTracker object for tracking when zmq is done with a buffer"
addr = 'tcp://127.0.0.1'
# get a port:
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
iface = "%s:%i" % (addr, port)
sock.close()
time.sleep(0.1)
a = self.context.socket(zmq.PUSH)
b = self.context.socket(zmq.PULL)
self.sockets.extend([a, b])
a.connect(iface)
time.sleep(0.1)
p1 = a.send(b'something', copy=False, track=True)
assert isinstance(p1, zmq.MessageTracker)
assert p1 is zmq._FINISHED_TRACKER
# small message, should start done
assert p1.done
# disable zero-copy threshold
a.copy_threshold = 0
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
assert isinstance(p2, zmq.MessageTracker)
assert not p2.done
b.bind(iface)
msg = self.recv_multipart(b)
for i in range(10):
if p1.done:
break
time.sleep(0.1)
assert p1.done == True
assert msg == [b'something']
msg = self.recv_multipart(b)
for i in range(10):
if p2.done:
break
time.sleep(0.1)
assert p2.done == True
assert msg == [b'something', b'else']
m = zmq.Frame(b"again", copy=False, track=True)
assert m.tracker.done == False
p1 = a.send(m, copy=False)
p2 = a.send(m, copy=False)
assert m.tracker.done == False
assert p1.done == False
assert p2.done == False
msg = self.recv_multipart(b)
assert m.tracker.done == False
assert msg == [b'again']
msg = self.recv_multipart(b)
assert m.tracker.done == False
assert msg == [b'again']
assert p1.done == False
assert p2.done == False
m.tracker
del m
for i in range(10):
if p1.done:
break
time.sleep(0.1)
assert p1.done == True
assert p2.done == True
m = zmq.Frame(b'something', track=False)
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
def test_close(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.close()
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
del ctx
def test_attr(self):
"""set setting/getting sockopts as attributes"""
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
linger = 10
s.linger = linger
assert linger == s.linger
assert linger == s.getsockopt(zmq.LINGER)
assert s.fd == s.getsockopt(zmq.FD)
def test_bad_attr(self):
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
try:
s.apple = 'foo'
except AttributeError:
pass
else:
self.fail("bad setattr should have raised AttributeError")
try:
s.apple
except AttributeError:
pass
else:
self.fail("bad getattr should have raised AttributeError")
def test_subclass(self):
"""subclasses can assign attributes"""
class S(zmq.Socket):
a = None
def __init__(self, *a, **kw):
self.a = -1
super().__init__(*a, **kw)
s = S(self.context, zmq.REP)
self.sockets.append(s)
assert s.a == -1
s.a = 1
assert s.a == 1
a = s.a
assert a == 1
def test_recv_multipart(self):
a, b = self.create_bound_pair()
msg = b'hi'
for i in range(3):
a.send(msg)
time.sleep(0.1)
for i in range(3):
assert self.recv_multipart(b) == [msg]
def test_close_after_destroy(self):
"""s.close() after ctx.destroy() should be fine"""
ctx = self.Context()
s = ctx.socket(zmq.REP)
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
s.close()
assert s.closed
def test_poll(self):
a, b = self.create_bound_pair()
time.time()
evt = a.poll(POLL_TIMEOUT)
assert evt == 0
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
assert evt == zmq.POLLOUT
msg = b'hi'
a.send(msg)
evt = b.poll(POLL_TIMEOUT)
assert evt == zmq.POLLIN
msg2 = self.recv(b)
evt = b.poll(POLL_TIMEOUT)
assert evt == 0
assert msg2 == msg
def test_ipc_path_max_length(self):
"""IPC_PATH_MAX_LEN is a sensible value"""
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
assert zmq.IPC_PATH_MAX_LEN > 30, msg
assert zmq.IPC_PATH_MAX_LEN < 1025, msg
def test_ipc_path_max_length_msg(self):
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
try:
s.bind('ipc://{}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
except zmq.ZMQError as e:
assert str(zmq.IPC_PATH_MAX_LEN) in e.strerror
@mark.skipif(windows, reason="ipc not supported on Windows.")
def test_ipc_path_no_such_file_or_directory_message(self):
"""Display the ipc path in case of an ENOENT exception"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
invalid_path = '/foo/bar'
with pytest.raises(zmq.ZMQError) as error:
s.bind(f'ipc://{invalid_path}')
assert error.value.errno == errno.ENOENT
error_message = str(error.value)
assert invalid_path in error_message
assert "no such file or directory" in error_message.lower()
def test_hwm(self):
zmq3 = zmq.zmq_version_info()[0] >= 3
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
s = self.context.socket(stype)
s.hwm = 100
assert s.hwm == 100
if zmq3:
try:
assert s.sndhwm == 100
except AttributeError:
pass
try:
assert s.rcvhwm == 100
except AttributeError:
pass
s.close()
def test_copy(self):
s = self.socket(zmq.PUB)
scopy = copy.copy(s)
sdcopy = copy.deepcopy(s)
assert scopy._shadow
assert sdcopy._shadow
assert s.underlying == scopy.underlying
assert s.underlying == sdcopy.underlying
s.close()
def test_send_buffer(self):
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
for buffer_type in (memoryview, bytearray):
rawbytes = str(buffer_type).encode('ascii')
msg = buffer_type(rawbytes)
a.send(msg)
recvd = b.recv()
assert recvd == rawbytes
def test_shadow(self):
p = self.socket(zmq.PUSH)
p.bind("tcp://127.0.0.1:5555")
p2 = zmq.Socket.shadow(p.underlying)
assert p.underlying == p2.underlying
s = self.socket(zmq.PULL)
s2 = zmq.Socket.shadow(s.underlying)
self.assertNotEqual(s.underlying, p.underlying)
assert s.underlying == s2.underlying
s2.connect("tcp://127.0.0.1:5555")
sent = b'hi'
p2.send(sent)
rcvd = self.recv(s2)
assert rcvd == sent
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
ca = zsocket.new(ctx, zmq.PUSH)
cb = zsocket.new(ctx, zmq.PULL)
a = zmq.Socket.shadow(ca)
b = zmq.Socket.shadow(cb)
a.bind("inproc://a")
b.connect("inproc://a")
a.send(b'hi')
rcvd = self.recv(b)
assert rcvd == b'hi'
def test_subscribe_method(self):
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
sub.subscribe('prefix')
sub.subscribe = 'c'
p = zmq.Poller()
p.register(sub, zmq.POLLIN)
# wait for subscription handshake
for i in range(100):
pub.send(b'canary')
events = p.poll(250)
if events:
break
self.recv(sub)
pub.send(b'prefixmessage')
msg = self.recv(sub)
assert msg == b'prefixmessage'
sub.unsubscribe('prefix')
pub.send(b'prefixmessage')
events = p.poll(1000)
assert events == []
# CI often can't handle how much memory PyPy uses on this test
@mark.skipif(
(pypy and on_ci) or (sys.maxsize < 2**32) or (windows),
reason="only run on 64b and not on CI.",
)
@mark.large
def test_large_send(self):
c = os.urandom(1)
N = 2**31 + 1
try:
buf = c * N
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
a, b = self.create_bound_pair()
try:
a.send(buf, copy=False)
rcvd = b.recv(copy=False)
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
# sample the front and back of the received message
# without checking the whole content
byte = ord(c)
view = memoryview(rcvd)
assert len(view) == N
assert view[0] == byte
assert view[-1] == byte
def test_custom_serialize(self):
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
a.send_serialized(msg, serialize)
recvd = b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
b.send_serialized(recvd, serialize)
r2 = a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
if have_gevent and not windows:
import gevent
class TestSocketGreen(GreenTest, TestSocket):
test_bad_attr = GreenTest.skip_green
test_close_after_destroy = GreenTest.skip_green
_repr_cls = "zmq.green.Socket"
def test_timeout(self):
a, b = self.create_bound_pair()
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
timeout = gevent.Timeout(0.1)
timeout.start()
self.assertRaises(gevent.Timeout, b.recv)
g.kill()
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_warn_set_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.rcvtimeo = 5
s.close()
assert len(w) == 1
assert w[0].category == UserWarning
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_warn_get_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.sndtimeo
s.close()
assert len(w) == 1
assert w[0].category == UserWarning

View File

@ -0,0 +1,9 @@
from zmq.ssh.tunnel import select_random_ports
def test_random_ports():
for i in range(4096):
ports = select_random_ports(10)
assert len(ports) == 10
for p in ports:
assert ports.count(p) == 1

View File

@ -0,0 +1,43 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
from zmq.sugar import version
class TestVersion(TestCase):
def test_pyzmq_version(self):
vs = zmq.pyzmq_version()
vs2 = zmq.__version__
assert isinstance(vs, str)
if zmq.__revision__:
assert vs == '@'.join(vs2, zmq.__revision__)
else:
assert vs == vs2
if version.VERSION_EXTRA:
assert version.VERSION_EXTRA in vs
assert version.VERSION_EXTRA in vs2
def test_pyzmq_version_info(self):
info = zmq.pyzmq_version_info()
assert isinstance(info, tuple)
for n in info[:3]:
assert isinstance(n, int)
if version.VERSION_EXTRA:
assert len(info) == 4
assert info[-1] == float('inf')
else:
assert len(info) == 3
def test_zmq_version_info(self):
info = zmq.zmq_version_info()
assert isinstance(info, tuple)
for n in info[:3]:
assert isinstance(n, int)
def test_zmq_version(self):
v = zmq.zmq_version()
assert isinstance(v, str)

View File

@ -0,0 +1,58 @@
import sys
import time
from functools import wraps
from pytest import mark
from zmq.tests import BaseZMQTestCase
from zmq.utils.win32 import allow_interrupt
def count_calls(f):
@wraps(f)
def _(*args, **kwds):
try:
return f(*args, **kwds)
finally:
_.__calls__ += 1
_.__calls__ = 0
return _
@mark.new_console
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
@mark.new_console
@mark.skipif(not sys.platform.startswith('win'), reason='Windows only test')
def test_handler(self):
@count_calls
def interrupt_polling():
print('Caught CTRL-C!')
from ctypes import windll
from ctypes.wintypes import BOOL, DWORD
kernel32 = windll.LoadLibrary('kernel32')
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
GenerateConsoleCtrlEvent.restype = BOOL
# Simulate CTRL-C event while handler is active.
try:
with allow_interrupt(interrupt_polling) as context:
result = GenerateConsoleCtrlEvent(0, 0)
# Sleep so that we give time to the handler to
# capture the Ctrl-C event.
time.sleep(0.5)
except KeyboardInterrupt:
pass
else:
if result == 0:
raise OSError()
else:
self.fail('Expecting `KeyboardInterrupt` exception!')
# Make sure our handler was called.
assert interrupt_polling.__calls__ == 1

View File

@ -0,0 +1,65 @@
"""Test Z85 encoding
confirm values and roundtrip with test values from the reference implementation.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
from zmq.utils import z85
class TestZ85(TestCase):
def test_client_public(self):
client_public = (
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B"
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A"
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26"
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
)
encoded = z85.encode(client_public)
assert encoded == b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID"
decoded = z85.decode(encoded)
assert decoded == client_public
def test_client_secret(self):
client_secret = (
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67"
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89"
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51"
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
)
encoded = z85.encode(client_secret)
assert encoded == b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs"
decoded = z85.decode(encoded)
assert decoded == client_secret
def test_server_public(self):
server_public = (
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96"
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0"
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5"
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
)
encoded = z85.encode(server_public)
assert encoded == b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7"
decoded = z85.decode(encoded)
assert decoded == server_public
def test_server_secret(self):
server_secret = (
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D"
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0"
b"\x4D\x48\x96\x3F\x79\x25\x98\x77"
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
)
encoded = z85.encode(server_secret)
assert encoded == b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6"
decoded = z85.decode(encoded)
assert decoded == server_secret

View File

@ -0,0 +1,83 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
from unittest import TestCase
import pytest
import zmq
try:
import tornado
from tornado import gen
from zmq.eventloop import ioloop, zmqstream
except ImportError:
tornado = None # type: ignore
class TestZMQStream(TestCase):
def setUp(self):
if tornado is None:
pytest.skip()
if asyncio:
asyncio.set_event_loop(asyncio.new_event_loop())
self.context = zmq.Context()
self.loop = ioloop.IOLoop()
self.loop.make_current()
self.push = zmqstream.ZMQStream(self.context.socket(zmq.PUSH))
self.pull = zmqstream.ZMQStream(self.context.socket(zmq.PULL))
port = self.push.bind_to_random_port('tcp://127.0.0.1')
self.pull.connect('tcp://127.0.0.1:%i' % port)
self.stream = self.push
def tearDown(self):
self.loop.close(all_fds=True)
self.context.term()
ioloop.IOLoop.clear_current()
def run_until_timeout(self, timeout=10):
timed_out = []
@gen.coroutine
def sleep_timeout():
yield gen.sleep(timeout)
timed_out[:] = ['timed out']
self.loop.stop()
self.loop.add_callback(lambda: sleep_timeout())
self.loop.start()
assert not timed_out
def test_callable_check(self):
"""Ensure callable check works (py3k)."""
self.stream.on_send(lambda *args: None)
self.stream.on_recv(lambda *args: None)
self.assertRaises(AssertionError, self.stream.on_recv, 1)
self.assertRaises(AssertionError, self.stream.on_send, 1)
self.assertRaises(AssertionError, self.stream.on_recv, zmq)
def test_on_recv_basic(self):
sent = [b'basic']
def callback(msg):
assert msg == sent
self.loop.stop()
self.loop.add_callback(lambda: self.push.send_multipart(sent))
self.pull.on_recv(callback)
self.run_until_timeout()
def test_on_recv_wake(self):
sent = [b'wake']
def callback(msg):
assert msg == sent
self.loop.stop()
self.pull.on_recv(callback)
self.loop.call_later(1, lambda: self.push.send_multipart(sent))
self.run_until_timeout()