mirror of
https://github.com/aykhans/AzSuicideDataVisualization.git
synced 2025-07-01 14:07:48 +00:00
first commit
This commit is contained in:
257
.venv/Lib/site-packages/zmq/tests/__init__.py
Normal file
257
.venv/Lib/site-packages/zmq/tests/__init__.py
Normal file
@ -0,0 +1,257 @@
|
||||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
from unittest import SkipTest, TestCase
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.utils import jsonapi
|
||||
|
||||
try:
|
||||
import gevent
|
||||
|
||||
from zmq import green as gzmq
|
||||
|
||||
have_gevent = True
|
||||
except ImportError:
|
||||
have_gevent = False
|
||||
|
||||
|
||||
PYPY = platform.python_implementation() == 'PyPy'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# skip decorators (directly from unittest)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_id = lambda x: x
|
||||
|
||||
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
|
||||
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base test class
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def term_context(ctx, timeout):
|
||||
"""Terminate a context with a timeout"""
|
||||
t = Thread(target=ctx.term)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise RuntimeError(
|
||||
"context could not terminate, open sockets likely remain in test"
|
||||
)
|
||||
|
||||
|
||||
class BaseZMQTestCase(TestCase):
|
||||
green = False
|
||||
teardown_timeout = 10
|
||||
test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60)
|
||||
sockets: List[zmq.Socket]
|
||||
|
||||
@property
|
||||
def _is_pyzmq_test(self):
|
||||
return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0]
|
||||
|
||||
@property
|
||||
def _should_test_timeout(self):
|
||||
return (
|
||||
self._is_pyzmq_test
|
||||
and hasattr(signal, 'SIGALRM')
|
||||
and self.test_timeout_seconds
|
||||
)
|
||||
|
||||
@property
|
||||
def Context(self):
|
||||
if self.green:
|
||||
return gzmq.Context
|
||||
else:
|
||||
return zmq.Context
|
||||
|
||||
def socket(self, socket_type):
|
||||
s = self.context.socket(socket_type)
|
||||
self.sockets.append(s)
|
||||
return s
|
||||
|
||||
def _alarm_timeout(self, timeout, *args):
|
||||
raise TimeoutError(f"Test did not complete in {timeout} seconds")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if self.green and not have_gevent:
|
||||
raise SkipTest("requires gevent")
|
||||
|
||||
self.context = self.Context.instance()
|
||||
self.sockets = []
|
||||
if self._should_test_timeout:
|
||||
# use SIGALRM to avoid test hangs
|
||||
signal.signal(
|
||||
signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds)
|
||||
)
|
||||
signal.alarm(self.test_timeout_seconds)
|
||||
|
||||
def tearDown(self):
|
||||
if self._should_test_timeout:
|
||||
# cancel the timeout alarm, if there was one
|
||||
signal.alarm(0)
|
||||
contexts = {self.context}
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close(0)
|
||||
for ctx in contexts:
|
||||
try:
|
||||
term_context(ctx, self.teardown_timeout)
|
||||
except Exception:
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise
|
||||
|
||||
super().tearDown()
|
||||
|
||||
def create_bound_pair(
|
||||
self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'
|
||||
):
|
||||
"""Create a bound socket pair using a random port."""
|
||||
s1 = self.context.socket(type1)
|
||||
s1.setsockopt(zmq.LINGER, 0)
|
||||
port = s1.bind_to_random_port(interface)
|
||||
s2 = self.context.socket(type2)
|
||||
s2.setsockopt(zmq.LINGER, 0)
|
||||
s2.connect(f'{interface}:{port}')
|
||||
self.sockets.extend([s1, s2])
|
||||
return s1, s2
|
||||
|
||||
def ping_pong(self, s1, s2, msg):
|
||||
s1.send(msg)
|
||||
msg2 = s2.recv()
|
||||
s2.send(msg2)
|
||||
msg3 = s1.recv()
|
||||
return msg3
|
||||
|
||||
def ping_pong_json(self, s1, s2, o):
|
||||
if jsonapi.jsonmod is None:
|
||||
raise SkipTest("No json library")
|
||||
s1.send_json(o)
|
||||
o2 = s2.recv_json()
|
||||
s2.send_json(o2)
|
||||
o3 = s1.recv_json()
|
||||
return o3
|
||||
|
||||
def ping_pong_pyobj(self, s1, s2, o):
|
||||
s1.send_pyobj(o)
|
||||
o2 = s2.recv_pyobj()
|
||||
s2.send_pyobj(o2)
|
||||
o3 = s1.recv_pyobj()
|
||||
return o3
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(
|
||||
e.errno,
|
||||
errno,
|
||||
"wrong error raised, expected '%s' \
|
||||
got '%s'"
|
||||
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
|
||||
)
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def _select_recv(self, multipart, socket, **kwargs):
|
||||
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
||||
if zmq.zmq_version_info() >= (3, 1, 0):
|
||||
# zmq 3.1 has a bug, where poll can return false positives,
|
||||
# so we wait a little bit just in case
|
||||
# See LIBZMQ-280 on JIRA
|
||||
time.sleep(0.1)
|
||||
|
||||
r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
|
||||
assert len(r) > 0, "Should have received a message"
|
||||
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
||||
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
return recv(**kwargs)
|
||||
|
||||
def recv(self, socket, **kwargs):
|
||||
"""call recv in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(False, socket, **kwargs)
|
||||
|
||||
def recv_multipart(self, socket, **kwargs):
|
||||
"""call recv_multipart in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(True, socket, **kwargs)
|
||||
|
||||
|
||||
class PollZMQTestCase(BaseZMQTestCase):
|
||||
pass
|
||||
|
||||
|
||||
class GreenTest:
|
||||
"""Mixin for making green versions of test classes"""
|
||||
|
||||
green = True
|
||||
teardown_timeout = 10
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
if errno == zmq.EAGAIN:
|
||||
raise SkipTest("Skipping because we're green.")
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError:
|
||||
e = sys.exc_info()[1]
|
||||
self.assertEqual(
|
||||
e.errno,
|
||||
errno,
|
||||
"wrong error raised, expected '%s' \
|
||||
got '%s'"
|
||||
% (zmq.ZMQError(errno), zmq.ZMQError(e.errno)),
|
||||
)
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def tearDown(self):
|
||||
if self._should_test_timeout:
|
||||
# cancel the timeout alarm, if there was one
|
||||
signal.alarm(0)
|
||||
contexts = {self.context}
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close()
|
||||
try:
|
||||
gevent.joinall(
|
||||
[gevent.spawn(ctx.term) for ctx in contexts],
|
||||
timeout=self.teardown_timeout,
|
||||
raise_error=True,
|
||||
)
|
||||
except gevent.Timeout:
|
||||
raise RuntimeError(
|
||||
"context could not terminate, open sockets likely remain in test"
|
||||
)
|
||||
|
||||
def skip_green(self):
|
||||
raise SkipTest("Skipping because we are green")
|
||||
|
||||
|
||||
def skip_green(f):
|
||||
def skipping_test(self, *args, **kwargs):
|
||||
if self.green:
|
||||
raise SkipTest("Skipping because we are green")
|
||||
else:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return skipping_test
|
1
.venv/Lib/site-packages/zmq/tests/conftest.py
Normal file
1
.venv/Lib/site-packages/zmq/tests/conftest.py
Normal file
@ -0,0 +1 @@
|
||||
"""pytest configuration and fixtures"""
|
498
.venv/Lib/site-packages/zmq/tests/test_asyncio.py
Normal file
498
.venv/Lib/site-packages/zmq/tests/test_asyncio.py
Normal file
@ -0,0 +1,498 @@
|
||||
"""Test asyncio support"""
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import CancelledError
|
||||
from multiprocessing import Process
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as zaio
|
||||
from zmq.auth.asyncio import AsyncioAuthenticator
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
from zmq.tests.test_auth import TestThreadAuthentication
|
||||
|
||||
|
||||
class ProcessForTeardownTest(Process):
|
||||
def __init__(self, event_loop_policy_class):
|
||||
Process.__init__(self)
|
||||
self.event_loop_policy_class = event_loop_policy_class
|
||||
|
||||
def run(self):
|
||||
"""Leave context, socket and event loop upon implicit disposal"""
|
||||
asyncio.set_event_loop_policy(self.event_loop_policy_class())
|
||||
|
||||
actx = zaio.Context.instance()
|
||||
socket = actx.socket(zmq.PAIR)
|
||||
socket.bind_to_random_port("tcp://127.0.0.1")
|
||||
|
||||
async def never_ending_task(socket):
|
||||
await socket.recv() # never ever receive anything
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError:
|
||||
pass # expected timeout
|
||||
else:
|
||||
assert False, "never_ending_task was completed unexpectedly"
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestAsyncIOSocket(BaseZMQTestCase):
|
||||
Context = zaio.Context
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.loop.close()
|
||||
# verify cleanup of references to selectors
|
||||
assert zaio._selectors == {}
|
||||
if 'zmq._asyncio_selector' in sys.modules:
|
||||
assert zmq._asyncio_selector._selector_loops == set()
|
||||
|
||||
def test_socket_class(self):
|
||||
s = self.context.socket(zmq.PUSH)
|
||||
assert isinstance(s, zaio.Socket)
|
||||
s.close()
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
actx = zmq.asyncio.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
ctx = zmq.Context.instance()
|
||||
actx = zmq.asyncio.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
def test_recv_multipart(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
await a.send(b"hi")
|
||||
recvd = await f
|
||||
assert recvd == [b"hi"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.done()
|
||||
assert f1.result() == b"hi"
|
||||
assert recvd == b"there"
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
|
||||
def test_recv_timeout(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with self.assertRaises(zmq.Again):
|
||||
await f1
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f2.done()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
|
||||
def test_send_timeout(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with self.assertRaises(zmq.Again):
|
||||
await s.send(b"not going anywhere")
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_string(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = "πøøπ"
|
||||
await a.send_string(msg)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == msg
|
||||
assert recvd == msg
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_json(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_json_cancelled(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
await asyncio.sleep(0)
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
# CancelledError change in 3.8 https://bugs.python.org/issue32528
|
||||
if sys.version_info < (3, 8):
|
||||
with pytest.raises(CancelledError):
|
||||
recvd = await f
|
||||
else:
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = await b.poll(timeout=5)
|
||||
assert events
|
||||
await asyncio.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
f = b.recv_json()
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_pyobj(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_pyobj(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_custom_serialize(self):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get("identities", []))
|
||||
content = json.dumps(msg["content"]).encode("utf8")
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode("utf8"))
|
||||
return {
|
||||
"identities": identities,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
"content": {
|
||||
"a": 5,
|
||||
"b": "bee",
|
||||
}
|
||||
}
|
||||
await a.send_serialized(msg, serialize)
|
||||
recvd = await b.recv_serialized(deserialize)
|
||||
assert recvd["content"] == msg["content"]
|
||||
assert recvd["identities"]
|
||||
# bounce back, tests identities
|
||||
await b.send_serialized(recvd, serialize)
|
||||
r2 = await a.recv_serialized(deserialize)
|
||||
assert r2["content"] == msg["content"]
|
||||
assert not r2["identities"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_custom_serialize_error(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
"content": {
|
||||
"a": 5,
|
||||
"b": "bee",
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
await a.send_serialized(json, json.dumps)
|
||||
|
||||
await a.send(b"not json")
|
||||
with pytest.raises(TypeError):
|
||||
await b.recv_serialized(json.loads)
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_dontwait(self):
|
||||
async def test():
|
||||
push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
with self.assertRaises(zmq.Again):
|
||||
await f
|
||||
await push.send(b"ping")
|
||||
await pull.poll() # ensure message will be waiting
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
assert f.done()
|
||||
msg = await f
|
||||
assert msg == b"ping"
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_cancel(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.poll(timeout=0)
|
||||
await asyncio.sleep(0)
|
||||
assert f.result() == 0
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = await f
|
||||
|
||||
assert evt == 0
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == zmq.POLLIN
|
||||
recvd = await b.recv_multipart()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll_base_socket(self):
|
||||
async def test():
|
||||
ctx = zmq.Context()
|
||||
url = "inproc://test"
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = zaio.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == [(b, zmq.POLLIN)]
|
||||
recvd = b.recv_multipart()
|
||||
assert recvd == [b"hi", b"there"]
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll_on_closed_socket(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
b.close()
|
||||
|
||||
# The test might stall if we try to await f directly so instead just make a few
|
||||
# passes through the event loop to schedule and execute all callbacks
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(0)
|
||||
if f.cancelled():
|
||||
break
|
||||
assert f.cancelled()
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="Windows does not support polling on files",
|
||||
)
|
||||
def test_poll_raw(self):
|
||||
async def test():
|
||||
p = zaio.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, "rb")
|
||||
w = os.fdopen(w, "wb")
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = await p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b"x")
|
||||
w.flush()
|
||||
evts = await p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b"x"
|
||||
r.close()
|
||||
w.close()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(test())
|
||||
|
||||
def test_multiple_loops(self):
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
|
||||
async def test():
|
||||
await a.send(b'buf')
|
||||
msg = await b.recv()
|
||||
assert msg == b'buf'
|
||||
|
||||
for i in range(3):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
|
||||
loop.close()
|
||||
|
||||
def test_shadow(self):
|
||||
async def test():
|
||||
ctx = zmq.Context()
|
||||
s = ctx.socket(zmq.PULL)
|
||||
async_s = zaio.Socket(s)
|
||||
assert isinstance(async_s, self.socket_class)
|
||||
|
||||
def test_process_teardown(self):
|
||||
event_loop_policy_class = type(asyncio.get_event_loop_policy())
|
||||
proc = ProcessForTeardownTest(event_loop_policy_class)
|
||||
proc.start()
|
||||
try:
|
||||
proc.join(10) # starting new Python process may cost a lot
|
||||
self.assertEqual(
|
||||
proc.exitcode,
|
||||
0,
|
||||
"Python process died with code %d" % proc.exitcode
|
||||
if proc.exitcode
|
||||
else "process teardown hangs",
|
||||
)
|
||||
finally:
|
||||
proc.terminate()
|
||||
|
||||
|
||||
class TestAsyncioAuthentication(TestThreadAuthentication):
|
||||
"""Test authentication running in a asyncio task"""
|
||||
|
||||
Context = zaio.Context
|
||||
|
||||
def shortDescription(self):
|
||||
"""Rewrite doc strings from TestThreadAuthentication from
|
||||
'threaded' to 'asyncio'.
|
||||
"""
|
||||
doc = self._testMethodDoc
|
||||
if doc:
|
||||
doc = doc.split("\n")[0].strip()
|
||||
if doc.startswith("threaded auth"):
|
||||
doc = doc.replace("threaded auth", "asyncio auth")
|
||||
return doc
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.loop.close()
|
||||
|
||||
def make_auth(self):
|
||||
return AsyncioAuthenticator(self.context)
|
||||
|
||||
def can_connect(self, server, client):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
|
||||
async def go():
|
||||
result = False
|
||||
iface = "tcp://127.0.0.1"
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
|
||||
# set timeouts
|
||||
server.SNDTIMEO = client.RCVTIMEO = 1000
|
||||
try:
|
||||
await server.send_multipart(msg)
|
||||
except zmq.Again:
|
||||
return False
|
||||
try:
|
||||
rcvd_msg = await client.recv_multipart()
|
||||
except zmq.Again:
|
||||
return False
|
||||
else:
|
||||
assert rcvd_msg == msg
|
||||
result = True
|
||||
return result
|
||||
|
||||
return self.loop.run_until_complete(go())
|
||||
|
||||
def _select_recv(self, multipart, socket, **kwargs):
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
|
||||
async def coro():
|
||||
if not await socket.poll(5000):
|
||||
raise TimeoutError("Should have received a message")
|
||||
return await recv(**kwargs)
|
||||
|
||||
return self.loop.run_until_complete(coro())
|
579
.venv/Lib/site-packages/zmq/tests/test_auth.py
Normal file
579
.venv/Lib/site-packages/zmq/tests/test_auth.py
Normal file
@ -0,0 +1,579 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq.auth
|
||||
from zmq.auth.thread import ThreadAuthenticator
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
|
||||
|
||||
|
||||
class BaseAuthTestCase(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if zmq.zmq_version_info() < (4, 0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to have curve support")
|
||||
super().setUp()
|
||||
# enable debug logging while we run tests
|
||||
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
||||
self.auth = self.make_auth()
|
||||
self.auth.start()
|
||||
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
|
||||
|
||||
def make_auth(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def tearDown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.remove_certs(self.base_dir)
|
||||
super().tearDown()
|
||||
|
||||
def create_certs(self):
|
||||
"""Create CURVE certificates for a test"""
|
||||
|
||||
# Create temporary CURVE keypairs for this test run. We create all keys in a
|
||||
# temp directory and then move them into the appropriate private or public
|
||||
# directory.
|
||||
|
||||
base_dir = tempfile.mkdtemp()
|
||||
keys_dir = os.path.join(base_dir, 'certificates')
|
||||
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
||||
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
||||
|
||||
os.mkdir(keys_dir)
|
||||
os.mkdir(public_keys_dir)
|
||||
os.mkdir(secret_keys_dir)
|
||||
|
||||
server_public_file, server_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "server"
|
||||
)
|
||||
client_public_file, client_secret_file = zmq.auth.create_certificates(
|
||||
keys_dir, "client"
|
||||
)
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
|
||||
)
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key_secret"):
|
||||
shutil.move(
|
||||
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
|
||||
)
|
||||
|
||||
return (base_dir, public_keys_dir, secret_keys_dir)
|
||||
|
||||
def remove_certs(self, base_dir):
|
||||
"""Remove certificates for a test"""
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
def load_certs(self, secret_keys_dir):
|
||||
"""Return server and client certificate keys"""
|
||||
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
||||
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
||||
|
||||
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
||||
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
||||
|
||||
return server_public, server_secret, client_public, client_secret
|
||||
|
||||
|
||||
class TestThreadAuthentication(BaseAuthTestCase):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
return ThreadAuthenticator(self.context)
|
||||
|
||||
def can_connect(self, server, client):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
result = False
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
# run poll on server twice
|
||||
# to flush spurious events
|
||||
server.poll(100, zmq.POLLOUT)
|
||||
|
||||
if server.poll(1000, zmq.POLLOUT):
|
||||
try:
|
||||
server.send_multipart(msg, zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if client.poll(1000):
|
||||
try:
|
||||
rcvd_msg = client.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
|
||||
else:
|
||||
assert rcvd_msg == msg
|
||||
result = True
|
||||
return result
|
||||
|
||||
def test_null(self):
|
||||
"""threaded auth - NULL"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
# use a new context, so ZAP isn't inherited
|
||||
self.context = self.Context()
|
||||
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
def test_blacklist(self):
|
||||
"""threaded auth - Blacklist"""
|
||||
# Blacklist 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
server = self.socket(zmq.PUSH)
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
def test_whitelist(self):
|
||||
"""threaded auth - Whitelist"""
|
||||
# Whitelist 127.0.0.1, connection should pass"
|
||||
self.auth.allow('127.0.0.1')
|
||||
server = self.socket(zmq.PUSH)
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
def test_plain(self):
|
||||
"""threaded auth - PLAIN"""
|
||||
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Password'
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
# Try PLAIN authentication - with server configured, connection should pass
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Password'
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
# Try PLAIN authentication - with bogus credentials, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Bogus'
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
assert self.can_connect(server, client)
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
def test_curve(self):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
# Try CURVE authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
# Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
|
||||
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
# Try CURVE authentication - with server configured, connection should pass
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
server = self.socket(zmq.PULL)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PUSH)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(client, server)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# Try connecting using NULL and no authentication enabled, connection should pass
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
def test_curve_callback(self):
|
||||
"""threaded auth - CURVE with callback authentication"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
# Try CURVE authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
# Try CURVE authentication - with callback authentication configured, connection should pass
|
||||
|
||||
class CredentialsProvider:
|
||||
def __init__(self):
|
||||
self.client = client_public
|
||||
|
||||
def callback(self, domain, key):
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
provider = CredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(server, client)
|
||||
|
||||
# Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
|
||||
|
||||
class WrongCredentialsProvider:
|
||||
def __init__(self):
|
||||
self.client = "WrongCredentials"
|
||||
|
||||
def callback(self, domain, key):
|
||||
if key == self.client:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
provider = WrongCredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert not self.can_connect(server, client)
|
||||
|
||||
@skip_pypy
|
||||
def test_curve_user_id(self):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
server = self.socket(zmq.PULL)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PUSH)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(client, server)
|
||||
|
||||
# test default user-id map
|
||||
client.send(b'test')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.bytes == b'test'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == client_public.decode("utf8")
|
||||
|
||||
# test custom user-id map
|
||||
self.auth.curve_user_id = lambda client_key: 'custom'
|
||||
|
||||
client2 = self.socket(zmq.PUSH)
|
||||
client2.curve_publickey = client_public
|
||||
client2.curve_secretkey = client_secret
|
||||
client2.curve_serverkey = server_public
|
||||
assert self.can_connect(client2, server)
|
||||
|
||||
client2.send(b'test2')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.bytes == b'test2'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == 'custom'
|
||||
|
||||
|
||||
def with_ioloop(method, expect_success=True):
|
||||
"""decorator for running tests with an IOLoop"""
|
||||
|
||||
def test_method(self):
|
||||
r = method(self)
|
||||
|
||||
loop = self.io_loop
|
||||
if expect_success:
|
||||
self.pullstream.on_recv(self.on_message_succeed)
|
||||
else:
|
||||
self.pullstream.on_recv(self.on_message_fail)
|
||||
|
||||
loop.call_later(1, self.attempt_connection)
|
||||
loop.call_later(1.2, self.send_msg)
|
||||
|
||||
if expect_success:
|
||||
loop.call_later(2, self.on_test_timeout_fail)
|
||||
else:
|
||||
loop.call_later(2, self.on_test_timeout_succeed)
|
||||
|
||||
loop.start()
|
||||
if self.fail_msg:
|
||||
self.fail(self.fail_msg)
|
||||
|
||||
return r
|
||||
|
||||
return test_method
|
||||
|
||||
|
||||
def should_auth(method):
|
||||
return with_ioloop(method, True)
|
||||
|
||||
|
||||
def should_not_auth(method):
|
||||
return with_ioloop(method, False)
|
||||
|
||||
|
||||
class TestIOLoopAuthentication(BaseAuthTestCase):
|
||||
"""Test authentication running in ioloop"""
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
from tornado import ioloop
|
||||
except ImportError:
|
||||
pytest.skip("Requires tornado")
|
||||
from zmq.eventloop import zmqstream
|
||||
|
||||
self.fail_msg = None
|
||||
self.io_loop = ioloop.IOLoop()
|
||||
super().setUp()
|
||||
self.server = self.socket(zmq.PUSH)
|
||||
self.client = self.socket(zmq.PULL)
|
||||
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
|
||||
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.ioloop import IOLoopAuthenticator
|
||||
|
||||
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
|
||||
|
||||
def tearDown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.io_loop.close(all_fds=True)
|
||||
super().tearDown()
|
||||
|
||||
def attempt_connection(self):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = self.server.bind_to_random_port(iface)
|
||||
self.client.connect("%s:%i" % (iface, port))
|
||||
|
||||
def send_msg(self):
|
||||
"""Send a message from server to a client"""
|
||||
msg = [b"Hello World"]
|
||||
self.pushstream.send_multipart(msg)
|
||||
|
||||
def on_message_succeed(self, frames):
|
||||
"""A message was received, as expected."""
|
||||
if frames != [b"Hello World"]:
|
||||
self.fail_msg = "Unexpected message received"
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_message_fail(self, frames):
|
||||
"""A message was received, unexpectedly."""
|
||||
self.fail_msg = 'Received messaged unexpectedly, security failed'
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_test_timeout_succeed(self):
|
||||
"""Test timer expired, indicates test success"""
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_test_timeout_fail(self):
|
||||
"""Test timer expired, indicates test failure"""
|
||||
self.fail_msg = 'Test timed out'
|
||||
self.io_loop.stop()
|
||||
|
||||
@should_auth
|
||||
def test_none(self):
|
||||
"""ioloop auth - NONE"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
# no auth should be running
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
@should_auth
|
||||
def test_null(self):
|
||||
"""ioloop auth - NULL"""
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
self.server.zap_domain = b'global'
|
||||
|
||||
@should_not_auth
|
||||
def test_blacklist(self):
|
||||
"""ioloop auth - Blacklist"""
|
||||
# Blacklist 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
self.server.zap_domain = b'global'
|
||||
|
||||
@should_auth
|
||||
def test_whitelist(self):
|
||||
"""ioloop auth - Whitelist"""
|
||||
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
|
||||
|
||||
@should_not_auth
|
||||
def test_plain_unconfigured_server(self):
|
||||
"""ioloop auth - PLAIN, unconfigured server"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Password'
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
self.server.plain_server = True
|
||||
|
||||
@should_auth
|
||||
def test_plain_configured_server(self):
|
||||
"""ioloop auth - PLAIN, configured server"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Password'
|
||||
# Try PLAIN authentication - with server configured, connection should pass
|
||||
self.server.plain_server = True
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
|
||||
@should_not_auth
|
||||
def test_plain_bogus_credentials(self):
|
||||
"""ioloop auth - PLAIN, bogus credentials"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Bogus'
|
||||
self.server.plain_server = True
|
||||
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
|
||||
@should_not_auth
|
||||
def test_curve_unconfigured_server(self):
|
||||
"""ioloop auth - CURVE, unconfigured server"""
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
||||
|
||||
@should_auth
|
||||
def test_curve_allow_any(self):
|
||||
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.allow('127.0.0.1')
|
||||
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
||||
|
||||
@should_auth
|
||||
def test_curve_configured_server(self):
|
||||
"""ioloop auth - CURVE, configured server"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
303
.venv/Lib/site-packages/zmq/tests/test_cffi_backend.py
Normal file
303
.venv/Lib/site-packages/zmq/tests/test_cffi_backend.py
Normal file
@ -0,0 +1,303 @@
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
from zmq.tests import SkipTest
|
||||
|
||||
try:
|
||||
from zmq.backend.cffi import ( # type: ignore
|
||||
IDENTITY,
|
||||
POLLIN,
|
||||
POLLOUT,
|
||||
PULL,
|
||||
PUSH,
|
||||
REP,
|
||||
REQ,
|
||||
zmq_version_info,
|
||||
)
|
||||
from zmq.backend.cffi._cffi import C, ffi
|
||||
|
||||
have_ffi_backend = True
|
||||
except ImportError:
|
||||
have_ffi_backend = False
|
||||
|
||||
|
||||
class TestCFFIBackend(TestCase):
|
||||
def setUp(self):
|
||||
if not have_ffi_backend:
|
||||
raise SkipTest('CFFI not available')
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
version = zmq_version_info()
|
||||
|
||||
assert version[0] in range(2, 11)
|
||||
|
||||
def test_zmq_ctx_new_destroy(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_socket_open_close(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_setsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[3]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
|
||||
assert ret == 0
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_getsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
assert ret == 0
|
||||
|
||||
option_len = ffi.new('size_t*', 3)
|
||||
option = ffi.new('char[3]')
|
||||
ret = C.zmq_getsockopt(socket, IDENTITY, ffi.cast('void*', option), option_len)
|
||||
|
||||
assert ret == 0
|
||||
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
|
||||
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
|
||||
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, 8)
|
||||
|
||||
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind_connect(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
socket1 = C.zmq_socket(ctx, PUSH)
|
||||
socket2 = C.zmq_socket(ctx, PULL)
|
||||
|
||||
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
|
||||
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket1
|
||||
assert ffi.NULL != socket2
|
||||
assert 0 == C.zmq_close(socket1)
|
||||
assert 0 == C.zmq_close(socket2)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_msg_init_close(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_size(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
assert 0 == C.zmq_msg_init_data(
|
||||
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
|
||||
)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[]', b'Hello')
|
||||
assert 0 == C.zmq_msg_init_data(
|
||||
zmq_msg, ffi.cast('void*', message), 5, ffi.NULL, ffi.NULL
|
||||
)
|
||||
|
||||
data = C.zmq_msg_data(zmq_msg)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_send(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_recv(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
|
||||
)
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_poll(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
|
||||
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
receiver_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
receiver_pollitem.socket = receiver
|
||||
receiver_pollitem.fd = 0
|
||||
receiver_pollitem.events = POLLIN | POLLOUT
|
||||
receiver_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(ffi.NULL, 0, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
|
||||
assert ret == 5
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 1
|
||||
|
||||
assert int(receiver_pollitem.revents) & POLLIN
|
||||
assert not int(receiver_pollitem.revents) & POLLOUT
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert ret_recv == 5
|
||||
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), C.zmq_msg_size(zmq_msg2))[:]
|
||||
)
|
||||
|
||||
sender_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
sender_pollitem.socket = sender
|
||||
sender_pollitem.fd = 0
|
||||
sender_pollitem.events = POLLIN | POLLOUT
|
||||
sender_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
zmq_msg_again = ffi.new('zmq_msg_t*')
|
||||
message_again = ffi.new('char[11]', b'Hello Again')
|
||||
|
||||
C.zmq_msg_init_data(
|
||||
zmq_msg_again,
|
||||
ffi.cast('void*', message_again),
|
||||
ffi.cast('size_t', 11),
|
||||
ffi.NULL,
|
||||
ffi.NULL,
|
||||
)
|
||||
|
||||
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert int(sender_pollitem.revents) & POLLIN
|
||||
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
|
||||
assert 11 == C.zmq_msg_size(zmq_msg2)
|
||||
assert (
|
||||
b"Hello Again"
|
||||
== ffi.buffer(C.zmq_msg_data(zmq_msg2), int(C.zmq_msg_size(zmq_msg2)))[:]
|
||||
)
|
||||
assert 0 == C.zmq_close(sender)
|
||||
assert 0 == C.zmq_close(receiver)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg2)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg_again)
|
19
.venv/Lib/site-packages/zmq/tests/test_constants.py
Normal file
19
.venv/Lib/site-packages/zmq/tests/test_constants.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import zmq
|
||||
import zmq.constants
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert zmq.POLLIN is zmq.PollEvent.POLLIN
|
||||
assert zmq.PUSH is zmq.SocketType.PUSH
|
||||
assert zmq.constants.SUBSCRIBE is zmq.SocketOption.SUBSCRIBE
|
||||
|
||||
|
||||
def test_socket_options():
|
||||
assert zmq.IDENTITY is zmq.SocketOption.ROUTING_ID
|
||||
assert zmq.IDENTITY._opt_type is zmq.constants._OptType.bytes
|
||||
assert zmq.AFFINITY._opt_type is zmq.constants._OptType.int64
|
||||
assert zmq.CURVE_SERVER._opt_type is zmq.constants._OptType.int
|
||||
assert zmq.FD._opt_type is zmq.constants._OptType.fd
|
401
.venv/Lib/site-packages/zmq/tests/test_context.py
Normal file
401
.venv/Lib/site-packages/zmq/tests/test_context.py
Normal file
@ -0,0 +1,401 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Event, Thread
|
||||
from unittest import mock
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest
|
||||
|
||||
|
||||
class KwargTestSocket(zmq.Socket):
|
||||
test_kwarg_value = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class KwargTestContext(zmq.Context):
|
||||
_socket_class = KwargTestSocket
|
||||
|
||||
|
||||
class TestContext(BaseZMQTestCase):
|
||||
def test_init(self):
|
||||
c1 = self.Context()
|
||||
assert isinstance(c1, self.Context)
|
||||
del c1
|
||||
c2 = self.Context()
|
||||
assert isinstance(c2, self.Context)
|
||||
del c2
|
||||
c3 = self.Context()
|
||||
assert isinstance(c3, self.Context)
|
||||
del c3
|
||||
|
||||
_repr_cls = "zmq.Context"
|
||||
|
||||
def test_repr(self):
|
||||
with self.Context() as ctx:
|
||||
assert f'{self._repr_cls}()' in repr(ctx)
|
||||
assert 'closed' not in repr(ctx)
|
||||
with ctx.socket(zmq.PUSH) as push:
|
||||
assert f'{self._repr_cls}(1 socket)' in repr(ctx)
|
||||
with ctx.socket(zmq.PULL) as pull:
|
||||
assert f'{self._repr_cls}(2 sockets)' in repr(ctx)
|
||||
assert f'{self._repr_cls}()' in repr(ctx)
|
||||
assert 'closed' in repr(ctx)
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
assert 'socket' in dir(ctx)
|
||||
if zmq.zmq_version_info() > (3,):
|
||||
assert 'IO_THREADS' in dir(ctx)
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
m = mock.Mock(spec=self.context)
|
||||
|
||||
def test_term(self):
|
||||
c = self.Context()
|
||||
c.term()
|
||||
assert c.closed
|
||||
|
||||
def test_context_manager(self):
|
||||
with self.Context() as c:
|
||||
pass
|
||||
assert c.closed
|
||||
|
||||
def test_fail_init(self):
|
||||
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
|
||||
|
||||
def test_term_hang(self):
|
||||
rep, req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
req.setsockopt(zmq.LINGER, 0)
|
||||
req.send(b'hello', copy=False)
|
||||
req.close()
|
||||
rep.close()
|
||||
self.context.term()
|
||||
|
||||
def test_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
c2 = self.Context.instance(io_threads=2)
|
||||
assert c2 is ctx
|
||||
c2.term()
|
||||
c3 = self.Context.instance()
|
||||
c4 = self.Context.instance()
|
||||
assert not c3 is c2
|
||||
assert not c3.closed
|
||||
assert c3 is c4
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
self.context.term()
|
||||
|
||||
class SubContext(zmq.Context):
|
||||
pass
|
||||
|
||||
sctx = SubContext.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is SubContext
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
self.context.term()
|
||||
|
||||
class SubContextInherit(zmq.Context):
|
||||
pass
|
||||
|
||||
class SubContextNoInherit(zmq.Context):
|
||||
_instance = None
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
sctx = SubContextInherit.instance()
|
||||
sctx2 = SubContextNoInherit.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
sctx2.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is zmq.Context
|
||||
assert type(sctx2) is SubContextNoInherit
|
||||
|
||||
def test_instance_threadsafe(self):
|
||||
self.context.term() # clear default context
|
||||
|
||||
q = Queue()
|
||||
# slow context initialization,
|
||||
# to ensure that we are both trying to create one at the same time
|
||||
class SlowContext(self.Context):
|
||||
def __init__(self, *a, **kw):
|
||||
time.sleep(1)
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
def f():
|
||||
q.put(SlowContext.instance())
|
||||
|
||||
# call ctx.instance() in several threads at once
|
||||
N = 16
|
||||
threads = [Thread(target=f) for i in range(N)]
|
||||
[t.start() for t in threads]
|
||||
# also call it in the main thread (not first)
|
||||
ctx = SlowContext.instance()
|
||||
assert isinstance(ctx, SlowContext)
|
||||
# check that all the threads got the same context
|
||||
for i in range(N):
|
||||
thread_ctx = q.get(timeout=5)
|
||||
assert thread_ctx is ctx
|
||||
# cleanup
|
||||
ctx.term()
|
||||
[t.join(timeout=5) for t in threads]
|
||||
|
||||
def test_socket_passes_kwargs(self):
|
||||
test_kwarg_value = 'testing one two three'
|
||||
with KwargTestContext() as ctx:
|
||||
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
|
||||
assert socket.test_kwarg_value is test_kwarg_value
|
||||
|
||||
def test_many_sockets(self):
|
||||
"""opening and closing many sockets shouldn't cause problems"""
|
||||
ctx = self.Context()
|
||||
for i in range(16):
|
||||
sockets = [ctx.socket(zmq.REP) for i in range(65)]
|
||||
[s.close() for s in sockets]
|
||||
# give the reaper a chance
|
||||
time.sleep(1e-2)
|
||||
ctx.term()
|
||||
|
||||
def test_sockopts(self):
|
||||
"""setting socket options with ctx attributes"""
|
||||
ctx = self.Context()
|
||||
ctx.linger = 5
|
||||
assert ctx.linger == 5
|
||||
s = ctx.socket(zmq.REQ)
|
||||
assert s.linger == 5
|
||||
assert s.getsockopt(zmq.LINGER) == 5
|
||||
s.close()
|
||||
# check that subscribe doesn't get set on sockets that don't subscribe:
|
||||
ctx.subscribe = b''
|
||||
s = ctx.socket(zmq.REQ)
|
||||
s.close()
|
||||
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='Segfaults on Windows')
|
||||
def test_destroy(self):
|
||||
"""Context.destroy should close sockets"""
|
||||
ctx = self.Context()
|
||||
sockets = [ctx.socket(zmq.REP) for i in range(65)]
|
||||
|
||||
# close half of the sockets
|
||||
[s.close() for s in sockets[::2]]
|
||||
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in sockets:
|
||||
assert s.closed
|
||||
|
||||
def test_destroy_linger(self):
|
||||
"""Context.destroy should set linger on closing sockets"""
|
||||
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
req.send(b'hi')
|
||||
time.sleep(1e-2)
|
||||
self.context.destroy(linger=0)
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in (req, rep):
|
||||
assert s.closed
|
||||
|
||||
def test_term_noclose(self):
|
||||
"""Context.term won't close sockets"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REQ)
|
||||
assert not s.closed
|
||||
t = Thread(target=ctx.term)
|
||||
t.start()
|
||||
t.join(timeout=0.1)
|
||||
assert t.is_alive(), "Context should be waiting"
|
||||
s.close()
|
||||
t.join(timeout=0.1)
|
||||
assert not t.is_alive(), "Context should have closed"
|
||||
|
||||
def test_gc(self):
|
||||
"""test close&term by garbage collection alone"""
|
||||
if PYPY:
|
||||
raise SkipTest("GC doesn't work ")
|
||||
|
||||
# test credit @dln (GH #137):
|
||||
def gcf():
|
||||
def inner():
|
||||
ctx = self.Context()
|
||||
ctx.socket(zmq.PUSH)
|
||||
|
||||
inner()
|
||||
gc.collect()
|
||||
|
||||
t = Thread(target=gcf)
|
||||
t.start()
|
||||
t.join(timeout=1)
|
||||
assert not t.is_alive(), "Garbage collection should have cleaned up context"
|
||||
|
||||
def test_cyclic_destroy(self):
|
||||
"""ctx.destroy should succeed when cyclic ref prevents gc"""
|
||||
# test credit @dln (GH #137):
|
||||
class CyclicReference:
|
||||
def __init__(self, parent=None):
|
||||
self.parent = parent
|
||||
|
||||
def crash(self, sock):
|
||||
self.sock = sock
|
||||
self.child = CyclicReference(self)
|
||||
|
||||
def crash_zmq():
|
||||
ctx = self.Context()
|
||||
sock = ctx.socket(zmq.PULL)
|
||||
c = CyclicReference()
|
||||
c.crash(sock)
|
||||
ctx.destroy()
|
||||
|
||||
crash_zmq()
|
||||
|
||||
def test_term_thread(self):
|
||||
"""ctx.term should not crash active threads (#139)"""
|
||||
ctx = self.Context()
|
||||
evt = Event()
|
||||
evt.clear()
|
||||
|
||||
def block():
|
||||
s = ctx.socket(zmq.REP)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
evt.set()
|
||||
try:
|
||||
s.recv()
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.ETERM
|
||||
return
|
||||
finally:
|
||||
s.close()
|
||||
self.fail("recv should have been interrupted with ETERM")
|
||||
|
||||
t = Thread(target=block)
|
||||
t.start()
|
||||
|
||||
evt.wait(1)
|
||||
assert evt.is_set(), "sync event never fired"
|
||||
time.sleep(0.01)
|
||||
ctx.term()
|
||||
t.join(timeout=1)
|
||||
assert not t.is_alive(), "term should have interrupted s.recv()"
|
||||
|
||||
def test_destroy_no_sockets(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
s.close()
|
||||
ctx.destroy()
|
||||
assert s.closed
|
||||
assert ctx.closed
|
||||
|
||||
def test_ctx_opts(self):
|
||||
if zmq.zmq_version_info() < (3,):
|
||||
raise SkipTest("context options require libzmq 3")
|
||||
ctx = self.Context()
|
||||
ctx.set(zmq.MAX_SOCKETS, 2)
|
||||
assert ctx.get(zmq.MAX_SOCKETS) == 2
|
||||
ctx.max_sockets = 100
|
||||
assert ctx.max_sockets == 100
|
||||
assert ctx.get(zmq.MAX_SOCKETS) == 100
|
||||
|
||||
def test_copy(self):
|
||||
c1 = self.Context()
|
||||
c2 = copy.copy(c1)
|
||||
c2b = copy.deepcopy(c1)
|
||||
c3 = copy.deepcopy(c2)
|
||||
assert c2._shadow
|
||||
assert c3._shadow
|
||||
assert c1.underlying == c2.underlying
|
||||
assert c1.underlying == c3.underlying
|
||||
assert c1.underlying == c2b.underlying
|
||||
s = c3.socket(zmq.PUB)
|
||||
s.close()
|
||||
c1.term()
|
||||
|
||||
def test_shadow(self):
|
||||
ctx = self.Context()
|
||||
ctx2 = self.Context.shadow(ctx.underlying)
|
||||
assert ctx.underlying == ctx2.underlying
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
del ctx2
|
||||
assert not ctx.closed
|
||||
s = ctx.socket(zmq.PUB)
|
||||
ctx2 = self.Context.shadow(ctx.underlying)
|
||||
s2 = ctx2.socket(zmq.PUB)
|
||||
s.close()
|
||||
s2.close()
|
||||
ctx.term()
|
||||
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
|
||||
del ctx2
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket, zstr
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
a = zsocket.new(ctx, zmq.PUSH)
|
||||
zsocket.bind(a, "inproc://a")
|
||||
ctx2 = self.Context.shadow_pyczmq(ctx)
|
||||
b = ctx2.socket(zmq.PULL)
|
||||
b.connect("inproc://a")
|
||||
zstr.send(a, b'hi')
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == b'hi'
|
||||
b.close()
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='No fork on Windows')
|
||||
def test_fork_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
parent_ctx_id = id(ctx)
|
||||
r_fd, w_fd = os.pipe()
|
||||
reader = os.fdopen(r_fd, 'r')
|
||||
child_pid = os.fork()
|
||||
if child_pid == 0:
|
||||
ctx = self.Context.instance()
|
||||
writer = os.fdopen(w_fd, 'w')
|
||||
child_ctx_id = id(ctx)
|
||||
ctx.term()
|
||||
writer.write(str(child_ctx_id) + "\n")
|
||||
writer.flush()
|
||||
writer.close()
|
||||
os._exit(0)
|
||||
else:
|
||||
os.close(w_fd)
|
||||
|
||||
child_id_s = reader.readline()
|
||||
reader.close()
|
||||
assert child_id_s
|
||||
assert int(child_id_s) != parent_ctx_id
|
||||
ctx.term()
|
||||
|
||||
|
||||
if False: # disable green context tests
|
||||
|
||||
class TestContextGreen(GreenTest, TestContext):
|
||||
"""gevent subclass of context tests"""
|
||||
|
||||
# skip tests that use real threads:
|
||||
test_gc = GreenTest.skip_green
|
||||
test_term_thread = GreenTest.skip_green
|
||||
test_destroy_linger = GreenTest.skip_green
|
||||
_repr_cls = "zmq.green.Context"
|
43
.venv/Lib/site-packages/zmq/tests/test_cython.py
Normal file
43
.venv/Lib/site-packages/zmq/tests/test_cython.py
Normal file
@ -0,0 +1,43 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize('language_level', [3, 2])
|
||||
def test_cython(language_level, request, tmpdir):
|
||||
import pyximport
|
||||
|
||||
assert 'zmq.tests.cython_ext' not in sys.modules
|
||||
|
||||
importers = pyximport.install(
|
||||
setup_args=dict(include_dirs=zmq.get_includes()),
|
||||
language_level=language_level,
|
||||
build_dir=str(tmpdir),
|
||||
)
|
||||
|
||||
cython_ext = None
|
||||
|
||||
def unimport():
|
||||
pyximport.uninstall(*importers)
|
||||
sys.modules.pop('zmq.tests.cython_ext', None)
|
||||
|
||||
request.addfinalizer(unimport)
|
||||
|
||||
# this import tests the compilation
|
||||
from . import cython_ext
|
||||
|
||||
assert hasattr(cython_ext, 'send_recv_test')
|
||||
|
||||
# call the compiled function
|
||||
# this shouldn't do much
|
||||
msg = b'my msg'
|
||||
received = cython_ext.send_recv_test(msg)
|
||||
assert received == msg
|
396
.venv/Lib/site-packages/zmq/tests/test_decorators.py
Normal file
396
.venv/Lib/site-packages/zmq/tests/test_decorators.py
Normal file
@ -0,0 +1,396 @@
|
||||
import threading
|
||||
|
||||
from pytest import fixture, raises
|
||||
|
||||
import zmq
|
||||
from zmq.decorators import context, socket
|
||||
from zmq.tests import BaseZMQTestCase, term_context
|
||||
|
||||
##############################################
|
||||
# Test cases for @context
|
||||
##############################################
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def term_context_instance(request):
|
||||
request.addfinalizer(lambda: term_context(zmq.Context.instance(), timeout=10))
|
||||
|
||||
|
||||
def test_ctx():
|
||||
@context()
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_orig_args():
|
||||
@context()
|
||||
def f(foo, bar, ctx, baz=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert foo == 42
|
||||
assert bar is True
|
||||
assert baz == 'mock'
|
||||
|
||||
f(42, True, baz='mock')
|
||||
|
||||
|
||||
def test_ctx_arg_naming():
|
||||
@context('myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_args():
|
||||
@context('ctx', 5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_arg_kwarg():
|
||||
@context('ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kw_naming():
|
||||
@context(name='myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs_default():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_keyword_miss():
|
||||
@context(name='ctx')
|
||||
def test(other_name):
|
||||
pass # the keyword ``ctx`` not found
|
||||
|
||||
with raises(TypeError):
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_multi_assign():
|
||||
@context(name='ctx')
|
||||
def test(ctx):
|
||||
pass # explosion
|
||||
|
||||
with raises(TypeError):
|
||||
test('mock')
|
||||
|
||||
|
||||
def test_ctx_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@context()
|
||||
def f(key, ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
result[key] = ctx
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_multi_thread():
|
||||
@context()
|
||||
@context()
|
||||
def f(foo, bar):
|
||||
assert isinstance(foo, zmq.Context), foo
|
||||
assert isinstance(bar, zmq.Context), bar
|
||||
|
||||
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
##############################################
|
||||
# Test cases for @socket
|
||||
##############################################
|
||||
|
||||
|
||||
def test_ctx_skt():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def test(ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_name():
|
||||
@context()
|
||||
@socket('myskt', zmq.PUB)
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_kwarg():
|
||||
@context()
|
||||
@socket(zmq.PUB, name='myskt')
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_skt_name():
|
||||
@context('ctx')
|
||||
@socket('skt', zmq.PUB, context_name='ctx')
|
||||
def test(ctx, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_default_ctx():
|
||||
@socket(zmq.PUB)
|
||||
def test(skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.context is zmq.Context.instance()
|
||||
assert skt.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@socket(zmq.PUB)
|
||||
def f(key, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_skt_reinit():
|
||||
result = {'foo': {'ctx': None, 'skt': None}, 'bar': {'ctx': None, 'skt': None}}
|
||||
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def f(key, ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key]['ctx'] = ctx
|
||||
result[key]['skt'] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo']['ctx'] is not None, result
|
||||
assert result['foo']['skt'] is not None, result
|
||||
assert result['bar']['ctx'] is not None, result
|
||||
assert result['bar']['skt'] is not None, result
|
||||
assert result['foo']['ctx'] is not result['bar']['ctx'], result
|
||||
assert result['foo']['skt'] is not result['bar']['skt'], result
|
||||
|
||||
|
||||
def test_skt_type_miss():
|
||||
@context()
|
||||
@socket('myskt')
|
||||
def f(ctx, myskt):
|
||||
pass # the socket type is missing
|
||||
|
||||
with raises(TypeError):
|
||||
f()
|
||||
|
||||
|
||||
def test_multi_skts():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_single_ctx():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(ctx, pub, sub, push):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
assert push.context is ctx
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_with_name():
|
||||
@socket('foo', zmq.PUSH)
|
||||
@socket('bar', zmq.SUB)
|
||||
@socket('baz', zmq.PUB)
|
||||
def test(foo, bar, baz):
|
||||
assert isinstance(foo, zmq.Socket), foo
|
||||
assert isinstance(bar, zmq.Socket), bar
|
||||
assert isinstance(baz, zmq.Socket), baz
|
||||
|
||||
assert foo.context is zmq.Context.instance()
|
||||
assert bar.context is zmq.Context.instance()
|
||||
assert baz.context is zmq.Context.instance()
|
||||
|
||||
assert foo.type == zmq.PUSH
|
||||
assert bar.type == zmq.SUB
|
||||
assert baz.type == zmq.PUB
|
||||
|
||||
test()
|
||||
|
||||
|
||||
def test_func_return():
|
||||
@context()
|
||||
def f(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
return 'something'
|
||||
|
||||
assert f() == 'something'
|
||||
|
||||
|
||||
def test_skt_multi_thread():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def f(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
assert len(set(map(id, [pub, sub, push]))) == 3
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
class TestMethodDecorators(BaseZMQTestCase):
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
|
||||
assert isinstance(self, TestMethodDecorators), self
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert foo == 'bar'
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
|
||||
def test_multi_skts_method(self):
|
||||
self.multi_skts_method()
|
||||
|
||||
def test_multi_skts_method_other_args(self):
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def f(foo, pub, sub, bar=None):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
|
||||
assert foo == 'mock'
|
||||
assert bar == 'fake'
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
|
||||
f('mock', bar='fake')
|
168
.venv/Lib/site-packages/zmq/tests/test_device.py
Normal file
168
.venv/Lib/site-packages/zmq/tests/test_device.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, GreenTest, SkipTest, have_gevent
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestDevice(BaseZMQTestCase):
|
||||
def test_device_types(self):
|
||||
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
|
||||
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
|
||||
assert dev.device_type == devtype
|
||||
del dev
|
||||
|
||||
def test_device_attributes(self):
|
||||
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
|
||||
assert dev.in_type == zmq.SUB
|
||||
assert dev.out_type == zmq.PUB
|
||||
assert dev.device_type == zmq.QUEUE
|
||||
assert dev.daemon == True
|
||||
del dev
|
||||
|
||||
def test_single_socket_forwarder_connect(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_in('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_out('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_single_socket_forwarder_bind(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
assert msg == self.recv(req)
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_device_bind_to_random_with_args(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_device_bind_to_random_binderror(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
try:
|
||||
for i in range(11):
|
||||
dev.bind_in_to_random_port(iface, min_port=10000, max_port=10010)
|
||||
except zmq.ZMQBindError as e:
|
||||
return
|
||||
else:
|
||||
self.fail('Should have failed')
|
||||
|
||||
def test_proxy(self):
|
||||
if zmq.zmq_version_info() < (3, 2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
|
||||
def test_proxy_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (3, 2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
|
||||
import zmq.green
|
||||
|
||||
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
|
||||
def test_green_device(self):
|
||||
rep = self.context.socket(zmq.REP)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([req, rep])
|
||||
port = rep.bind_to_random_port('tcp://127.0.0.1')
|
||||
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
req.send(b'hi')
|
||||
timeout = gevent.Timeout(3)
|
||||
timeout.start()
|
||||
receiver = gevent.spawn(req.recv)
|
||||
assert receiver.get(2) == b'hi'
|
||||
timeout.cancel()
|
||||
g.kill(block=True)
|
47
.venv/Lib/site-packages/zmq/tests/test_draft.py
Normal file
47
.venv/Lib/site-packages/zmq/tests/test_draft.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestDraftSockets(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if not zmq.DRAFT_API:
|
||||
pytest.skip("draft api unavailable")
|
||||
super().setUp()
|
||||
|
||||
def test_client_server(self):
|
||||
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
|
||||
client.send(b'request')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.routing_id is not None
|
||||
server.send(b'reply', routing_id=msg.routing_id)
|
||||
reply = self.recv(client)
|
||||
assert reply == b'reply'
|
||||
|
||||
def test_radio_dish(self):
|
||||
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
|
||||
dish.rcvtimeo = 250
|
||||
group = 'mygroup'
|
||||
dish.join(group)
|
||||
received_count = 0
|
||||
received = set()
|
||||
sent = set()
|
||||
for i in range(10):
|
||||
msg = str(i).encode('ascii')
|
||||
sent.add(msg)
|
||||
radio.send(msg, group=group)
|
||||
try:
|
||||
recvd = dish.recv()
|
||||
except zmq.Again:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
received.add(recvd)
|
||||
received_count += 1
|
||||
# assert that we got *something*
|
||||
assert len(received.intersection(sent)) >= 5
|
37
.venv/Lib/site-packages/zmq/tests/test_error.py
Normal file
37
.venv/Lib/site-packages/zmq/tests/test_error.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq import Again, ContextTerminated, ZMQError, strerror
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestZMQError(BaseZMQTestCase):
|
||||
def test_strerror(self):
|
||||
"""test that strerror gets the right type."""
|
||||
for i in range(10):
|
||||
e = strerror(i)
|
||||
assert isinstance(e, str)
|
||||
|
||||
def test_zmqerror(self):
|
||||
for errno in range(10):
|
||||
e = ZMQError(errno)
|
||||
assert e.errno == errno
|
||||
assert str(e) == strerror(errno)
|
||||
|
||||
def test_again(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
|
||||
def atest_ctxterm(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
t = Thread(target=self.context.term)
|
||||
t.start()
|
||||
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
t.join()
|
26
.venv/Lib/site-packages/zmq/tests/test_etc.py
Normal file
26
.venv/Lib/site-packages/zmq/tests/test_etc.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
|
||||
only_bundled = mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
|
||||
|
||||
|
||||
@mark.skipif('zmq.zmq_version_info() < (4, 1)')
|
||||
def test_has():
|
||||
assert not zmq.has('something weird')
|
||||
|
||||
|
||||
@only_bundled
|
||||
def test_has_curve():
|
||||
"""bundled libzmq has curve support"""
|
||||
assert zmq.has('curve')
|
||||
|
||||
|
||||
@only_bundled
|
||||
def test_has_ipc():
|
||||
"""bundled libzmq has ipc support"""
|
||||
assert zmq.has('ipc')
|
353
.venv/Lib/site-packages/zmq/tests/test_future.py
Normal file
353
.venv/Lib/site-packages/zmq/tests/test_future.py
Normal file
@ -0,0 +1,353 @@
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
gen = pytest.importorskip('tornado.gen')
|
||||
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
import zmq
|
||||
from zmq.eventloop import future
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestFutureSocket(BaseZMQTestCase):
|
||||
Context = future.Context
|
||||
|
||||
def setUp(self):
|
||||
self.loop = IOLoop()
|
||||
self.loop.make_current()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.loop:
|
||||
self.loop.close(all_fds=True)
|
||||
IOLoop.clear_current()
|
||||
IOLoop.clear_instance()
|
||||
|
||||
def test_socket_class(self):
|
||||
s = self.context.socket(zmq.PUSH)
|
||||
assert isinstance(s, future.Socket)
|
||||
s.close()
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
actx = self.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
ctx = zmq.Context.instance()
|
||||
actx = self.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_recv_multipart(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
await a.send(b"hi")
|
||||
recvd = await f
|
||||
assert recvd == [b'hi']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.done()
|
||||
assert f1.result() == b'hi'
|
||||
assert recvd == b'there'
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_cancel(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_recv_timeout(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with pytest.raises(zmq.Again):
|
||||
await f1
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
recvd = await f2
|
||||
assert f2.done()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_send_timeout(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send(b"not going anywhere")
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_send_noblock(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send(b"not going anywhere", flags=zmq.NOBLOCK)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_send_multipart_noblock(self):
|
||||
async def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
await s.send_multipart([b"not going anywhere"], flags=zmq.NOBLOCK)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_string(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = 'πøøπ'
|
||||
await a.send_string(msg)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == msg
|
||||
assert recvd == msg
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json_cancelled(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
await gen.sleep(0)
|
||||
obj = dict(a=5)
|
||||
await a.send_json(obj)
|
||||
with pytest.raises(future.CancelledError):
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = await b.poll(timeout=5)
|
||||
assert events
|
||||
await gen.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
recvd = await gen.with_timeout(timedelta(seconds=5), b.recv_json())
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_pyobj(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
await a.send_pyobj(obj)
|
||||
recvd = await f
|
||||
assert f.done()
|
||||
assert f.result() == obj
|
||||
assert recvd == obj
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize(self):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
await a.send_serialized(msg, serialize)
|
||||
recvd = await b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
await b.send_serialized(recvd, serialize)
|
||||
r2 = await a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize_error(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
await a.send_serialized(json, json.dumps)
|
||||
|
||||
await a.send(b"not json")
|
||||
with pytest.raises(TypeError):
|
||||
await b.recv_serialized(json.loads)
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_poll(self):
|
||||
async def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.poll(timeout=0)
|
||||
assert f.done()
|
||||
assert f.result() == 0
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = await f
|
||||
assert evt == 0
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
await a.send_multipart([b"hi", b"there"])
|
||||
evt = await f
|
||||
assert evt == zmq.POLLIN
|
||||
recvd = await b.recv_multipart()
|
||||
assert recvd == [b'hi', b'there']
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'), reason='Windows unsupported socket type'
|
||||
)
|
||||
def test_poll_base_socket(self):
|
||||
async def test():
|
||||
ctx = zmq.Context()
|
||||
url = 'inproc://test'
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = future.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b'hi', b'there'])
|
||||
evt = await f
|
||||
assert evt == [(b, zmq.POLLIN)]
|
||||
recvd = b.recv_multipart()
|
||||
assert recvd == [b'hi', b'there']
|
||||
a.close()
|
||||
b.close()
|
||||
ctx.term()
|
||||
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_close_all_fds(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
s._get_loop()
|
||||
self.loop.close(all_fds=True)
|
||||
self.loop = None # avoid second close later
|
||||
assert s.closed
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Windows does not support polling on files',
|
||||
)
|
||||
def test_poll_raw(self):
|
||||
async def test():
|
||||
p = future.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = await p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
evts = await p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b'x'
|
||||
r.close()
|
||||
w.close()
|
||||
|
||||
self.loop.run_sync(test)
|
65
.venv/Lib/site-packages/zmq/tests/test_imports.py
Normal file
65
.venv/Lib/site-packages/zmq/tests/test_imports.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestImports(TestCase):
|
||||
"""Test Imports - the quickest test to ensure that we haven't
|
||||
introduced version-incompatible syntax errors."""
|
||||
|
||||
def test_toplevel(self):
|
||||
"""test toplevel import"""
|
||||
import zmq
|
||||
|
||||
def test_core(self):
|
||||
"""test core imports"""
|
||||
from zmq import (
|
||||
Context,
|
||||
Frame,
|
||||
Poller,
|
||||
Socket,
|
||||
constants,
|
||||
device,
|
||||
proxy,
|
||||
pyzmq_version,
|
||||
pyzmq_version_info,
|
||||
zmq_version,
|
||||
zmq_version_info,
|
||||
)
|
||||
|
||||
def test_devices(self):
|
||||
"""test device imports"""
|
||||
import zmq.devices
|
||||
from zmq.devices import basedevice, monitoredqueue, monitoredqueuedevice
|
||||
|
||||
def test_log(self):
|
||||
"""test log imports"""
|
||||
import zmq.log
|
||||
from zmq.log import handlers
|
||||
|
||||
def test_eventloop(self):
|
||||
"""test eventloop imports"""
|
||||
try:
|
||||
import tornado
|
||||
except ImportError:
|
||||
pytest.skip('requires tornado')
|
||||
import zmq.eventloop
|
||||
from zmq.eventloop import ioloop, zmqstream
|
||||
|
||||
def test_utils(self):
|
||||
"""test util imports"""
|
||||
import zmq.utils
|
||||
from zmq.utils import jsonapi, strtypes
|
||||
|
||||
def test_ssh(self):
|
||||
"""test ssh imports"""
|
||||
from zmq.ssh import tunnel
|
||||
|
||||
def test_decorators(self):
|
||||
"""test decorators imports"""
|
||||
from zmq.decorators import context, socket
|
33
.venv/Lib/site-packages/zmq/tests/test_includes.py
Normal file
33
.venv/Lib/site-packages/zmq/tests/test_includes.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
class TestIncludes(TestCase):
|
||||
def test_get_includes(self):
|
||||
from os.path import basename
|
||||
|
||||
includes = zmq.get_includes()
|
||||
assert isinstance(includes, list)
|
||||
assert len(includes) >= 2
|
||||
parent = includes[0]
|
||||
assert isinstance(parent, str)
|
||||
utilsdir = includes[1]
|
||||
assert isinstance(utilsdir, str)
|
||||
utils = basename(utilsdir)
|
||||
assert utils == "utils"
|
||||
|
||||
def test_get_library_dirs(self):
|
||||
from os.path import basename
|
||||
|
||||
libdirs = zmq.get_library_dirs()
|
||||
assert isinstance(libdirs, list)
|
||||
assert len(libdirs) == 1
|
||||
parent = libdirs[0]
|
||||
assert isinstance(parent, str)
|
||||
libdir = basename(parent)
|
||||
assert libdir == "zmq"
|
142
.venv/Lib/site-packages/zmq/tests/test_ioloop.py
Normal file
142
.venv/Lib/site-packages/zmq/tests/test_ioloop.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, have_gevent
|
||||
|
||||
try:
|
||||
from tornado.ioloop import IOLoop as BaseIOLoop
|
||||
|
||||
from zmq.eventloop import ioloop
|
||||
except ImportError:
|
||||
_tornado = False
|
||||
else:
|
||||
_tornado = True
|
||||
|
||||
# tornado 5 with asyncio disables custom IOLoop implementations
|
||||
t5asyncio = False
|
||||
if _tornado:
|
||||
import tornado
|
||||
|
||||
if tornado.version_info >= (5,) and asyncio:
|
||||
t5asyncio = True
|
||||
|
||||
|
||||
def printer():
|
||||
os.system("say hello")
|
||||
raise Exception
|
||||
print(time.time())
|
||||
|
||||
|
||||
class Delay(threading.Thread):
|
||||
def __init__(self, f, delay=1):
|
||||
self.f = f
|
||||
self.delay = delay
|
||||
self.aborted = False
|
||||
self.cond = threading.Condition()
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
self.cond.acquire()
|
||||
self.cond.wait(self.delay)
|
||||
self.cond.release()
|
||||
if not self.aborted:
|
||||
self.f()
|
||||
|
||||
def abort(self):
|
||||
self.aborted = True
|
||||
self.cond.acquire()
|
||||
self.cond.notify()
|
||||
self.cond.release()
|
||||
|
||||
|
||||
class TestIOLoop(BaseZMQTestCase):
|
||||
if _tornado:
|
||||
IOLoop = ioloop.IOLoop
|
||||
|
||||
def setUp(self):
|
||||
if not _tornado:
|
||||
pytest.skip("tornado required")
|
||||
super().setUp()
|
||||
if asyncio:
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
BaseIOLoop.clear_current()
|
||||
BaseIOLoop.clear_instance()
|
||||
|
||||
def test_simple(self):
|
||||
"""simple IOLoop creation test"""
|
||||
loop = self.IOLoop()
|
||||
loop.make_current()
|
||||
dc = ioloop.PeriodicCallback(loop.stop, 200)
|
||||
pc = ioloop.PeriodicCallback(lambda: None, 10)
|
||||
pc.start()
|
||||
dc.start()
|
||||
t = Delay(loop.stop, 1)
|
||||
t.start()
|
||||
loop.start()
|
||||
if t.is_alive():
|
||||
t.abort()
|
||||
else:
|
||||
self.fail("IOLoop failed to exit")
|
||||
|
||||
def test_instance(self):
|
||||
"""IOLoop.instance returns the right object"""
|
||||
loop = self.IOLoop.instance()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.instance()
|
||||
assert base_loop is loop
|
||||
|
||||
def test_current(self):
|
||||
"""IOLoop.current returns the right object"""
|
||||
loop = ioloop.IOLoop.current()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.current()
|
||||
assert base_loop is loop
|
||||
|
||||
def test_close_all(self):
|
||||
"""Test close(all_fds=True)"""
|
||||
loop = self.IOLoop.current()
|
||||
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ)
|
||||
loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ)
|
||||
assert req.closed == False
|
||||
assert rep.closed == False
|
||||
loop.close(all_fds=True)
|
||||
assert req.closed == True
|
||||
assert rep.closed == True
|
||||
|
||||
|
||||
if have_gevent and _tornado:
|
||||
import zmq.green.eventloop.ioloop as green_ioloop
|
||||
|
||||
class TestIOLoopGreen(TestIOLoop):
|
||||
IOLoop = green_ioloop.IOLoop
|
||||
|
||||
def xtest_instance(self):
|
||||
"""Green IOLoop.instance returns the right object"""
|
||||
loop = self.IOLoop.instance()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.instance()
|
||||
assert base_loop is loop
|
||||
|
||||
def xtest_current(self):
|
||||
"""Green IOLoop.current returns the right object"""
|
||||
loop = self.IOLoop.current()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.current()
|
||||
assert base_loop is loop
|
178
.venv/Lib/site-packages/zmq/tests/test_log.py
Normal file
178
.venv/Lib/site-packages/zmq/tests/test_log.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.log import handlers
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestPubLog(BaseZMQTestCase):
|
||||
|
||||
iface = 'inproc://zmqlog'
|
||||
topic = 'zmq'
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
# print dir(self)
|
||||
logger = logging.getLogger('zmqtest')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return logger
|
||||
|
||||
def connect_handler(self, topic=None):
|
||||
topic = self.topic if topic is None else topic
|
||||
logger = self.logger
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = topic
|
||||
logger.addHandler(handler)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
time.sleep(0.1)
|
||||
return logger, handler, sub
|
||||
|
||||
def test_init_iface(self):
|
||||
logger = self.logger
|
||||
ctx = self.context
|
||||
handler = handlers.PUBHandler(self.iface)
|
||||
assert not handler.ctx is ctx
|
||||
self.sockets.append(handler.socket)
|
||||
# handler.ctx.term()
|
||||
handler = handlers.PUBHandler(self.iface, self.context)
|
||||
self.sockets.append(handler.socket)
|
||||
assert handler.ctx is ctx
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
sub = ctx.socket(zmq.SUB)
|
||||
self.sockets.append(sub)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
|
||||
sub.connect(self.iface)
|
||||
import time
|
||||
|
||||
time.sleep(0.25)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
assert topic == b'zmq.INFO'
|
||||
assert msg2 == (msg1 + "\n").encode("utf8")
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_init_socket(self):
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
logger = self.logger
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
|
||||
assert handler.socket is pub
|
||||
assert handler.ctx is pub.context
|
||||
assert handler.ctx is self.context
|
||||
sub.setsockopt(zmq.SUBSCRIBE, self.topic.encode())
|
||||
import time
|
||||
|
||||
time.sleep(0.1)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
assert topic == b'zmq.INFO'
|
||||
assert msg2 == (msg1 + "\n").encode("utf8")
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_root_topic(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.socket.bind(self.iface)
|
||||
sub2 = sub.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub2)
|
||||
sub2.connect(self.iface)
|
||||
sub2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.root_topic = b'twoonly'
|
||||
msg1 = 'ignored'
|
||||
logger.info(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
|
||||
topic, msg2 = sub2.recv_multipart()
|
||||
assert topic == b'twoonly.INFO'
|
||||
assert msg2 == (msg1 + '\n').encode()
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_blank_root_topic(self):
|
||||
logger, handler, sub_everything = self.connect_handler()
|
||||
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.socket.bind(self.iface)
|
||||
sub_only_info = sub_everything.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub_only_info)
|
||||
sub_only_info.connect(self.iface)
|
||||
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
|
||||
handler.setRootTopic(b'')
|
||||
msg_debug = 'debug_message'
|
||||
logger.debug(msg_debug)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
|
||||
topic, msg_debug_response = sub_everything.recv_multipart()
|
||||
assert topic == b'DEBUG'
|
||||
msg_info = 'info_message'
|
||||
logger.info(msg_info)
|
||||
topic, msg_info_response_everything = sub_everything.recv_multipart()
|
||||
assert topic == b'INFO'
|
||||
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
|
||||
assert topic == b'INFO'
|
||||
assert msg_info_response_everything == msg_info_response_onlyinfo
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_unicode_message(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
base_topic = (self.topic + '.INFO').encode()
|
||||
for msg, expected in [
|
||||
('hello', [base_topic, b'hello\n']),
|
||||
('héllo', [base_topic, 'héllo\n'.encode()]),
|
||||
('tøpic::héllo', [base_topic + '.tøpic'.encode(), 'héllo\n'.encode()]),
|
||||
]:
|
||||
logger.info(msg)
|
||||
received = sub.recv_multipart()
|
||||
assert received == expected
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_set_info_formatter_via_property(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'info message UNITTEST\n'
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_global_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST info message'
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST debug message'
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_debug_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
|
||||
handler.setFormatter(formatter, logging.DEBUG)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, handler.root_topic.encode())
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'info message\n'
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
assert msg == b'UNITTEST DEBUG debug message'
|
||||
logger.removeHandler(handler)
|
370
.venv/Lib/site-packages/zmq/tests/test_message.py
Normal file
370
.venv/Lib/site-packages/zmq/tests/test_message.py
Normal file
@ -0,0 +1,370 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import sys
|
||||
|
||||
try:
|
||||
from sys import getrefcount
|
||||
except ImportError:
|
||||
grc = None
|
||||
else:
|
||||
grc = getrefcount
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest, skip_pypy
|
||||
|
||||
# some useful constants:
|
||||
|
||||
x = b'x'
|
||||
|
||||
if grc:
|
||||
rc0 = grc(x)
|
||||
v = memoryview(x)
|
||||
view_rc = grc(x) - rc0
|
||||
|
||||
|
||||
def await_gc(obj, rc):
|
||||
"""wait for refcount on an object to drop to an expected value
|
||||
|
||||
Necessary because of the zero-copy gc thread,
|
||||
which can take some time to receive its DECREF message.
|
||||
"""
|
||||
# count refs for this function
|
||||
if sys.version_info < (3, 11):
|
||||
my_refs = 2
|
||||
else:
|
||||
my_refs = 1
|
||||
for i in range(50):
|
||||
# rc + 2 because of the refs in this function
|
||||
if grc(obj) <= rc + my_refs:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
class TestFrame(BaseZMQTestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
|
||||
@skip_pypy
|
||||
def test_above_30(self):
|
||||
"""Message above 30 bytes are never copied by 0MQ."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
assert grc(s) == rc + 2
|
||||
del m
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
del s
|
||||
|
||||
def test_str(self):
|
||||
"""Test the str representations of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
m_str = str(m)
|
||||
m_str_b = m_str.encode()
|
||||
assert s == m_str_b
|
||||
|
||||
def test_bytes(self):
|
||||
"""Test the Frame.bytes property."""
|
||||
for i in range(1, 16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
b = m.bytes
|
||||
assert s == m.bytes
|
||||
if not PYPY:
|
||||
# check that it copies
|
||||
assert b is not s
|
||||
# check that it copies only once
|
||||
assert b is m.bytes
|
||||
|
||||
def test_unicode(self):
|
||||
"""Test the unicode representations of the Frames."""
|
||||
s = 'asdf'
|
||||
self.assertRaises(TypeError, zmq.Frame, s)
|
||||
for i in range(16):
|
||||
s = (2**i) * '§'
|
||||
m = zmq.Frame(s.encode('utf8'))
|
||||
assert s == m.bytes.decode('utf8')
|
||||
|
||||
def test_len(self):
|
||||
"""Test the len of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i) * x
|
||||
m = zmq.Frame(s)
|
||||
assert len(s) == len(m)
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle1(self):
|
||||
"""Run through a ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = rc_0 = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
assert grc(s) == rc
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
assert grc(s) == rc
|
||||
# no increase in refcount for accessing buffer
|
||||
# which references m2 directly
|
||||
buf = m2.buffer
|
||||
assert grc(s) == rc
|
||||
|
||||
assert s == str(m).encode()
|
||||
assert s == bytes(m2)
|
||||
assert s == m.bytes
|
||||
assert s == bytes(buf)
|
||||
# assert s is str(m)
|
||||
# assert s is str(m2)
|
||||
del m2
|
||||
assert grc(s) == rc
|
||||
# buf holds direct reference to m2 which holds
|
||||
del buf
|
||||
rc -= 1
|
||||
assert grc(s) == rc
|
||||
del m
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
assert rc == rc_0
|
||||
del s
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle2(self):
|
||||
"""Run through a different ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i) * x
|
||||
rc = rc_0 = grc(s)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
assert grc(s) == rc
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
assert grc(s) == rc
|
||||
# no increase in refcount for accessing buffer
|
||||
# which references m directly
|
||||
buf = m.buffer
|
||||
assert grc(s) == rc
|
||||
assert s == str(m).encode()
|
||||
assert s == bytes(m2)
|
||||
assert s == m2.bytes
|
||||
assert s == m.bytes
|
||||
assert s == bytes(buf)
|
||||
# assert s is str(m)
|
||||
# assert s is str(m2)
|
||||
del buf
|
||||
assert grc(s) == rc
|
||||
del m
|
||||
rc -= 1
|
||||
assert grc(s) == rc
|
||||
del m2
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
assert grc(s) == rc
|
||||
assert rc == rc_0
|
||||
del s
|
||||
|
||||
def test_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
assert not m.tracker.done
|
||||
pm = zmq.MessageTracker(m)
|
||||
assert not pm.done
|
||||
del m
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
for i in range(10):
|
||||
if pm.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert pm.done
|
||||
|
||||
def test_no_tracker(self):
|
||||
m = zmq.Frame(b'asdf', track=False)
|
||||
assert m.tracker == None
|
||||
m2 = copy.copy(m)
|
||||
assert m2.tracker == None
|
||||
self.assertRaises(ValueError, zmq.MessageTracker, m)
|
||||
|
||||
def test_multi_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
m2 = zmq.Frame(b'whoda', copy=False, track=True)
|
||||
mt = zmq.MessageTracker(m, m2)
|
||||
assert not m.tracker.done
|
||||
assert not mt.done
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
del m
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
assert not mt.done
|
||||
del m2
|
||||
for i in range(3):
|
||||
gc.collect()
|
||||
assert mt.wait(0.1) is None
|
||||
assert mt.done
|
||||
|
||||
def test_buffer_in(self):
|
||||
"""test using a buffer as input"""
|
||||
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
zmq.Frame(memoryview(ins))
|
||||
|
||||
def test_bad_buffer_in(self):
|
||||
"""test using a bad object"""
|
||||
self.assertRaises(TypeError, zmq.Frame, 5)
|
||||
self.assertRaises(TypeError, zmq.Frame, object())
|
||||
|
||||
def test_buffer_out(self):
|
||||
"""receiving buffered output"""
|
||||
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
m = zmq.Frame(ins)
|
||||
outb = m.buffer
|
||||
assert isinstance(outb, memoryview)
|
||||
assert outb is m.buffer
|
||||
assert m.buffer is m.buffer
|
||||
|
||||
def test_memoryview_shape(self):
|
||||
"""memoryview shape info"""
|
||||
data = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
||||
n = len(data)
|
||||
f = zmq.Frame(data)
|
||||
view1 = f.buffer
|
||||
assert view1.ndim == 1
|
||||
assert view1.shape == (n,)
|
||||
assert view1.tobytes() == data
|
||||
view2 = memoryview(f)
|
||||
assert view2.ndim == 1
|
||||
assert view2.shape == (n,)
|
||||
assert view2.tobytes() == data
|
||||
|
||||
def test_multisend(self):
|
||||
"""ensure that a message remains intact after multiple sends"""
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
s = b"message"
|
||||
m = zmq.Frame(s)
|
||||
assert s == m.bytes
|
||||
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
assert s == m.bytes
|
||||
for i in range(4):
|
||||
r = b.recv()
|
||||
assert s == r
|
||||
assert s == m.bytes
|
||||
|
||||
def test_memoryview(self):
|
||||
"""test messages from memoryview"""
|
||||
s = b'carrotjuice'
|
||||
memoryview(s)
|
||||
m = zmq.Frame(s)
|
||||
buf = m.buffer
|
||||
s2 = buf.tobytes()
|
||||
assert s2 == s
|
||||
assert m.bytes == s
|
||||
|
||||
def test_noncopying_recv(self):
|
||||
"""check for clobbering message buffers"""
|
||||
null = b'\0' * 64
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
for i in range(32):
|
||||
# try a few times
|
||||
sb.send(null, copy=False)
|
||||
m = sa.recv(copy=False)
|
||||
mb = m.bytes
|
||||
# buf = memoryview(m)
|
||||
buf = m.buffer
|
||||
del m
|
||||
for i in range(5):
|
||||
ff = b'\xff' * (40 + i * 10)
|
||||
sb.send(ff, copy=False)
|
||||
m2 = sa.recv(copy=False)
|
||||
b = buf.tobytes()
|
||||
assert b == null
|
||||
assert mb == null
|
||||
assert m2.bytes == ff
|
||||
assert type(m2.bytes) is bytes
|
||||
|
||||
def test_noncopying_memoryview(self):
|
||||
"""test non-copying memmoryview messages"""
|
||||
null = b'\0' * 64
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
for i in range(32):
|
||||
# try a few times
|
||||
sb.send(memoryview(null), copy=False)
|
||||
m = sa.recv(copy=False)
|
||||
buf = memoryview(m)
|
||||
for i in range(5):
|
||||
ff = b'\xff' * (40 + i * 10)
|
||||
sb.send(memoryview(ff), copy=False)
|
||||
m2 = sa.recv(copy=False)
|
||||
buf2 = memoryview(m2)
|
||||
assert buf.tobytes() == null
|
||||
assert not buf.readonly
|
||||
assert buf2.tobytes() == ff
|
||||
assert not buf2.readonly
|
||||
assert type(buf) is memoryview
|
||||
|
||||
def test_buffer_numpy(self):
|
||||
"""test non-copying numpy array messages"""
|
||||
try:
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal
|
||||
except ImportError:
|
||||
raise SkipTest("requires numpy")
|
||||
rand = numpy.random.randint
|
||||
shapes = [rand(2, 5) for i in range(5)]
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
dtypes = [int, float, '>i4', 'B']
|
||||
for i in range(1, len(shapes) + 1):
|
||||
shape = shapes[:i]
|
||||
for dt in dtypes:
|
||||
A = numpy.empty(shape, dtype=dt)
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
|
||||
A['a'] = 1024
|
||||
A['b'] = 1e9
|
||||
A['c'] = 'hello there'
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
@skip_pypy
|
||||
def test_frame_more(self):
|
||||
"""test Frame.more attribute"""
|
||||
frame = zmq.Frame(b"hello")
|
||||
assert not frame.more
|
||||
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
sa.send_multipart([b'hi', b'there'])
|
||||
frame = self.recv(sb, copy=False)
|
||||
assert frame.more
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
assert frame.get(zmq.MORE)
|
||||
frame = self.recv(sb, copy=False)
|
||||
assert not frame.more
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
assert not frame.get(zmq.MORE)
|
76
.venv/Lib/site-packages/zmq/tests/test_monitor.py
Normal file
76
.venv/Lib/site-packages/zmq/tests/test_monitor.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, require_zmq_4
|
||||
from zmq.utils.monitor import recv_monitor_message
|
||||
|
||||
|
||||
class TestSocketMonitor(BaseZMQTestCase):
|
||||
@require_zmq_4
|
||||
def test_monitor(self):
|
||||
"""Test monitoring interface for sockets."""
|
||||
s_rep = self.context.socket(zmq.REP)
|
||||
s_req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([s_rep, s_req])
|
||||
s_req.bind("tcp://127.0.0.1:6666")
|
||||
# try monitoring the REP socket
|
||||
|
||||
s_rep.monitor(
|
||||
"inproc://monitor.rep",
|
||||
zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED,
|
||||
)
|
||||
# create listening socket for monitor
|
||||
s_event = self.context.socket(zmq.PAIR)
|
||||
self.sockets.append(s_event)
|
||||
s_event.connect("inproc://monitor.rep")
|
||||
s_event.linger = 0
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6666")
|
||||
m = recv_monitor_message(s_event)
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
assert m['event'] == zmq.EVENT_CONNECTED
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6666"
|
||||
|
||||
# test monitor can be disabled.
|
||||
s_rep.disable_monitor()
|
||||
m = recv_monitor_message(s_event)
|
||||
assert m['event'] == zmq.EVENT_MONITOR_STOPPED
|
||||
|
||||
@require_zmq_4
|
||||
def test_monitor_repeat(self):
|
||||
s = self.socket(zmq.PULL)
|
||||
m = s.get_monitor_socket()
|
||||
self.sockets.append(m)
|
||||
m2 = s.get_monitor_socket()
|
||||
assert m is m2
|
||||
s.disable_monitor()
|
||||
evt = recv_monitor_message(m)
|
||||
assert evt['event'] == zmq.EVENT_MONITOR_STOPPED
|
||||
m.close()
|
||||
s.close()
|
||||
|
||||
@require_zmq_4
|
||||
def test_monitor_connected(self):
|
||||
"""Test connected monitoring socket."""
|
||||
s_rep = self.context.socket(zmq.REP)
|
||||
s_req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([s_rep, s_req])
|
||||
s_req.bind("tcp://127.0.0.1:6667")
|
||||
# try monitoring the REP socket
|
||||
# create listening socket for monitor
|
||||
s_event = s_rep.get_monitor_socket()
|
||||
s_event.linger = 0
|
||||
self.sockets.append(s_event)
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6667")
|
||||
m = recv_monitor_message(s_event)
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
assert m['event'] == zmq.EVENT_CONNECTED
|
||||
assert m['endpoint'] == b"tcp://127.0.0.1:6667"
|
219
.venv/Lib/site-packages/zmq/tests/test_monqueue.py
Normal file
219
.venv/Lib/site-packages/zmq/tests/test_monqueue.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase
|
||||
|
||||
if PYPY or zmq.zmq_version_info() >= (4, 1):
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestMonitoredQueue(BaseZMQTestCase):
|
||||
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
|
||||
self.device = devices.ThreadMonitoredQueue(
|
||||
zmq.PAIR, zmq.PAIR, zmq.PUB, in_prefix, out_prefix
|
||||
)
|
||||
alice = self.context.socket(zmq.PAIR)
|
||||
bob = self.context.socket(zmq.PAIR)
|
||||
mon = self.context.socket(zmq.SUB)
|
||||
|
||||
aport = alice.bind_to_random_port('tcp://127.0.0.1')
|
||||
bport = bob.bind_to_random_port('tcp://127.0.0.1')
|
||||
mport = mon.bind_to_random_port('tcp://127.0.0.1')
|
||||
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
|
||||
|
||||
self.device.connect_in("tcp://127.0.0.1:%i" % aport)
|
||||
self.device.connect_out("tcp://127.0.0.1:%i" % bport)
|
||||
self.device.connect_mon("tcp://127.0.0.1:%i" % mport)
|
||||
self.device.start()
|
||||
time.sleep(0.2)
|
||||
try:
|
||||
# this is currenlty necessary to ensure no dropped monitor messages
|
||||
# see LIBZMQ-248 for more info
|
||||
mon.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
self.sockets.extend([alice, bob, mon])
|
||||
return alice, bob, mon
|
||||
|
||||
def teardown_device(self):
|
||||
for socket in self.sockets:
|
||||
socket.close()
|
||||
del socket
|
||||
del self.device
|
||||
|
||||
def test_reply(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
self.teardown_device()
|
||||
|
||||
def test_queue(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + bobs == mons
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + alices2 == mons
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'in'] + alices3 == mons
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'out'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_prefix(self):
|
||||
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + bobs == mons
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + alices2 == mons
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'foo'] + alices3 == mons
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'bar'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor_subscribe(self):
|
||||
alice, bob, mon = self.build_device(b"out")
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices2 == bobs
|
||||
bobs = self.recv_multipart(bob)
|
||||
assert alices3 == bobs
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
assert alices == bobs
|
||||
mons = self.recv_multipart(mon)
|
||||
assert [b'out'] + bobs == mons
|
||||
self.teardown_device()
|
||||
|
||||
def test_router_router(self):
|
||||
"""test router-router MQ devices"""
|
||||
dev = devices.ThreadMonitoredQueue(
|
||||
zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out'
|
||||
)
|
||||
self.device = dev
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
|
||||
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
|
||||
a = self.context.socket(zmq.DEALER)
|
||||
a.identity = b'a'
|
||||
b = self.context.socket(zmq.DEALER)
|
||||
b.identity = b'b'
|
||||
self.sockets.extend([a, b])
|
||||
|
||||
a.connect('tcp://127.0.0.1:%i' % porta)
|
||||
b.connect('tcp://127.0.0.1:%i' % portb)
|
||||
dev.start()
|
||||
time.sleep(1)
|
||||
if zmq.zmq_version_info() >= (3, 1, 0):
|
||||
# flush erroneous poll state, due to LIBZMQ-280
|
||||
ping_msg = [b'ping', b'pong']
|
||||
for s in (a, b):
|
||||
s.send_multipart(ping_msg)
|
||||
try:
|
||||
s.recv(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
msg = [b'hello', b'there']
|
||||
a.send_multipart([b'b'] + msg)
|
||||
bmsg = self.recv_multipart(b)
|
||||
assert bmsg == [b'a'] + msg
|
||||
b.send_multipart(bmsg)
|
||||
amsg = self.recv_multipart(a)
|
||||
assert amsg == [b'b'] + msg
|
||||
self.teardown_device()
|
||||
|
||||
def test_default_mq_args(self):
|
||||
self.device = dev = devices.ThreadMonitoredQueue(
|
||||
zmq.ROUTER, zmq.DEALER, zmq.PUB
|
||||
)
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
# this will raise if default args are wrong
|
||||
dev.start()
|
||||
self.teardown_device()
|
||||
|
||||
def test_mq_check_prefix(self):
|
||||
ins = self.context.socket(zmq.ROUTER)
|
||||
outs = self.context.socket(zmq.DEALER)
|
||||
mons = self.context.socket(zmq.PUB)
|
||||
self.sockets.extend([ins, outs, mons])
|
||||
|
||||
ins = 'in'
|
||||
outs = 'out'
|
||||
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)
|
34
.venv/Lib/site-packages/zmq/tests/test_multipart.py
Normal file
34
.venv/Lib/site-packages/zmq/tests/test_multipart.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestMultipart(BaseZMQTestCase):
|
||||
def test_router_dealer(self):
|
||||
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
|
||||
msg1 = b'message1'
|
||||
dealer.send(msg1)
|
||||
self.recv(router)
|
||||
more = router.rcvmore
|
||||
assert more == True
|
||||
msg2 = self.recv(router)
|
||||
assert msg1 == msg2
|
||||
more = router.rcvmore
|
||||
assert more == False
|
||||
|
||||
def test_basic_multipart(self):
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
msg = [b'hi', b'there', b'b']
|
||||
a.send_multipart(msg)
|
||||
recvd = b.recv_multipart()
|
||||
assert msg == recvd
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestMultipartGreen(GreenTest, TestMultipart):
|
||||
pass
|
73
.venv/Lib/site-packages/zmq/tests/test_mypy.py
Normal file
73
.venv/Lib/site-packages/zmq/tests/test_mypy.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
Test our typing with mypy
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
pytest.importorskip("mypy")
|
||||
|
||||
zmq_dir = os.path.dirname(zmq.__file__)
|
||||
|
||||
|
||||
def resolve_repo_dir(path):
|
||||
"""Resolve a dir in the repo
|
||||
|
||||
Resolved relative to zmq dir
|
||||
|
||||
fallback on CWD (e.g. test run from repo, zmq installed, not -e)
|
||||
"""
|
||||
resolved_path = os.path.join(os.path.dirname(zmq_dir), path)
|
||||
|
||||
# fallback on CWD
|
||||
if not os.path.exists(resolved_path):
|
||||
resolved_path = path
|
||||
return resolved_path
|
||||
|
||||
|
||||
examples_dir = resolve_repo_dir("examples")
|
||||
mypy_dir = resolve_repo_dir("mypy_tests")
|
||||
|
||||
|
||||
def run_mypy(*mypy_args):
|
||||
"""Run mypy for a path
|
||||
|
||||
Captures output and reports it on errors
|
||||
"""
|
||||
p = Popen(
|
||||
[sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
|
||||
)
|
||||
o, _ = p.communicate()
|
||||
out = o.decode("utf8", "replace")
|
||||
print(out)
|
||||
assert p.returncode == 0, out
|
||||
|
||||
|
||||
if os.path.exists(examples_dir):
|
||||
examples = [
|
||||
d
|
||||
for d in os.listdir(examples_dir)
|
||||
if os.path.isdir(os.path.join(examples_dir, d))
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.path.exists(examples_dir), reason="only test from examples directory"
|
||||
)
|
||||
@pytest.mark.parametrize("example", examples)
|
||||
def test_mypy_example(example):
|
||||
example_dir = os.path.join(examples_dir, example)
|
||||
run_mypy("--disallow-untyped-calls", example_dir)
|
||||
|
||||
|
||||
if os.path.exists(mypy_dir):
|
||||
mypy_tests = [p for p in os.listdir(mypy_dir) if p.endswith(".py")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename", mypy_tests)
|
||||
def test_mypy(filename):
|
||||
run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename))
|
52
.venv/Lib/site-packages/zmq/tests/test_pair.py
Normal file
52
.venv/Lib/site-packages/zmq/tests/test_pair.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
x = b' '
|
||||
|
||||
|
||||
class TestPair(BaseZMQTestCase):
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
msg1 = b'message1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
for i in range(10):
|
||||
msg = i * x
|
||||
s1.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = i * x
|
||||
s2.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = s1.recv()
|
||||
assert msg == i * x
|
||||
|
||||
for i in range(10):
|
||||
msg = s2.recv()
|
||||
assert msg == i * x
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10, b=list(range(10)))
|
||||
self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10, b=range(10))
|
||||
self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestReqRepGreen(GreenTest, TestPair):
|
||||
pass
|
239
.venv/Lib/site-packages/zmq/tests/test_poll.py
Normal file
239
.venv/Lib/site-packages/zmq/tests/test_poll.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import GreenTest, PollZMQTestCase, have_gevent
|
||||
|
||||
|
||||
def wait():
|
||||
time.sleep(0.25)
|
||||
|
||||
|
||||
class TestPoll(PollZMQTestCase):
|
||||
|
||||
Poller = zmq.Poller
|
||||
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
|
||||
# Poll result should contain both sockets
|
||||
socks = dict(poller.poll())
|
||||
# Now make sure that both are send ready.
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
|
||||
s1.send(b'msg1')
|
||||
s2.send(b'msg2')
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT | zmq.POLLIN
|
||||
assert socks[s2] == zmq.POLLOUT | zmq.POLLIN
|
||||
# Make sure that both are in POLLOUT after recv.
|
||||
s1.recv()
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
def test_reqrep(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN | zmq.POLLOUT)
|
||||
|
||||
# Make sure that s1 is in state 0 and s2 is in POLLOUT
|
||||
socks = dict(poller.poll())
|
||||
assert s1 not in socks
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
# Make sure that s2 goes immediately into state 0 after send.
|
||||
s2.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
assert s2 not in socks
|
||||
|
||||
# Make sure that s1 goes into POLLIN state after a time.sleep().
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLIN
|
||||
|
||||
# Make sure that s1 goes into POLLOUT after recv.
|
||||
s1.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
|
||||
# Make sure s1 goes into state 0 after send.
|
||||
s1.send(b'msg2')
|
||||
socks = dict(poller.poll())
|
||||
assert s1 not in socks
|
||||
|
||||
# Wait and then see that s2 is in POLLIN.
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLIN
|
||||
|
||||
# Make sure that s2 is in POLLOUT after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLOUT
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
def test_no_events(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, 0)
|
||||
assert s1 in poller
|
||||
assert s2 not in poller
|
||||
poller.register(s1, 0)
|
||||
assert s1 not in poller
|
||||
|
||||
def test_pubsub(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN | zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
# Now make sure that both are send ready.
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
assert s2 not in socks
|
||||
# Make sure that s1 stays in POLLOUT after a send.
|
||||
s1.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s1] == zmq.POLLOUT
|
||||
|
||||
# Make sure that s2 is POLLIN after waiting.
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
assert socks[s2] == zmq.POLLIN
|
||||
|
||||
# Make sure that s2 goes into 0 after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
assert s2 not in socks
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
|
||||
def test_raw(self):
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
p = self.Poller()
|
||||
p.register(r, zmq.POLLIN)
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {}
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {r.fileno(): zmq.POLLIN}
|
||||
w.close()
|
||||
r.close()
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_timeout(self):
|
||||
"""make sure Poller.poll timeout has the right units (milliseconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN)
|
||||
tic = time.perf_counter()
|
||||
poller.poll(0.005)
|
||||
toc = time.perf_counter()
|
||||
toc - tic < 0.1
|
||||
tic = time.perf_counter()
|
||||
poller.poll(50)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 0.1
|
||||
assert toc - tic > 0.01
|
||||
tic = time.perf_counter()
|
||||
poller.poll(500)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.1
|
||||
|
||||
|
||||
class TestSelect(PollZMQTestCase):
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
|
||||
assert s1 in wlist
|
||||
assert s2 in wlist
|
||||
assert s1 not in rlist
|
||||
assert s2 not in rlist
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_timeout(self):
|
||||
"""make sure select timeout has the right units (seconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
tic = time.perf_counter()
|
||||
r, w, x = zmq.select([s1, s2], [], [], 0.005)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.001
|
||||
tic = time.perf_counter()
|
||||
r, w, x = zmq.select([s1, s2], [], [], 0.25)
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
assert toc - tic > 0.1
|
||||
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
|
||||
from zmq import green as gzmq
|
||||
|
||||
class TestPollGreen(GreenTest, TestPoll):
|
||||
Poller = gzmq.Poller
|
||||
|
||||
def test_wakeup(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
tic = time.perf_counter()
|
||||
r = gevent.spawn(lambda: poller.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
||||
|
||||
def test_socket_poll(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
tic = time.perf_counter()
|
||||
r = gevent.spawn(lambda: s2.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.perf_counter()
|
||||
assert toc - tic < 1
|
95
.venv/Lib/site-packages/zmq/tests/test_proxy_steerable.py
Normal file
95
.venv/Lib/site-packages/zmq/tests/test_proxy_steerable.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import struct
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestProxySteerable(BaseZMQTestCase):
|
||||
def test_proxy_steerable(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
||||
|
||||
def test_proxy_steerable_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend(
|
||||
[
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max),
|
||||
]
|
||||
)
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_proxy_steerable_statistics(self):
|
||||
if zmq.zmq_version_info() < (4, 3):
|
||||
raise SkipTest("STATISTICS only in libzmq >= 4.3")
|
||||
dev = devices.ThreadProxySteerable(zmq.PULL, zmq.PUSH, zmq.PUSH, zmq.PAIR)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
assert msg == self.recv(pull)
|
||||
assert msg == self.recv(mon)
|
||||
ctrl.send(b'STATISTICS')
|
||||
stats = self.recv_multipart(ctrl)
|
||||
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
|
||||
assert 1 == stats_int[0]
|
||||
assert len(msg) == stats_int[1]
|
||||
assert 1 == stats_int[6]
|
||||
assert len(msg) == stats_int[7]
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
41
.venv/Lib/site-packages/zmq/tests/test_pubsub.py
Normal file
41
.venv/Lib/site-packages/zmq/tests/test_pubsub.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestPubSub(BaseZMQTestCase):
|
||||
|
||||
pass
|
||||
|
||||
# We are disabling this test while an issue is being resolved.
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv() # This is blocking!
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_topic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'x')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
|
||||
msg1 = b'xmessage'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv()
|
||||
assert msg1 == msg2
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestPubSubGreen(GreenTest, TestPubSub):
|
||||
pass
|
61
.venv/Lib/site-packages/zmq/tests/test_reqrep.py
Normal file
61
.venv/Lib/site-packages/zmq/tests/test_reqrep.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, have_gevent
|
||||
|
||||
|
||||
class TestReqRep(BaseZMQTestCase):
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
msg1 = b'message 1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
for i in range(10):
|
||||
msg1 = i * b' '
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_bad_send_recv(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
if zmq.zmq_version() != '2.1.8':
|
||||
# this doesn't work on 2.1.8
|
||||
for copy in (True, False):
|
||||
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
|
||||
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
|
||||
|
||||
# I have to have this or we die on an Abort trap.
|
||||
msg1 = b'asdf'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10, b=list(range(10)))
|
||||
self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10, b=range(10))
|
||||
self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
def test_large_msg(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
msg1 = 10000 * b'X'
|
||||
|
||||
for i in range(10):
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
assert msg1 == msg2
|
||||
|
||||
|
||||
if have_gevent:
|
||||
|
||||
class TestReqRepGreen(GreenTest, TestReqRep):
|
||||
pass
|
94
.venv/Lib/site-packages/zmq/tests/test_retry_eintr.py
Normal file
94
.venv/Lib/site-packages/zmq/tests/test_retry_eintr.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import signal
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest
|
||||
|
||||
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
|
||||
|
||||
|
||||
class TestEINTRSysCall(BaseZMQTestCase):
|
||||
"""Base class for EINTR tests."""
|
||||
|
||||
# delay for initial signal delivery
|
||||
signal_delay = 0.1
|
||||
# timeout for tests. Must be > signal_delay
|
||||
timeout = 0.25
|
||||
timeout_ms = int(timeout * 1e3)
|
||||
|
||||
def alarm(self, t=None):
|
||||
"""start a timer to fire only once
|
||||
|
||||
like signal.alarm, but with better resolution than integer seconds.
|
||||
"""
|
||||
if not hasattr(signal, 'setitimer'):
|
||||
raise SkipTest('EINTR tests require setitimer')
|
||||
if t is None:
|
||||
t = self.signal_delay
|
||||
self.timer_fired = False
|
||||
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
|
||||
# signal_period ignored, since only one timer event is allowed to fire
|
||||
signal.setitimer(signal.ITIMER_REAL, t, 1000)
|
||||
|
||||
def stop_timer(self, *args):
|
||||
self.timer_fired = True
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, self.orig_handler)
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_retry_recv(self):
|
||||
pull = self.socket(zmq.PULL)
|
||||
pull.rcvtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, pull.recv)
|
||||
assert self.timer_fired
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_retry_send(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.sndtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, push.send, b'buf')
|
||||
assert self.timer_fired
|
||||
|
||||
@mark.flaky(reruns=3)
|
||||
def test_retry_poll(self):
|
||||
x, y = self.create_bound_pair()
|
||||
poller = zmq.Poller()
|
||||
poller.register(x, zmq.POLLIN)
|
||||
self.alarm()
|
||||
|
||||
def send():
|
||||
time.sleep(2 * self.signal_delay)
|
||||
y.send(b'ping')
|
||||
|
||||
t = Thread(target=send)
|
||||
t.start()
|
||||
evts = dict(poller.poll(2 * self.timeout_ms))
|
||||
t.join()
|
||||
assert x in evts
|
||||
assert self.timer_fired
|
||||
x.recv()
|
||||
|
||||
def test_retry_term(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.linger = self.timeout_ms
|
||||
push.connect('tcp://127.0.0.1:5555')
|
||||
push.send(b'ping')
|
||||
time.sleep(0.1)
|
||||
self.alarm()
|
||||
self.context.destroy()
|
||||
assert self.timer_fired
|
||||
assert self.context.closed
|
||||
|
||||
def test_retry_getsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt getsockopt")
|
||||
|
||||
def test_retry_setsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt setsockopt")
|
238
.venv/Lib/site-packages/zmq/tests/test_security.py
Normal file
238
.venv/Lib/site-packages/zmq/tests/test_security.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""Test libzmq security (libzmq >= 3.3.0)"""
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq.tests import PYPY, BaseZMQTestCase, SkipTest
|
||||
from zmq.utils import z85
|
||||
|
||||
USER = b"admin"
|
||||
PASS = b"password"
|
||||
|
||||
|
||||
class TestSecurity(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if zmq.zmq_version_info() < (4, 0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to be built with CURVE support")
|
||||
super().setUp()
|
||||
|
||||
def zap_handler(self):
|
||||
socket = self.context.socket(zmq.REP)
|
||||
socket.bind("inproc://zeromq.zap.01")
|
||||
try:
|
||||
msg = self.recv_multipart(socket)
|
||||
|
||||
version, sequence, domain, address, identity, mechanism = msg[:6]
|
||||
if mechanism == b'PLAIN':
|
||||
username, password = msg[6:]
|
||||
elif mechanism == b'CURVE':
|
||||
msg[6]
|
||||
|
||||
assert version == b"1.0"
|
||||
assert identity == b"IDENT"
|
||||
reply = [version, sequence]
|
||||
if (
|
||||
mechanism == b'CURVE'
|
||||
or (mechanism == b'PLAIN' and username == USER and password == PASS)
|
||||
or (mechanism == b'NULL')
|
||||
):
|
||||
reply.extend(
|
||||
[
|
||||
b"200",
|
||||
b"OK",
|
||||
b"anonymous",
|
||||
b"\5Hello\0\0\0\5World",
|
||||
]
|
||||
)
|
||||
else:
|
||||
reply.extend(
|
||||
[
|
||||
b"400",
|
||||
b"Invalid username or password",
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
socket.send_multipart(reply)
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zap(self):
|
||||
self.start_zap()
|
||||
time.sleep(0.5) # allow time for the Thread to start
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.stop_zap()
|
||||
|
||||
def start_zap(self):
|
||||
self.zap_thread = Thread(target=self.zap_handler)
|
||||
self.zap_thread.start()
|
||||
|
||||
def stop_zap(self):
|
||||
self.zap_thread.join()
|
||||
|
||||
def bounce(self, server, client, test_metadata=True):
|
||||
msg = [os.urandom(64), os.urandom(64)]
|
||||
client.send_multipart(msg)
|
||||
frames = self.recv_multipart(server, copy=False)
|
||||
recvd = list(map(lambda x: x.bytes, frames))
|
||||
|
||||
try:
|
||||
if test_metadata and not PYPY:
|
||||
for frame in frames:
|
||||
assert frame.get('User-Id') == 'anonymous'
|
||||
assert frame.get('Hello') == 'World'
|
||||
assert frame['Socket-Type'] == 'DEALER'
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
|
||||
assert recvd == msg
|
||||
server.send_multipart(recvd)
|
||||
msg2 = self.recv_multipart(client)
|
||||
assert msg2 == msg
|
||||
|
||||
def test_null(self):
|
||||
"""test NULL (default) security"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
client = self.socket(zmq.DEALER)
|
||||
assert client.MECHANISM == zmq.NULL
|
||||
assert server.mechanism == zmq.NULL
|
||||
assert client.plain_server == 0
|
||||
assert server.plain_server == 0
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client, False)
|
||||
|
||||
def test_plain(self):
|
||||
"""test PLAIN authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
assert client.plain_username == b''
|
||||
assert client.plain_password == b''
|
||||
client.plain_username = USER
|
||||
client.plain_password = PASS
|
||||
assert client.getsockopt(zmq.PLAIN_USERNAME) == USER
|
||||
assert client.getsockopt(zmq.PLAIN_PASSWORD) == PASS
|
||||
assert client.plain_server == 0
|
||||
assert server.plain_server == 0
|
||||
server.plain_server = True
|
||||
assert server.mechanism == zmq.PLAIN
|
||||
assert client.mechanism == zmq.PLAIN
|
||||
|
||||
assert not client.plain_server
|
||||
assert server.plain_server
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
||||
|
||||
def skip_plain_inauth(self):
|
||||
"""test PLAIN failed authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
client.plain_username = USER
|
||||
client.plain_password = b'incorrect'
|
||||
server.plain_server = True
|
||||
assert server.mechanism == zmq.PLAIN
|
||||
assert client.mechanism == zmq.PLAIN
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
client.send(b'ping')
|
||||
server.rcvtimeo = 250
|
||||
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
|
||||
|
||||
def test_keypair(self):
|
||||
"""test curve_keypair"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
assert type(secret) == bytes
|
||||
assert type(public) == bytes
|
||||
assert len(secret) == 40
|
||||
assert len(public) == 40
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bsecret, bpublic = (z85.decode(key) for key in (public, secret))
|
||||
assert type(bsecret) == bytes
|
||||
assert type(bpublic) == bytes
|
||||
assert len(bsecret) == 32
|
||||
assert len(bpublic) == 32
|
||||
|
||||
def test_curve_public(self):
|
||||
"""test curve_public"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
if zmq.zmq_version_info() < (4, 2):
|
||||
raise SkipTest("curve_public is new in libzmq 4.2")
|
||||
|
||||
derived_public = zmq.curve_public(secret)
|
||||
|
||||
assert type(derived_public) == bytes
|
||||
assert len(derived_public) == 40
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bpublic = z85.decode(derived_public)
|
||||
assert type(bpublic) == bytes
|
||||
assert len(bpublic) == 32
|
||||
|
||||
# verify that it is equal to the known public key
|
||||
assert derived_public == public
|
||||
|
||||
def test_curve(self):
|
||||
"""test CURVE encryption"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
try:
|
||||
server.curve_server = True
|
||||
except zmq.ZMQError as e:
|
||||
# will raise EINVAL if no CURVE support
|
||||
if e.errno == zmq.EINVAL:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
server_public, server_secret = zmq.curve_keypair()
|
||||
client_public, client_secret = zmq.curve_keypair()
|
||||
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_publickey = server_public
|
||||
client.curve_serverkey = server_public
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
|
||||
assert server.mechanism == zmq.CURVE
|
||||
assert client.mechanism == zmq.CURVE
|
||||
|
||||
assert server.get(zmq.CURVE_SERVER) == True
|
||||
assert client.get(zmq.CURVE_SERVER) == False
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
671
.venv/Lib/site-packages/zmq/tests/test_socket.py
Normal file
671
.venv/Lib/site-packages/zmq/tests/test_socket.py
Normal file
@ -0,0 +1,671 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, GreenTest, SkipTest, have_gevent, skip_pypy
|
||||
|
||||
pypy = platform.python_implementation().lower() == 'pypy'
|
||||
windows = platform.platform().lower().startswith('windows')
|
||||
on_ci = bool(os.environ.get('CI'))
|
||||
|
||||
# polling on windows is slow
|
||||
POLL_TIMEOUT = 1000 if windows else 100
|
||||
|
||||
|
||||
class TestSocket(BaseZMQTestCase):
|
||||
def test_create(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
# Superluminal protocol not yet implemented
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
|
||||
s.close()
|
||||
del ctx
|
||||
|
||||
def test_context_manager(self):
|
||||
url = 'inproc://a'
|
||||
with self.Context() as ctx:
|
||||
with ctx.socket(zmq.PUSH) as a:
|
||||
a.bind(url)
|
||||
with ctx.socket(zmq.PULL) as b:
|
||||
b.connect(url)
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
assert b.closed == True
|
||||
assert a.closed == True
|
||||
assert ctx.closed == True
|
||||
|
||||
def test_connectbind_context_managers(self):
|
||||
url = 'inproc://a'
|
||||
msg = b'hi'
|
||||
with self.Context() as ctx:
|
||||
# Test connect() context manager
|
||||
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
|
||||
a.bind(url)
|
||||
connect_context = b.connect(url)
|
||||
assert f'connect={url!r}' in repr(connect_context)
|
||||
with connect_context:
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
# b should now be disconnected, so sending and receiving don't work
|
||||
with pytest.raises(zmq.Again):
|
||||
a.send(msg, flags=zmq.DONTWAIT)
|
||||
with pytest.raises(zmq.Again):
|
||||
b.recv(flags=zmq.DONTWAIT)
|
||||
a.unbind(url)
|
||||
# Test bind() context manager
|
||||
with ctx.socket(zmq.PUSH) as a, ctx.socket(zmq.PULL) as b:
|
||||
# unbind() just stops accepting of new connections, so we have to disconnect to test that
|
||||
# unbind happened.
|
||||
bind_context = a.bind(url)
|
||||
assert f'bind={url!r}' in repr(bind_context)
|
||||
with bind_context:
|
||||
b.connect(url)
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == msg
|
||||
b.disconnect(url)
|
||||
b.connect(url)
|
||||
# Since a is unbound from url, b is not connected to anything
|
||||
with pytest.raises(zmq.Again):
|
||||
a.send(msg, flags=zmq.DONTWAIT)
|
||||
with pytest.raises(zmq.Again):
|
||||
b.recv(flags=zmq.DONTWAIT)
|
||||
|
||||
_repr_cls = "zmq.Socket"
|
||||
|
||||
def test_repr(self):
|
||||
with self.context.socket(zmq.PUSH) as s:
|
||||
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
|
||||
assert 'closed' not in repr(s)
|
||||
assert f'{self._repr_cls}(zmq.PUSH)' in repr(s)
|
||||
assert 'closed' in repr(s)
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
assert 'send' in dir(s)
|
||||
assert 'IDENTITY' in dir(s)
|
||||
assert 'AFFINITY' in dir(s)
|
||||
assert 'FD' in dir(s)
|
||||
s.close()
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
s = self.socket(zmq.SUB)
|
||||
m = mock.Mock(spec=s)
|
||||
s.close()
|
||||
|
||||
def test_bind_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
p = s.bind_to_random_port("tcp://*")
|
||||
|
||||
def test_connect_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
s.connect("tcp://127.0.0.1:5555")
|
||||
|
||||
def test_bind_to_random_port(self):
|
||||
# Check that bind_to_random_port do not hide useful exception
|
||||
ctx = self.Context()
|
||||
c = ctx.socket(zmq.PUB)
|
||||
# Invalid format
|
||||
try:
|
||||
c.bind_to_random_port('tcp:*')
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.EINVAL
|
||||
# Invalid protocol
|
||||
try:
|
||||
c.bind_to_random_port('rand://*')
|
||||
except zmq.ZMQError as e:
|
||||
assert e.errno == zmq.EPROTONOSUPPORT
|
||||
|
||||
def test_identity(self):
|
||||
s = self.context.socket(zmq.PULL)
|
||||
self.sockets.append(s)
|
||||
ident = b'identity\0\0'
|
||||
s.identity = ident
|
||||
assert s.get(zmq.IDENTITY) == ident
|
||||
|
||||
def test_unicode_sockopts(self):
|
||||
"""test setting/getting sockopts with unicode strings"""
|
||||
topic = "tést"
|
||||
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
assert s.send_unicode == s.send_unicode
|
||||
assert p.recv_unicode == p.recv_unicode
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
|
||||
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
|
||||
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
|
||||
|
||||
identb = s.getsockopt(zmq.IDENTITY)
|
||||
identu = identb.decode('utf16')
|
||||
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
|
||||
assert identu == identu2
|
||||
time.sleep(0.1) # wait for connection/subscription
|
||||
p.send_unicode(topic, zmq.SNDMORE)
|
||||
p.send_unicode(topic * 2, encoding='latin-1')
|
||||
assert topic == s.recv_unicode()
|
||||
assert topic * 2 == s.recv_unicode(encoding='latin-1')
|
||||
|
||||
def test_int_sockopts(self):
|
||||
"test integer sockopts"
|
||||
v = zmq.zmq_version_info()
|
||||
if v < (3, 0):
|
||||
default_hwm = 0
|
||||
else:
|
||||
default_hwm = 1000
|
||||
p, s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
p.setsockopt(zmq.LINGER, 0)
|
||||
assert p.getsockopt(zmq.LINGER) == 0
|
||||
p.setsockopt(zmq.LINGER, -1)
|
||||
assert p.getsockopt(zmq.LINGER) == -1
|
||||
assert p.hwm == default_hwm
|
||||
p.hwm = 11
|
||||
assert p.hwm == 11
|
||||
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
|
||||
assert p.getsockopt(zmq.EVENTS) == zmq.POLLOUT
|
||||
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt, zmq.EVENTS, 2**7 - 1)
|
||||
assert p.getsockopt(zmq.TYPE) == p.socket_type
|
||||
assert p.getsockopt(zmq.TYPE) == zmq.PUB
|
||||
assert s.getsockopt(zmq.TYPE) == s.socket_type
|
||||
assert s.getsockopt(zmq.TYPE) == zmq.SUB
|
||||
|
||||
# check for overflow / wrong type:
|
||||
errors = []
|
||||
backref = {}
|
||||
constants = zmq.constants
|
||||
for name in constants.__all__:
|
||||
value = getattr(constants, name)
|
||||
if isinstance(value, int):
|
||||
backref[value] = name
|
||||
for opt in zmq.constants.SocketOption:
|
||||
if opt._opt_type not in {
|
||||
zmq.constants._OptType.int,
|
||||
zmq.constants._OptType.int64,
|
||||
}:
|
||||
continue
|
||||
if opt.name.startswith(
|
||||
(
|
||||
'HWM',
|
||||
'ROUTER',
|
||||
'XPUB',
|
||||
'TCP',
|
||||
'FAIL',
|
||||
'REQ_',
|
||||
'CURVE_',
|
||||
'PROBE_ROUTER',
|
||||
'IPC_FILTER',
|
||||
'GSSAPI',
|
||||
'STREAM_',
|
||||
'VMCI_BUFFER_SIZE',
|
||||
'VMCI_BUFFER_MIN_SIZE',
|
||||
'VMCI_BUFFER_MAX_SIZE',
|
||||
'VMCI_CONNECT_TIMEOUT',
|
||||
'BLOCKY',
|
||||
'IN_BATCH_SIZE',
|
||||
'OUT_BATCH_SIZE',
|
||||
'WSS_TRUST_SYSTEM',
|
||||
'ONLY_FIRST_SUBSCRIBE',
|
||||
'PRIORITY',
|
||||
'RECONNECT_STOP',
|
||||
)
|
||||
):
|
||||
# some sockopts are write-only
|
||||
continue
|
||||
try:
|
||||
n = p.getsockopt(opt)
|
||||
except zmq.ZMQError as e:
|
||||
errors.append(f"getsockopt({opt!r}) raised {e}.")
|
||||
else:
|
||||
if n > 2**31:
|
||||
errors.append(
|
||||
f"getsockopt({opt!r}) returned a ridiculous value."
|
||||
" It is probably the wrong type."
|
||||
)
|
||||
if errors:
|
||||
self.fail('\n'.join([''] + errors))
|
||||
|
||||
def test_bad_sockopts(self):
|
||||
"""Test that appropriate errors are raised on bad socket options"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
s.setsockopt(zmq.LINGER, 0)
|
||||
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
|
||||
# but only int sockopts are allowed through this way, otherwise raise a TypeError
|
||||
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
|
||||
# some sockopts are valid in general, but not on every socket:
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
|
||||
|
||||
def test_sockopt_roundtrip(self):
|
||||
"test set/getsockopt roundtrip."
|
||||
p = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(p)
|
||||
p.setsockopt(zmq.LINGER, 11)
|
||||
assert p.getsockopt(zmq.LINGER) == 11
|
||||
|
||||
def test_send_unicode(self):
|
||||
"test sending unicode objects"
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a, b])
|
||||
u = "çπ§"
|
||||
self.assertRaises(TypeError, a.send, u, copy=False)
|
||||
self.assertRaises(TypeError, a.send, u, copy=True)
|
||||
a.send_unicode(u)
|
||||
s = b.recv()
|
||||
assert s == u.encode('utf8')
|
||||
assert s.decode('utf8') == u
|
||||
a.send_unicode(u, encoding='utf16')
|
||||
s = b.recv_unicode(encoding='utf16')
|
||||
assert s == u
|
||||
|
||||
def test_send_multipart_check_type(self):
|
||||
"check type on all frames in send_multipart"
|
||||
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a, b])
|
||||
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
|
||||
a.send_multipart([b'b'])
|
||||
rcvd = self.recv_multipart(b)
|
||||
assert rcvd == [b'b']
|
||||
|
||||
@skip_pypy
|
||||
def test_tracker(self):
|
||||
"test the MessageTracker object for tracking when zmq is done with a buffer"
|
||||
addr = 'tcp://127.0.0.1'
|
||||
# get a port:
|
||||
sock = socket.socket()
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
port = sock.getsockname()[1]
|
||||
iface = "%s:%i" % (addr, port)
|
||||
sock.close()
|
||||
time.sleep(0.1)
|
||||
|
||||
a = self.context.socket(zmq.PUSH)
|
||||
b = self.context.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.connect(iface)
|
||||
time.sleep(0.1)
|
||||
p1 = a.send(b'something', copy=False, track=True)
|
||||
assert isinstance(p1, zmq.MessageTracker)
|
||||
assert p1 is zmq._FINISHED_TRACKER
|
||||
# small message, should start done
|
||||
assert p1.done
|
||||
|
||||
# disable zero-copy threshold
|
||||
a.copy_threshold = 0
|
||||
|
||||
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
|
||||
assert isinstance(p2, zmq.MessageTracker)
|
||||
assert not p2.done
|
||||
|
||||
b.bind(iface)
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p1.done == True
|
||||
assert msg == [b'something']
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p2.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p2.done == True
|
||||
assert msg == [b'something', b'else']
|
||||
m = zmq.Frame(b"again", copy=False, track=True)
|
||||
assert m.tracker.done == False
|
||||
p1 = a.send(m, copy=False)
|
||||
p2 = a.send(m, copy=False)
|
||||
assert m.tracker.done == False
|
||||
assert p1.done == False
|
||||
assert p2.done == False
|
||||
msg = self.recv_multipart(b)
|
||||
assert m.tracker.done == False
|
||||
assert msg == [b'again']
|
||||
msg = self.recv_multipart(b)
|
||||
assert m.tracker.done == False
|
||||
assert msg == [b'again']
|
||||
assert p1.done == False
|
||||
assert p2.done == False
|
||||
m.tracker
|
||||
del m
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
assert p1.done == True
|
||||
assert p2.done == True
|
||||
m = zmq.Frame(b'something', track=False)
|
||||
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
|
||||
|
||||
def test_close(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
|
||||
del ctx
|
||||
|
||||
def test_attr(self):
|
||||
"""set setting/getting sockopts as attributes"""
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
linger = 10
|
||||
s.linger = linger
|
||||
assert linger == s.linger
|
||||
assert linger == s.getsockopt(zmq.LINGER)
|
||||
assert s.fd == s.getsockopt(zmq.FD)
|
||||
|
||||
def test_bad_attr(self):
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.apple = 'foo'
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad setattr should have raised AttributeError")
|
||||
try:
|
||||
s.apple
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad getattr should have raised AttributeError")
|
||||
|
||||
def test_subclass(self):
|
||||
"""subclasses can assign attributes"""
|
||||
|
||||
class S(zmq.Socket):
|
||||
a = None
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
self.a = -1
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
s = S(self.context, zmq.REP)
|
||||
self.sockets.append(s)
|
||||
assert s.a == -1
|
||||
s.a = 1
|
||||
assert s.a == 1
|
||||
a = s.a
|
||||
assert a == 1
|
||||
|
||||
def test_recv_multipart(self):
|
||||
a, b = self.create_bound_pair()
|
||||
msg = b'hi'
|
||||
for i in range(3):
|
||||
a.send(msg)
|
||||
time.sleep(0.1)
|
||||
for i in range(3):
|
||||
assert self.recv_multipart(b) == [msg]
|
||||
|
||||
def test_close_after_destroy(self):
|
||||
"""s.close() after ctx.destroy() should be fine"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REP)
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
s.close()
|
||||
assert s.closed
|
||||
|
||||
def test_poll(self):
|
||||
a, b = self.create_bound_pair()
|
||||
time.time()
|
||||
evt = a.poll(POLL_TIMEOUT)
|
||||
assert evt == 0
|
||||
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
|
||||
assert evt == zmq.POLLOUT
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
assert evt == zmq.POLLIN
|
||||
msg2 = self.recv(b)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
assert evt == 0
|
||||
assert msg2 == msg
|
||||
|
||||
def test_ipc_path_max_length(self):
|
||||
"""IPC_PATH_MAX_LEN is a sensible value"""
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
|
||||
assert zmq.IPC_PATH_MAX_LEN > 30, msg
|
||||
assert zmq.IPC_PATH_MAX_LEN < 1025, msg
|
||||
|
||||
def test_ipc_path_max_length_msg(self):
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.bind('ipc://{}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
|
||||
except zmq.ZMQError as e:
|
||||
assert str(zmq.IPC_PATH_MAX_LEN) in e.strerror
|
||||
|
||||
@mark.skipif(windows, reason="ipc not supported on Windows.")
|
||||
def test_ipc_path_no_such_file_or_directory_message(self):
|
||||
"""Display the ipc path in case of an ENOENT exception"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
invalid_path = '/foo/bar'
|
||||
with pytest.raises(zmq.ZMQError) as error:
|
||||
s.bind(f'ipc://{invalid_path}')
|
||||
assert error.value.errno == errno.ENOENT
|
||||
error_message = str(error.value)
|
||||
assert invalid_path in error_message
|
||||
assert "no such file or directory" in error_message.lower()
|
||||
|
||||
def test_hwm(self):
|
||||
zmq3 = zmq.zmq_version_info()[0] >= 3
|
||||
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
|
||||
s = self.context.socket(stype)
|
||||
s.hwm = 100
|
||||
assert s.hwm == 100
|
||||
if zmq3:
|
||||
try:
|
||||
assert s.sndhwm == 100
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
assert s.rcvhwm == 100
|
||||
except AttributeError:
|
||||
pass
|
||||
s.close()
|
||||
|
||||
def test_copy(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
scopy = copy.copy(s)
|
||||
sdcopy = copy.deepcopy(s)
|
||||
assert scopy._shadow
|
||||
assert sdcopy._shadow
|
||||
assert s.underlying == scopy.underlying
|
||||
assert s.underlying == sdcopy.underlying
|
||||
s.close()
|
||||
|
||||
def test_send_buffer(self):
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
for buffer_type in (memoryview, bytearray):
|
||||
rawbytes = str(buffer_type).encode('ascii')
|
||||
msg = buffer_type(rawbytes)
|
||||
a.send(msg)
|
||||
recvd = b.recv()
|
||||
assert recvd == rawbytes
|
||||
|
||||
def test_shadow(self):
|
||||
p = self.socket(zmq.PUSH)
|
||||
p.bind("tcp://127.0.0.1:5555")
|
||||
p2 = zmq.Socket.shadow(p.underlying)
|
||||
assert p.underlying == p2.underlying
|
||||
s = self.socket(zmq.PULL)
|
||||
s2 = zmq.Socket.shadow(s.underlying)
|
||||
self.assertNotEqual(s.underlying, p.underlying)
|
||||
assert s.underlying == s2.underlying
|
||||
s2.connect("tcp://127.0.0.1:5555")
|
||||
sent = b'hi'
|
||||
p2.send(sent)
|
||||
rcvd = self.recv(s2)
|
||||
assert rcvd == sent
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
ca = zsocket.new(ctx, zmq.PUSH)
|
||||
cb = zsocket.new(ctx, zmq.PULL)
|
||||
a = zmq.Socket.shadow(ca)
|
||||
b = zmq.Socket.shadow(cb)
|
||||
a.bind("inproc://a")
|
||||
b.connect("inproc://a")
|
||||
a.send(b'hi')
|
||||
rcvd = self.recv(b)
|
||||
assert rcvd == b'hi'
|
||||
|
||||
def test_subscribe_method(self):
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
sub.subscribe('prefix')
|
||||
sub.subscribe = 'c'
|
||||
p = zmq.Poller()
|
||||
p.register(sub, zmq.POLLIN)
|
||||
# wait for subscription handshake
|
||||
for i in range(100):
|
||||
pub.send(b'canary')
|
||||
events = p.poll(250)
|
||||
if events:
|
||||
break
|
||||
self.recv(sub)
|
||||
pub.send(b'prefixmessage')
|
||||
msg = self.recv(sub)
|
||||
assert msg == b'prefixmessage'
|
||||
sub.unsubscribe('prefix')
|
||||
pub.send(b'prefixmessage')
|
||||
events = p.poll(1000)
|
||||
assert events == []
|
||||
|
||||
# CI often can't handle how much memory PyPy uses on this test
|
||||
@mark.skipif(
|
||||
(pypy and on_ci) or (sys.maxsize < 2**32) or (windows),
|
||||
reason="only run on 64b and not on CI.",
|
||||
)
|
||||
@mark.large
|
||||
def test_large_send(self):
|
||||
c = os.urandom(1)
|
||||
N = 2**31 + 1
|
||||
try:
|
||||
buf = c * N
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
a, b = self.create_bound_pair()
|
||||
try:
|
||||
a.send(buf, copy=False)
|
||||
rcvd = b.recv(copy=False)
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
# sample the front and back of the received message
|
||||
# without checking the whole content
|
||||
byte = ord(c)
|
||||
view = memoryview(rcvd)
|
||||
assert len(view) == N
|
||||
assert view[0] == byte
|
||||
assert view[-1] == byte
|
||||
|
||||
def test_custom_serialize(self):
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
a.send_serialized(msg, serialize)
|
||||
recvd = b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
b.send_serialized(recvd, serialize)
|
||||
r2 = a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
|
||||
|
||||
if have_gevent and not windows:
|
||||
import gevent
|
||||
|
||||
class TestSocketGreen(GreenTest, TestSocket):
|
||||
test_bad_attr = GreenTest.skip_green
|
||||
test_close_after_destroy = GreenTest.skip_green
|
||||
_repr_cls = "zmq.green.Socket"
|
||||
|
||||
def test_timeout(self):
|
||||
a, b = self.create_bound_pair()
|
||||
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
|
||||
timeout = gevent.Timeout(0.1)
|
||||
timeout.start()
|
||||
self.assertRaises(gevent.Timeout, b.recv)
|
||||
g.kill()
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_warn_set_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.rcvtimeo = 5
|
||||
s.close()
|
||||
assert len(w) == 1
|
||||
assert w[0].category == UserWarning
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_warn_get_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.sndtimeo
|
||||
s.close()
|
||||
assert len(w) == 1
|
||||
assert w[0].category == UserWarning
|
9
.venv/Lib/site-packages/zmq/tests/test_ssh.py
Normal file
9
.venv/Lib/site-packages/zmq/tests/test_ssh.py
Normal file
@ -0,0 +1,9 @@
|
||||
from zmq.ssh.tunnel import select_random_ports
|
||||
|
||||
|
||||
def test_random_ports():
|
||||
for i in range(4096):
|
||||
ports = select_random_ports(10)
|
||||
assert len(ports) == 10
|
||||
for p in ports:
|
||||
assert ports.count(p) == 1
|
43
.venv/Lib/site-packages/zmq/tests/test_version.py
Normal file
43
.venv/Lib/site-packages/zmq/tests/test_version.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq.sugar import version
|
||||
|
||||
|
||||
class TestVersion(TestCase):
|
||||
def test_pyzmq_version(self):
|
||||
vs = zmq.pyzmq_version()
|
||||
vs2 = zmq.__version__
|
||||
assert isinstance(vs, str)
|
||||
if zmq.__revision__:
|
||||
assert vs == '@'.join(vs2, zmq.__revision__)
|
||||
else:
|
||||
assert vs == vs2
|
||||
if version.VERSION_EXTRA:
|
||||
assert version.VERSION_EXTRA in vs
|
||||
assert version.VERSION_EXTRA in vs2
|
||||
|
||||
def test_pyzmq_version_info(self):
|
||||
info = zmq.pyzmq_version_info()
|
||||
assert isinstance(info, tuple)
|
||||
for n in info[:3]:
|
||||
assert isinstance(n, int)
|
||||
if version.VERSION_EXTRA:
|
||||
assert len(info) == 4
|
||||
assert info[-1] == float('inf')
|
||||
else:
|
||||
assert len(info) == 3
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
info = zmq.zmq_version_info()
|
||||
assert isinstance(info, tuple)
|
||||
for n in info[:3]:
|
||||
assert isinstance(n, int)
|
||||
|
||||
def test_zmq_version(self):
|
||||
v = zmq.zmq_version()
|
||||
assert isinstance(v, str)
|
58
.venv/Lib/site-packages/zmq/tests/test_win32_shim.py
Normal file
58
.venv/Lib/site-packages/zmq/tests/test_win32_shim.py
Normal file
@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
from zmq.utils.win32 import allow_interrupt
|
||||
|
||||
|
||||
def count_calls(f):
|
||||
@wraps(f)
|
||||
def _(*args, **kwds):
|
||||
try:
|
||||
return f(*args, **kwds)
|
||||
finally:
|
||||
_.__calls__ += 1
|
||||
|
||||
_.__calls__ = 0
|
||||
return _
|
||||
|
||||
|
||||
@mark.new_console
|
||||
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
|
||||
@mark.new_console
|
||||
@mark.skipif(not sys.platform.startswith('win'), reason='Windows only test')
|
||||
def test_handler(self):
|
||||
@count_calls
|
||||
def interrupt_polling():
|
||||
print('Caught CTRL-C!')
|
||||
|
||||
from ctypes import windll
|
||||
from ctypes.wintypes import BOOL, DWORD
|
||||
|
||||
kernel32 = windll.LoadLibrary('kernel32')
|
||||
|
||||
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
|
||||
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
|
||||
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
|
||||
GenerateConsoleCtrlEvent.restype = BOOL
|
||||
|
||||
# Simulate CTRL-C event while handler is active.
|
||||
try:
|
||||
with allow_interrupt(interrupt_polling) as context:
|
||||
result = GenerateConsoleCtrlEvent(0, 0)
|
||||
# Sleep so that we give time to the handler to
|
||||
# capture the Ctrl-C event.
|
||||
time.sleep(0.5)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
if result == 0:
|
||||
raise OSError()
|
||||
else:
|
||||
self.fail('Expecting `KeyboardInterrupt` exception!')
|
||||
|
||||
# Make sure our handler was called.
|
||||
assert interrupt_polling.__calls__ == 1
|
65
.venv/Lib/site-packages/zmq/tests/test_z85.py
Normal file
65
.venv/Lib/site-packages/zmq/tests/test_z85.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Test Z85 encoding
|
||||
|
||||
confirm values and roundtrip with test values from the reference implementation.
|
||||
"""
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from zmq.utils import z85
|
||||
|
||||
|
||||
class TestZ85(TestCase):
|
||||
def test_client_public(self):
|
||||
client_public = (
|
||||
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B"
|
||||
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A"
|
||||
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26"
|
||||
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
|
||||
)
|
||||
encoded = z85.encode(client_public)
|
||||
|
||||
assert encoded == b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == client_public
|
||||
|
||||
def test_client_secret(self):
|
||||
client_secret = (
|
||||
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67"
|
||||
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89"
|
||||
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51"
|
||||
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
|
||||
)
|
||||
encoded = z85.encode(client_secret)
|
||||
|
||||
assert encoded == b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == client_secret
|
||||
|
||||
def test_server_public(self):
|
||||
server_public = (
|
||||
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96"
|
||||
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0"
|
||||
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5"
|
||||
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
|
||||
)
|
||||
encoded = z85.encode(server_public)
|
||||
|
||||
assert encoded == b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == server_public
|
||||
|
||||
def test_server_secret(self):
|
||||
server_secret = (
|
||||
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D"
|
||||
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0"
|
||||
b"\x4D\x48\x96\x3F\x79\x25\x98\x77"
|
||||
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
|
||||
)
|
||||
encoded = z85.encode(server_secret)
|
||||
|
||||
assert encoded == b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6"
|
||||
decoded = z85.decode(encoded)
|
||||
assert decoded == server_secret
|
83
.venv/Lib/site-packages/zmq/tests/test_zmqstream.py
Normal file
83
.venv/Lib/site-packages/zmq/tests/test_zmqstream.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import asyncio
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
try:
|
||||
import tornado
|
||||
from tornado import gen
|
||||
|
||||
from zmq.eventloop import ioloop, zmqstream
|
||||
except ImportError:
|
||||
tornado = None # type: ignore
|
||||
|
||||
|
||||
class TestZMQStream(TestCase):
|
||||
def setUp(self):
|
||||
if tornado is None:
|
||||
pytest.skip()
|
||||
if asyncio:
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
self.context = zmq.Context()
|
||||
self.loop = ioloop.IOLoop()
|
||||
self.loop.make_current()
|
||||
self.push = zmqstream.ZMQStream(self.context.socket(zmq.PUSH))
|
||||
self.pull = zmqstream.ZMQStream(self.context.socket(zmq.PULL))
|
||||
port = self.push.bind_to_random_port('tcp://127.0.0.1')
|
||||
self.pull.connect('tcp://127.0.0.1:%i' % port)
|
||||
self.stream = self.push
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close(all_fds=True)
|
||||
self.context.term()
|
||||
ioloop.IOLoop.clear_current()
|
||||
|
||||
def run_until_timeout(self, timeout=10):
|
||||
timed_out = []
|
||||
|
||||
@gen.coroutine
|
||||
def sleep_timeout():
|
||||
yield gen.sleep(timeout)
|
||||
timed_out[:] = ['timed out']
|
||||
self.loop.stop()
|
||||
|
||||
self.loop.add_callback(lambda: sleep_timeout())
|
||||
self.loop.start()
|
||||
assert not timed_out
|
||||
|
||||
def test_callable_check(self):
|
||||
"""Ensure callable check works (py3k)."""
|
||||
|
||||
self.stream.on_send(lambda *args: None)
|
||||
self.stream.on_recv(lambda *args: None)
|
||||
self.assertRaises(AssertionError, self.stream.on_recv, 1)
|
||||
self.assertRaises(AssertionError, self.stream.on_send, 1)
|
||||
self.assertRaises(AssertionError, self.stream.on_recv, zmq)
|
||||
|
||||
def test_on_recv_basic(self):
|
||||
sent = [b'basic']
|
||||
|
||||
def callback(msg):
|
||||
assert msg == sent
|
||||
self.loop.stop()
|
||||
|
||||
self.loop.add_callback(lambda: self.push.send_multipart(sent))
|
||||
self.pull.on_recv(callback)
|
||||
self.run_until_timeout()
|
||||
|
||||
def test_on_recv_wake(self):
|
||||
sent = [b'wake']
|
||||
|
||||
def callback(msg):
|
||||
assert msg == sent
|
||||
self.loop.stop()
|
||||
|
||||
self.pull.on_recv(callback)
|
||||
self.loop.call_later(1, lambda: self.push.send_multipart(sent))
|
||||
self.run_until_timeout()
|
Reference in New Issue
Block a user