first commit

This commit is contained in:
Ayxan
2022-05-23 00:16:32 +04:00
commit d660f2a4ca
24786 changed files with 4428337 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
"""Client-side implementations of the Jupyter protocol"""
from ._version import __version__ # noqa
from ._version import protocol_version # noqa
from ._version import protocol_version_info # noqa
from ._version import version_info # noqa
from .asynchronous import AsyncKernelClient # noqa
from .blocking import BlockingKernelClient # noqa
from .client import KernelClient # noqa
from .connect import * # noqa
from .launcher import * # noqa
from .manager import AsyncKernelManager # noqa
from .manager import KernelManager # noqa
from .manager import run_kernel # noqa
from .multikernelmanager import AsyncMultiKernelManager # noqa
from .multikernelmanager import MultiKernelManager # noqa
from .provisioning import KernelProvisionerBase # noqa
from .provisioning import LocalProvisioner # noqa

View File

@@ -0,0 +1,20 @@
import re
from typing import List
from typing import Union
__version__ = "7.3.1"
# Build up version_info tuple for backwards compatibility
pattern = r'(?P<major>\d+).(?P<minor>\d+).(?P<patch>\d+)(?P<rest>.*)'
match = re.match(pattern, __version__)
if match:
parts: List[Union[int, str]] = [int(match[part]) for part in ['major', 'minor', 'patch']]
if match['rest']:
parts.append(match['rest'])
else:
parts = []
version_info = tuple(parts)
protocol_version_info = (5, 3)
protocol_version = "%i.%i" % protocol_version_info

View File

@@ -0,0 +1,411 @@
"""Adapters for Jupyter msg spec versions."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from jupyter_client import protocol_version_info
def code_to_line(code: str, cursor_pos: int) -> Tuple[str, int]:
"""Turn a multiline code block and cursor position into a single line
and new cursor position.
For adapting ``complete_`` and ``object_info_request``.
"""
if not code:
return "", 0
for line in code.splitlines(True):
n = len(line)
if cursor_pos > n:
cursor_pos -= n
else:
break
return line, cursor_pos
_match_bracket = re.compile(r"\([^\(\)]+\)", re.UNICODE)
_end_bracket = re.compile(r"\([^\(]*$", re.UNICODE)
_identifier = re.compile(r"[a-z_][0-9a-z._]*", re.I | re.UNICODE)
def extract_oname_v4(code: str, cursor_pos: int) -> str:
"""Reimplement token-finding logic from IPython 2.x javascript
for adapting object_info_request from v5 to v4
"""
line, _ = code_to_line(code, cursor_pos)
oldline = line
line = _match_bracket.sub("", line)
while oldline != line:
oldline = line
line = _match_bracket.sub("", line)
# remove everything after last open bracket
line = _end_bracket.sub("", line)
matches = _identifier.findall(line)
if matches:
return matches[-1]
else:
return ""
class Adapter(object):
"""Base class for adapting messages
Override message_type(msg) methods to create adapters.
"""
msg_type_map: Dict[str, str] = {}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
return msg
def update_metadata(self, msg: Dict[str, Any]) -> Dict[str, Any]:
return msg
def update_msg_type(self, msg: Dict[str, Any]) -> Dict[str, Any]:
header = msg["header"]
msg_type = header["msg_type"]
if msg_type in self.msg_type_map:
msg["msg_type"] = header["msg_type"] = self.msg_type_map[msg_type]
return msg
def handle_reply_status_error(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""This will be called *instead of* the regular handler
on any reply with status != ok
"""
return msg
def __call__(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg = self.update_header(msg)
msg = self.update_metadata(msg)
msg = self.update_msg_type(msg)
header = msg["header"]
handler = getattr(self, header["msg_type"], None)
if handler is None:
return msg
# handle status=error replies separately (no change, at present)
if msg["content"].get("status", None) in {"error", "aborted"}:
return self.handle_reply_status_error(msg)
return handler(msg)
def _version_str_to_list(version: str) -> List[int]:
"""convert a version string to a list of ints
non-int segments are excluded
"""
v = []
for part in version.split("."):
try:
v.append(int(part))
except ValueError:
pass
return v
class V5toV4(Adapter):
"""Adapt msg protocol v5 to v4"""
version = "4.1"
msg_type_map = {
"execute_result": "pyout",
"execute_input": "pyin",
"error": "pyerr",
"inspect_request": "object_info_request",
"inspect_reply": "object_info_reply",
}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg["header"].pop("version", None)
msg["parent_header"].pop("version", None)
return msg
# shell channel
def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
v4c = {}
content = msg["content"]
for key in ("language_version", "protocol_version"):
if key in content:
v4c[key] = _version_str_to_list(content[key])
if content.get("implementation", "") == "ipython" and "implementation_version" in content:
v4c["ipython_version"] = _version_str_to_list(content["implementation_version"])
language_info = content.get("language_info", {})
language = language_info.get("name", "")
v4c.setdefault("language", language)
if "version" in language_info:
v4c.setdefault("language_version", _version_str_to_list(language_info["version"]))
msg["content"] = v4c
return msg
def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content.setdefault("user_variables", [])
return msg
def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content.setdefault("user_variables", {})
# TODO: handle payloads
return msg
def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
code = content["code"]
cursor_pos = content["cursor_pos"]
line, cursor_pos = code_to_line(code, cursor_pos)
new_content = msg["content"] = {}
new_content["text"] = ""
new_content["line"] = line
new_content["block"] = None
new_content["cursor_pos"] = cursor_pos
return msg
def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
cursor_start = content.pop("cursor_start")
cursor_end = content.pop("cursor_end")
match_len = cursor_end - cursor_start
content["matched_text"] = content["matches"][0][:match_len]
content.pop("metadata", None)
return msg
def object_info_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
code = content["code"]
cursor_pos = content["cursor_pos"]
line, _ = code_to_line(code, cursor_pos)
new_content = msg["content"] = {}
new_content["oname"] = extract_oname_v4(code, cursor_pos)
new_content["detail_level"] = content["detail_level"]
return msg
def object_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
msg["content"] = {"found": False, "oname": "unknown"}
return msg
# iopub channel
def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content["data"] = content.pop("text")
return msg
def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content.setdefault("source", "display")
data = content["data"]
if "application/json" in data:
try:
data["application/json"] = json.dumps(data["application/json"])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg["content"].pop("password", None)
return msg
class V4toV5(Adapter):
"""Convert msg spec V4 to V5"""
version = "5.0"
# invert message renames above
msg_type_map = {v: k for k, v in V5toV4.msg_type_map.items()}
def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg["header"]["version"] = self.version
if msg["parent_header"]:
msg["parent_header"]["version"] = self.version
return msg
# shell channel
def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
for key in ("protocol_version", "ipython_version"):
if key in content:
content[key] = ".".join(map(str, content[key]))
content.setdefault("protocol_version", "4.1")
if content["language"].startswith("python") and "ipython_version" in content:
content["implementation"] = "ipython"
content["implementation_version"] = content.pop("ipython_version")
language = content.pop("language")
language_info = content.setdefault("language_info", {})
language_info.setdefault("name", language)
if "language_version" in content:
language_version = ".".join(map(str, content.pop("language_version")))
language_info.setdefault("version", language_version)
content["banner"] = ""
return msg
def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
user_variables = content.pop("user_variables", [])
user_expressions = content.setdefault("user_expressions", {})
for v in user_variables:
user_expressions[v] = v
return msg
def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
user_expressions = content.setdefault("user_expressions", {})
user_variables = content.pop("user_variables", {})
if user_variables:
user_expressions.update(user_variables)
# Pager payloads became a mime bundle
for payload in content.get("payload", []):
if payload.get("source", None) == "page" and ("text" in payload):
if "data" not in payload:
payload["data"] = {}
payload["data"]["text/plain"] = payload.pop("text")
return msg
def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
old_content = msg["content"]
new_content = msg["content"] = {}
new_content["code"] = old_content["line"]
new_content["cursor_pos"] = old_content["cursor_pos"]
return msg
def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
# complete_reply needs more context than we have to get cursor_start and end.
# use special end=null to indicate current cursor position and negative offset
# for start relative to the cursor.
# start=None indicates that start == end (accounts for no -0).
content = msg["content"]
new_content = msg["content"] = {"status": "ok"}
new_content["matches"] = content["matches"]
if content["matched_text"]:
new_content["cursor_start"] = -len(content["matched_text"])
else:
# no -0, use None to indicate that start == end
new_content["cursor_start"] = None
new_content["cursor_end"] = None
new_content["metadata"] = {}
return msg
def inspect_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
name = content["oname"]
new_content = msg["content"] = {}
new_content["code"] = name
new_content["cursor_pos"] = len(name)
new_content["detail_level"] = content["detail_level"]
return msg
def inspect_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
content = msg["content"]
new_content = msg["content"] = {"status": "ok"}
found = new_content["found"] = content["found"]
new_content["data"] = data = {}
new_content["metadata"] = {}
if found:
lines = []
for key in ("call_def", "init_definition", "definition"):
if content.get(key, False):
lines.append(content[key])
break
for key in ("call_docstring", "init_docstring", "docstring"):
if content.get(key, False):
lines.append(content[key])
break
if not lines:
lines.append("<empty docstring>")
data["text/plain"] = "\n".join(lines)
return msg
# iopub channel
def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content["text"] = content.pop("data")
return msg
def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]:
content = msg["content"]
content.pop("source", None)
data = content["data"]
if "application/json" in data:
try:
data["application/json"] = json.loads(data["application/json"])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]:
msg["content"].setdefault("password", False)
return msg
def adapt(msg: Dict[str, Any], to_version: int = protocol_version_info[0]) -> Dict[str, Any]:
"""Adapt a single message to a target version
Parameters
----------
msg : dict
A Jupyter message.
to_version : int, optional
The target major version.
If unspecified, adapt to the current version.
Returns
-------
msg : dict
A Jupyter message appropriate in the new version.
"""
from .session import utcnow
header = msg["header"]
if "date" not in header:
header["date"] = utcnow()
if "version" in header:
from_version = int(header["version"].split(".")[0])
else:
# assume last version before adding the key to the header
from_version = 4
adapter = adapters.get((from_version, to_version), None)
if adapter is None:
return msg
return adapter(msg)
# one adapter per major version from,to
adapters = {
(5, 4): V5toV4(),
(4, 5): V4toV5(),
}

View File

@@ -0,0 +1 @@
from .client import AsyncKernelClient # noqa

View File

@@ -0,0 +1,63 @@
"""Implements an async kernel client"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from traitlets import Type
from jupyter_client.channels import HBChannel
from jupyter_client.channels import ZMQSocketChannel
from jupyter_client.client import KernelClient
from jupyter_client.client import reqrep
def wrapped(meth, channel):
def _(self, *args, **kwargs):
reply = kwargs.pop("reply", False)
timeout = kwargs.pop("timeout", None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return self._async_recv_reply(msg_id, timeout=timeout, channel=channel)
return _
class AsyncKernelClient(KernelClient):
"""A KernelClient with async APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
get_shell_msg = KernelClient._async_get_shell_msg
get_iopub_msg = KernelClient._async_get_iopub_msg
get_stdin_msg = KernelClient._async_get_stdin_msg
get_control_msg = KernelClient._async_get_control_msg
wait_for_ready = KernelClient._async_wait_for_ready
# The classes to use for the various channels
shell_channel_class = Type(ZMQSocketChannel)
iopub_channel_class = Type(ZMQSocketChannel)
stdin_channel_class = Type(ZMQSocketChannel)
hb_channel_class = Type(HBChannel)
control_channel_class = Type(ZMQSocketChannel)
_recv_reply = KernelClient._async_recv_reply
# replies come on the shell channel
execute = reqrep(wrapped, KernelClient.execute)
history = reqrep(wrapped, KernelClient.history)
complete = reqrep(wrapped, KernelClient.complete)
inspect = reqrep(wrapped, KernelClient.inspect)
kernel_info = reqrep(wrapped, KernelClient.kernel_info)
comm_info = reqrep(wrapped, KernelClient.comm_info)
is_alive = KernelClient._async_is_alive
execute_interactive = KernelClient._async_execute_interactive
# replies come on the control channel
shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")

View File

@@ -0,0 +1 @@
from .client import BlockingKernelClient # noqa

View File

@@ -0,0 +1,67 @@
"""Implements a fully blocking kernel client.
Useful for test suites and blocking terminal interfaces.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from traitlets import Type
from ..utils import run_sync
from jupyter_client.channels import HBChannel
from jupyter_client.channels import ZMQSocketChannel
from jupyter_client.client import KernelClient
from jupyter_client.client import reqrep
def wrapped(meth, channel):
def _(self, *args, **kwargs):
reply = kwargs.pop("reply", False)
timeout = kwargs.pop("timeout", None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return run_sync(self._async_recv_reply)(msg_id, timeout=timeout, channel=channel)
return _
class BlockingKernelClient(KernelClient):
"""A KernelClient with blocking APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
get_shell_msg = run_sync(KernelClient._async_get_shell_msg)
get_iopub_msg = run_sync(KernelClient._async_get_iopub_msg)
get_stdin_msg = run_sync(KernelClient._async_get_stdin_msg)
get_control_msg = run_sync(KernelClient._async_get_control_msg)
wait_for_ready = run_sync(KernelClient._async_wait_for_ready)
# The classes to use for the various channels
shell_channel_class = Type(ZMQSocketChannel)
iopub_channel_class = Type(ZMQSocketChannel)
stdin_channel_class = Type(ZMQSocketChannel)
hb_channel_class = Type(HBChannel)
control_channel_class = Type(ZMQSocketChannel)
_recv_reply = run_sync(KernelClient._async_recv_reply)
# replies come on the shell channel
execute = reqrep(wrapped, KernelClient.execute)
history = reqrep(wrapped, KernelClient.history)
complete = reqrep(wrapped, KernelClient.complete)
inspect = reqrep(wrapped, KernelClient.inspect)
kernel_info = reqrep(wrapped, KernelClient.kernel_info)
comm_info = reqrep(wrapped, KernelClient.comm_info)
is_alive = run_sync(KernelClient._async_is_alive)
execute_interactive = run_sync(KernelClient._async_execute_interactive)
# replies come on the control channel
shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")

View File

@@ -0,0 +1,267 @@
"""Base classes to manage a Client's interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import atexit
import time
import typing as t
from queue import Empty
from threading import Event
from threading import Thread
import zmq.asyncio
from .channelsabc import HBChannelABC
from .session import Session
from jupyter_client import protocol_version_info
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
# -----------------------------------------------------------------------------
# Constants and exceptions
# -----------------------------------------------------------------------------
major_protocol_version = protocol_version_info[0]
class InvalidPortNumber(Exception):
pass
class HBChannel(Thread):
"""The heartbeat channel which monitors the kernel heartbeat.
Note that the heartbeat channel is paused by default. As long as you start
this channel, the kernel manager will ensure that it is paused and un-paused
as appropriate.
"""
session = None
socket = None
address = None
_exiting = False
time_to_dead: float = 1.0
_running = None
_pause = None
_beating = None
def __init__(
self,
context: t.Optional[zmq.asyncio.Context] = None,
session: t.Optional[Session] = None,
address: t.Union[t.Tuple[str, int], str] = "",
):
"""Create the heartbeat monitor thread.
Parameters
----------
context : :class:`zmq.asyncio.Context`
The ZMQ context to use.
session : :class:`session.Session`
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
"""
super().__init__()
self.daemon = True
self.context = context
self.session = session
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
raise InvalidPortNumber(message)
address_str = "tcp://%s:%i" % address
else:
address_str = address
self.address = address_str
# running is False until `.start()` is called
self._running = False
self._exit = Event()
# don't start paused
self._pause = False
self.poller = zmq.Poller()
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if HBChannel is not None:
HBChannel._exiting = True
def _create_socket(self) -> None:
if self.socket is not None:
# close previous socket, before opening a new one
self.poller.unregister(self.socket)
self.socket.close()
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
self.socket.connect(self.address)
self.poller.register(self.socket, zmq.POLLIN)
def run(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_run())
loop.close()
async def _async_run(self) -> None:
"""The thread's main activity. Call start() instead."""
self._create_socket()
self._running = True
self._beating = True
assert self.socket is not None
while self._running:
if self._pause:
# just sleep, and skip the rest of the loop
self._exit.wait(self.time_to_dead)
continue
since_last_heartbeat = 0.0
# no need to catch EFSM here, because the previous event was
# either a recv or connect, which cannot be followed by EFSM
await self.socket.send(b"ping")
request_time = time.time()
# Wait until timeout
self._exit.wait(self.time_to_dead)
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
self._beating = bool(self.poller.poll(0))
if self._beating:
# the poll above guarantees we have something to recv
await self.socket.recv()
continue
elif self._running:
# nothing was received within the time limit, signal heart failure
since_last_heartbeat = time.time() - request_time
self.call_handlers(since_last_heartbeat)
# and close/reopen the socket, because the REQ/REP cycle has been broken
self._create_socket()
continue
def pause(self) -> None:
"""Pause the heartbeat."""
self._pause = True
def unpause(self) -> None:
"""Unpause the heartbeat."""
self._pause = False
def is_beating(self) -> bool:
"""Is the heartbeat running and responsive (and not paused)."""
if self.is_alive() and not self._pause and self._beating:
return True
else:
return False
def stop(self) -> None:
"""Stop the channel's event loop and join its thread."""
self._running = False
self._exit.set()
self.join()
self.close()
def close(self) -> None:
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def call_handlers(self, since_last_heartbeat: float) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
HBChannelABC.register(HBChannel)
class ZMQSocketChannel(object):
"""A ZMQ socket in an async API"""
def __init__(
self, socket: zmq.sugar.socket.Socket, session: Session, loop: t.Any = None
) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.asyncio.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket: t.Optional[zmq.sugar.socket.Socket] = socket
self.session = session
async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
assert self.socket is not None
msg = await self.socket.recv_multipart(**kwargs)
ident, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
async def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = await self.socket.poll(timeout)
if ready:
res = await self._recv()
return res
else:
raise Empty
async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(await self.get_msg())
except Empty:
break
return msgs
async def msg_ready(self) -> bool:
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(await self.socket.poll(timeout=0))
def close(self) -> None:
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self) -> bool:
return self.socket is not None
def send(self, msg: t.Dict[str, t.Any]) -> None:
"""Pass a message to the ZMQ socket to send"""
assert self.socket is not None
self.session.send(self.socket, msg)
def start(self) -> None:
pass

View File

@@ -0,0 +1,45 @@
"""Abstract base classes for kernel client channels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
class ChannelABC(object, metaclass=abc.ABCMeta):
"""A base class for all channel ABCs."""
@abc.abstractmethod
def start(self):
pass
@abc.abstractmethod
def stop(self):
pass
@abc.abstractmethod
def is_alive(self):
pass
class HBChannelABC(ChannelABC):
"""HBChannel ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.channels.HBChannel`
"""
@abc.abstractproperty
def time_to_dead(self):
pass
@abc.abstractmethod
def pause(self):
pass
@abc.abstractmethod
def unpause(self):
pass
@abc.abstractmethod
def is_beating(self):
pass

View File

@@ -0,0 +1,805 @@
"""Base class to manage the interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import sys
import time
import typing as t
from functools import partial
from getpass import getpass
from queue import Empty
import zmq.asyncio
from traitlets import Any
from traitlets import Bool
from traitlets import Instance
from traitlets import Type
from .channelsabc import ChannelABC
from .channelsabc import HBChannelABC
from .clientabc import KernelClientABC
from .connect import ConnectionFileMixin
from .session import Session
from .utils import ensure_async
from jupyter_client.channels import major_protocol_version
# some utilities to validate message structure, these might get moved elsewhere
# if they prove to have more generic utility
def validate_string_dict(dct: t.Dict[str, str]) -> None:
"""Validate that the input is a dict with string keys and values.
Raises ValueError if not."""
for k, v in dct.items():
if not isinstance(k, str):
raise ValueError("key %r in dict must be a string" % k)
if not isinstance(v, str):
raise ValueError("value %r in dict must be a string" % v)
def reqrep(wrapped: t.Callable, meth: t.Callable, channel: str = "shell") -> t.Callable:
wrapped = wrapped(meth, channel)
if not meth.__doc__:
# python -OO removes docstrings,
# so don't bother building the wrapped docstring
return wrapped
basedoc, _ = meth.__doc__.split("Returns\n", 1)
parts = [basedoc.strip()]
if "Parameters" not in basedoc:
parts.append(
"""
Parameters
----------
"""
)
parts.append(
"""
reply: bool (default: False)
Whether to wait for and return reply
timeout: float or None (default: None)
Timeout to use when waiting for a reply
Returns
-------
msg_id: str
The msg_id of the request sent, if reply=False (default)
reply: dict
The reply message for this request, if reply=True
"""
)
wrapped.__doc__ = "\n".join(parts)
return wrapped
class KernelClient(ConnectionFileMixin):
"""Communicates with a single kernel on any host via zmq channels.
There are five channels associated with each kernel:
* shell: for request/reply calls to the kernel.
* iopub: for the kernel to publish results to frontends.
* hb: for monitoring the kernel's heartbeat.
* stdin: for frontends to reply to raw_input calls in the kernel.
* control: for kernel management calls to the kernel.
The messages that can be sent on these channels are exposed as methods of the
client (KernelClient.execute, complete, history, etc.). These methods only
send the message, they don't wait for a reply. To get results, use e.g.
:meth:`get_shell_msg` to fetch messages from the shell channel.
"""
# The PyZMQ Context to use for communication with the kernel.
context = Instance(zmq.asyncio.Context)
_created_context = Bool(False)
def _context_default(self) -> zmq.asyncio.Context:
self._created_context = True
return zmq.asyncio.Context()
# The classes to use for the various channels
shell_channel_class = Type(ChannelABC)
iopub_channel_class = Type(ChannelABC)
stdin_channel_class = Type(ChannelABC)
hb_channel_class = Type(HBChannelABC)
control_channel_class = Type(ChannelABC)
# Protected traits
_shell_channel = Any()
_iopub_channel = Any()
_stdin_channel = Any()
_hb_channel = Any()
_control_channel = Any()
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin: bool = True
def __del__(self):
"""Handle garbage collection. Destroy context if applicable."""
if self._created_context and self.context and not self.context.closed:
if self.channels_running:
if self.log:
self.log.warning("Could not destroy zmq context for %s", self)
else:
if self.log:
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__
except AttributeError:
pass
else:
super_del()
# --------------------------------------------------------------------------
# Channel proxy methods
# --------------------------------------------------------------------------
async def _async_get_shell_msg(self, *args: Any, **kwargs: Any) -> t.Dict[str, t.Any]:
"""Get a message from the shell channel"""
return await self.shell_channel.get_msg(*args, **kwargs)
async def _async_get_iopub_msg(self, *args: Any, **kwargs: Any) -> t.Dict[str, t.Any]:
"""Get a message from the iopub channel"""
return await self.iopub_channel.get_msg(*args, **kwargs)
async def _async_get_stdin_msg(self, *args: Any, **kwargs: Any) -> t.Dict[str, t.Any]:
"""Get a message from the stdin channel"""
return await self.stdin_channel.get_msg(*args, **kwargs)
async def _async_get_control_msg(self, *args: Any, **kwargs: Any) -> t.Dict[str, t.Any]:
"""Get a message from the control channel"""
return await self.control_channel.get_msg(*args, **kwargs)
async def _async_wait_for_ready(self, timeout: t.Optional[float] = None) -> None:
"""Waits for a response when a client is blocked
- Sets future time for timeout
- Blocks on shell channel until a message is received
- Exit if the kernel has died
- If client times out before receiving a message from the kernel, send RuntimeError
- Flush the IOPub channel
"""
if timeout is None:
timeout = float("inf")
abs_timeout = time.time() + timeout
from .manager import KernelManager
if not isinstance(self.parent, KernelManager):
# This Client was not created by a KernelManager,
# so wait for kernel to become responsive to heartbeats
# before checking for kernel_info reply
while not await ensure_async(self.is_alive()):
if time.time() > abs_timeout:
raise RuntimeError(
"Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout
)
await asyncio.sleep(0.2)
# Wait for kernel info reply on shell channel
while True:
self.kernel_info()
try:
msg = await self.shell_channel.get_msg(timeout=1)
except Empty:
pass
else:
if msg["msg_type"] == "kernel_info_reply":
# Checking that IOPub is connected. If it is not connected, start over.
try:
await self.iopub_channel.get_msg(timeout=0.2)
except Empty:
pass
else:
self._handle_kernel_info_reply(msg)
break
if not await ensure_async(self.is_alive()):
raise RuntimeError("Kernel died before replying to kernel_info")
# Check if current time is ready check time plus timeout
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond in %d seconds" % timeout)
# Flush IOPub channel
while True:
try:
msg = await self.iopub_channel.get_msg(timeout=0.2)
except Empty:
break
async def _async_recv_reply(
self, msg_id: str, timeout: t.Optional[float] = None, channel: str = "shell"
) -> t.Dict[str, t.Any]:
"""Receive and return the reply for a given request"""
if timeout is not None:
deadline = time.monotonic() + timeout
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
try:
if channel == "control":
reply = await self._async_get_control_msg(timeout=timeout)
else:
reply = await self._async_get_shell_msg(timeout=timeout)
except Empty as e:
raise TimeoutError("Timeout waiting for reply") from e
if reply["parent_header"].get("msg_id") != msg_id:
# not my reply, someone may have forgotten to retrieve theirs
continue
return reply
def _stdin_hook_default(self, msg: t.Dict[str, t.Any]) -> None:
"""Handle an input request"""
content = msg["content"]
if content.get("password", False):
prompt = getpass
else:
prompt = input # type: ignore
try:
raw_data = prompt(content["prompt"])
except EOFError:
# turn EOFError into EOF character
raw_data = "\x04"
except KeyboardInterrupt:
sys.stdout.write("\n")
return
# only send stdin reply if there *was not* another request
# or execution finished while we were reading.
if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()):
self.input(raw_data)
def _output_hook_default(self, msg: t.Dict[str, t.Any]) -> None:
"""Default hook for redisplaying plain-text output"""
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
stream = getattr(sys, content["name"])
stream.write(content["text"])
elif msg_type in ("display_data", "execute_result"):
sys.stdout.write(content["data"].get("text/plain", ""))
elif msg_type == "error":
print("\n".join(content["traceback"]), file=sys.stderr)
def _output_hook_kernel(
self,
session: Session,
socket: zmq.sugar.socket.Socket,
parent_header: Any,
msg: t.Dict[str, t.Any],
) -> None:
"""Output hook when running inside an IPython kernel
adds rich output support.
"""
msg_type = msg["header"]["msg_type"]
if msg_type in ("display_data", "execute_result", "error"):
session.send(socket, msg_type, msg["content"], parent=parent_header)
else:
self._output_hook_default(msg)
# --------------------------------------------------------------------------
# Channel management methods
# --------------------------------------------------------------------------
def start_channels(
self,
shell: bool = True,
iopub: bool = True,
stdin: bool = True,
hb: bool = True,
control: bool = True,
) -> None:
"""Starts the channels for this kernel.
This will create the channels if they do not exist and then start
them (their activity runs in a thread). If port numbers of 0 are
being used (random ports) then you must first call
:meth:`start_kernel`. If the channels have been stopped and you
call this, :class:`RuntimeError` will be raised.
"""
if iopub:
self.iopub_channel.start()
if shell:
self.shell_channel.start()
if stdin:
self.stdin_channel.start()
self.allow_stdin = True
else:
self.allow_stdin = False
if hb:
self.hb_channel.start()
if control:
self.control_channel.start()
def stop_channels(self) -> None:
"""Stops all the running channels for this kernel.
This stops their event loops and joins their threads.
"""
if self.shell_channel.is_alive():
self.shell_channel.stop()
if self.iopub_channel.is_alive():
self.iopub_channel.stop()
if self.stdin_channel.is_alive():
self.stdin_channel.stop()
if self.hb_channel.is_alive():
self.hb_channel.stop()
if self.control_channel.is_alive():
self.control_channel.stop()
@property
def channels_running(self) -> bool:
"""Are any of the channels created and running?"""
return (
(self._shell_channel and self.shell_channel.is_alive())
or (self._iopub_channel and self.iopub_channel.is_alive())
or (self._stdin_channel and self.stdin_channel.is_alive())
or (self._hb_channel and self.hb_channel.is_alive())
or (self._control_channel and self.control_channel.is_alive())
)
ioloop = None # Overridden in subclasses that use pyzmq event loop
@property
def shell_channel(self) -> t.Any:
"""Get the shell channel object for this kernel."""
if self._shell_channel is None:
url = self._make_url("shell")
self.log.debug("connecting shell channel to %s", url)
socket = self.connect_shell(identity=self.session.bsession)
self._shell_channel = self.shell_channel_class(socket, self.session, self.ioloop)
return self._shell_channel
@property
def iopub_channel(self) -> t.Any:
"""Get the iopub channel object for this kernel."""
if self._iopub_channel is None:
url = self._make_url("iopub")
self.log.debug("connecting iopub channel to %s", url)
socket = self.connect_iopub()
self._iopub_channel = self.iopub_channel_class(socket, self.session, self.ioloop)
return self._iopub_channel
@property
def stdin_channel(self) -> t.Any:
"""Get the stdin channel object for this kernel."""
if self._stdin_channel is None:
url = self._make_url("stdin")
self.log.debug("connecting stdin channel to %s", url)
socket = self.connect_stdin(identity=self.session.bsession)
self._stdin_channel = self.stdin_channel_class(socket, self.session, self.ioloop)
return self._stdin_channel
@property
def hb_channel(self) -> t.Any:
"""Get the hb channel object for this kernel."""
if self._hb_channel is None:
url = self._make_url("hb")
self.log.debug("connecting heartbeat channel to %s", url)
self._hb_channel = self.hb_channel_class(self.context, self.session, url)
return self._hb_channel
@property
def control_channel(self) -> t.Any:
"""Get the control channel object for this kernel."""
if self._control_channel is None:
url = self._make_url("control")
self.log.debug("connecting control channel to %s", url)
socket = self.connect_control(identity=self.session.bsession)
self._control_channel = self.control_channel_class(socket, self.session, self.ioloop)
return self._control_channel
async def _async_is_alive(self) -> bool:
"""Is the kernel process still running?"""
from .manager import KernelManager
if isinstance(self.parent, KernelManager):
# This KernelClient was created by a KernelManager,
# we can ask the parent KernelManager:
return await ensure_async(self.parent.is_alive())
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True
async def _async_execute_interactive(
self,
code: str,
silent: bool = False,
store_history: bool = True,
user_expressions: t.Optional[t.Dict[str, t.Any]] = None,
allow_stdin: t.Optional[bool] = None,
stop_on_error: bool = True,
timeout: t.Optional[float] = None,
output_hook: t.Optional[t.Callable] = None,
stdin_hook: t.Optional[t.Callable] = None,
) -> t.Dict[str, t.Any]:
"""Execute code in the kernel interactively
Output will be redisplayed, and stdin prompts will be relayed as well.
If an IPython kernel is detected, rich output will be displayed.
You can pass a custom output_hook callable that will be called
with every IOPub message that is produced instead of the default redisplay.
.. versionadded:: 5.0
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
timeout: float or None (default: None)
Timeout to use when waiting for a reply
output_hook: callable(msg)
Function to be called with output messages.
If not specified, output will be redisplayed.
stdin_hook: callable(msg)
Function to be called with stdin_request messages.
If not specified, input/getpass will be called.
Returns
-------
reply: dict
The reply message for this request
"""
if not self.iopub_channel.is_alive():
raise RuntimeError("IOPub channel must be running to receive output")
if allow_stdin is None:
allow_stdin = self.allow_stdin
if allow_stdin and not self.stdin_channel.is_alive():
raise RuntimeError("stdin channel must be running to allow input")
msg_id = self.execute(
code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)
if stdin_hook is None:
stdin_hook = self._stdin_hook_default
if output_hook is None:
# detect IPython kernel
if "IPython" in sys.modules:
from IPython import get_ipython # type: ignore
ip = get_ipython()
in_kernel = getattr(ip, "kernel", False)
if in_kernel:
output_hook = partial(
self._output_hook_kernel,
ip.display_pub.session,
ip.display_pub.pub_socket,
ip.display_pub.parent_header,
)
if output_hook is None:
# default: redisplay plain-text outputs
output_hook = self._output_hook_default
# set deadline based on timeout
if timeout is not None:
deadline = time.monotonic() + timeout
else:
timeout_ms = None
poller = zmq.Poller()
iopub_socket = self.iopub_channel.socket
poller.register(iopub_socket, zmq.POLLIN)
if allow_stdin:
stdin_socket = self.stdin_channel.socket
poller.register(stdin_socket, zmq.POLLIN)
else:
stdin_socket = None
# wait for output and redisplay it
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
timeout_ms = int(1000 * timeout)
events = dict(poller.poll(timeout_ms))
if not events:
raise TimeoutError("Timeout waiting for output")
if stdin_socket in events:
req = await self.stdin_channel.get_msg(timeout=0)
stdin_hook(req)
continue
if iopub_socket not in events:
continue
msg = await self.iopub_channel.get_msg(timeout=0)
if msg["parent_header"].get("msg_id") != msg_id:
# not from my request
continue
output_hook(msg)
# stop on idle
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
):
break
# output is done, get the reply
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
return await self._async_recv_reply(msg_id, timeout=timeout)
# Methods to send specific messages on channels
def execute(
self,
code: str,
silent: bool = False,
store_history: bool = True,
user_expressions: t.Optional[t.Dict[str, t.Any]] = None,
allow_stdin: t.Optional[bool] = None,
stop_on_error: bool = True,
) -> str:
"""Execute code in the kernel.
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
Returns
-------
The msg_id of the message sent.
"""
if user_expressions is None:
user_expressions = {}
if allow_stdin is None:
allow_stdin = self.allow_stdin
# Don't waste network traffic if inputs are invalid
if not isinstance(code, str):
raise ValueError("code %r must be a string" % code)
validate_string_dict(user_expressions)
# Create class for content/msg creation. Related to, but possibly
# not in Session.
content = dict(
code=code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)
msg = self.session.msg("execute_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def complete(self, code: str, cursor_pos: t.Optional[int] = None) -> str:
"""Tab complete text in the kernel's namespace.
Parameters
----------
code : str
The context in which completion is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the completion was requested.
Default: ``len(code)``
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = dict(code=code, cursor_pos=cursor_pos)
msg = self.session.msg("complete_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def inspect(self, code: str, cursor_pos: t.Optional[int] = None, detail_level: int = 0) -> str:
"""Get metadata information about an object in the kernel's namespace.
It is up to the kernel to determine the appropriate object to inspect.
Parameters
----------
code : str
The context in which info is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the info was requested.
Default: ``len(code)``
detail_level : int, optional
The level of detail for the introspection (0-2)
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = dict(
code=code,
cursor_pos=cursor_pos,
detail_level=detail_level,
)
msg = self.session.msg("inspect_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def history(
self,
raw: bool = True,
output: bool = False,
hist_access_type: str = "range",
**kwargs: Any,
) -> str:
"""Get entries from the kernel's history list.
Parameters
----------
raw : bool
If True, return the raw input.
output : bool
If True, then return the output as well.
hist_access_type : str
'range' (fill in session, start and stop params), 'tail' (fill in n)
or 'search' (fill in pattern param).
session : int
For a range request, the session from which to get lines. Session
numbers are positive integers; negative ones count back from the
current session.
start : int
The first line number of a history range.
stop : int
The final (excluded) line number of a history range.
n : int
The number of lines of history to get for a tail request.
pattern : str
The glob-syntax pattern for a search request.
Returns
-------
The ID of the message sent.
"""
if hist_access_type == "range":
kwargs.setdefault("session", 0)
kwargs.setdefault("start", 0)
content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwargs)
msg = self.session.msg("history_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def kernel_info(self) -> str:
"""Request kernel info
Returns
-------
The msg_id of the message sent
"""
msg = self.session.msg("kernel_info_request")
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def comm_info(self, target_name: t.Optional[str] = None) -> str:
"""Request comm info
Returns
-------
The msg_id of the message sent
"""
if target_name is None:
content = {}
else:
content = dict(target_name=target_name)
msg = self.session.msg("comm_info_request", content)
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def _handle_kernel_info_reply(self, msg: t.Dict[str, t.Any]) -> None:
"""handle kernel info reply
sets protocol adaptation version. This might
be run from a separate thread.
"""
adapt_version = int(msg["content"]["protocol_version"].split(".")[0])
if adapt_version != major_protocol_version:
self.session.adapt_version = adapt_version
def is_complete(self, code: str) -> str:
"""Ask the kernel whether some code is complete and ready to execute."""
msg = self.session.msg("is_complete_request", {"code": code})
self.shell_channel.send(msg)
return msg["header"]["msg_id"]
def input(self, string: str) -> None:
"""Send a string of raw input to the kernel.
This should only be called in response to the kernel sending an
``input_request`` message on the stdin channel.
"""
content = dict(value=string)
msg = self.session.msg("input_reply", content)
self.stdin_channel.send(msg)
def shutdown(self, restart: bool = False) -> str:
"""Request an immediate kernel shutdown on the control channel.
Upon receipt of the (empty) reply, client code can safely assume that
the kernel has shut down and it's safe to forcefully terminate it if
it's still alive.
The kernel will send the reply via a function registered with Python's
atexit module, ensuring it's truly done as the kernel is done with all
normal operation.
Returns
-------
The msg_id of the message sent
"""
# Send quit message to kernel. Once we implement kernel-side setattr,
# this should probably be done that way, but for now this will do.
msg = self.session.msg("shutdown_request", {"restart": restart})
self.control_channel.send(msg)
return msg["header"]["msg_id"]
KernelClientABC.register(KernelClient)

View File

@@ -0,0 +1,84 @@
"""Abstract base class for kernel clients"""
# -----------------------------------------------------------------------------
# Copyright (c) The Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Main kernel client class
# -----------------------------------------------------------------------------
class KernelClientABC(object, metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.client.KernelClient`
"""
@abc.abstractproperty
def kernel(self):
pass
@abc.abstractproperty
def shell_channel_class(self):
pass
@abc.abstractproperty
def iopub_channel_class(self):
pass
@abc.abstractproperty
def hb_channel_class(self):
pass
@abc.abstractproperty
def stdin_channel_class(self):
pass
@abc.abstractproperty
def control_channel_class(self):
pass
# --------------------------------------------------------------------------
# Channel management methods
# --------------------------------------------------------------------------
@abc.abstractmethod
def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True):
pass
@abc.abstractmethod
def stop_channels(self):
pass
@abc.abstractproperty
def channels_running(self):
pass
@abc.abstractproperty
def shell_channel(self):
pass
@abc.abstractproperty
def iopub_channel(self):
pass
@abc.abstractproperty
def stdin_channel(self):
pass
@abc.abstractproperty
def hb_channel(self):
pass
@abc.abstractproperty
def control_channel(self):
pass

View File

@@ -0,0 +1,676 @@
"""Utilities for connecting to jupyter kernels
The :class:`ConnectionFileMixin` class in this module encapsulates the logic
related to writing and reading connections files.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import glob
import json
import os
import socket
import stat
import tempfile
import warnings
from getpass import getpass
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import zmq
from jupyter_core.paths import jupyter_data_dir
from jupyter_core.paths import jupyter_runtime_dir
from jupyter_core.paths import secure_write
from traitlets import Bool
from traitlets import CaselessStrEnum
from traitlets import Instance
from traitlets import Int
from traitlets import Integer
from traitlets import observe
from traitlets import Type
from traitlets import Unicode
from traitlets.config import LoggingConfigurable
from traitlets.config import SingletonConfigurable
from .localinterfaces import localhost
from .utils import _filefind
# Define custom type for kernel connection info
KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
def write_connection_file(
fname: Optional[str] = None,
shell_port: Union[Integer, Int, int] = 0,
iopub_port: Union[Integer, Int, int] = 0,
stdin_port: Union[Integer, Int, int] = 0,
hb_port: Union[Integer, Int, int] = 0,
control_port: Union[Integer, Int, int] = 0,
ip: str = "",
key: bytes = b"",
transport: str = "tcp",
signature_scheme: str = "hmac-sha256",
kernel_name: str = "",
) -> Tuple[str, KernelConnectionInfo]:
"""Generates a JSON config file, including the selection of random ports.
Parameters
----------
fname : unicode
The path to the file to write
shell_port : int, optional
The port to use for ROUTER (shell) channel.
iopub_port : int, optional
The port to use for the SUB channel.
stdin_port : int, optional
The port to use for the ROUTER (raw input) channel.
control_port : int, optional
The port to use for the ROUTER (control) channel.
hb_port : int, optional
The port to use for the heartbeat REP channel.
ip : str, optional
The ip address the kernel will bind to.
key : str, optional
The Session key used for message authentication.
signature_scheme : str, optional
The scheme used for message authentication.
This has the form 'digest-hash', where 'digest'
is the scheme used for digests, and 'hash' is the name of the hash function
used by the digest scheme.
Currently, 'hmac' is the only supported digest scheme,
and 'sha256' is the default hash function.
kernel_name : str, optional
The name of the kernel currently connected to.
"""
if not ip:
ip = localhost()
# default to temporary connector file
if not fname:
fd, fname = tempfile.mkstemp(".json")
os.close(fd)
# Find open ports as necessary.
ports: List[int] = []
sockets: List[socket.socket] = []
ports_needed = (
int(shell_port <= 0)
+ int(iopub_port <= 0)
+ int(stdin_port <= 0)
+ int(control_port <= 0)
+ int(hb_port <= 0)
)
if transport == "tcp":
for _ in range(ports_needed):
sock = socket.socket()
# struct.pack('ii', (0,0)) is 8 null bytes
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
sock.bind((ip, 0))
sockets.append(sock)
for sock in sockets:
port = sock.getsockname()[1]
sock.close()
ports.append(port)
else:
N = 1
for _ in range(ports_needed):
while os.path.exists("%s-%s" % (ip, str(N))):
N += 1
ports.append(N)
N += 1
if shell_port <= 0:
shell_port = ports.pop(0)
if iopub_port <= 0:
iopub_port = ports.pop(0)
if stdin_port <= 0:
stdin_port = ports.pop(0)
if control_port <= 0:
control_port = ports.pop(0)
if hb_port <= 0:
hb_port = ports.pop(0)
cfg: KernelConnectionInfo = dict(
shell_port=shell_port,
iopub_port=iopub_port,
stdin_port=stdin_port,
control_port=control_port,
hb_port=hb_port,
)
cfg["ip"] = ip
cfg["key"] = key.decode()
cfg["transport"] = transport
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name
# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
# which would let others execute arbitrarily code as you
with secure_write(fname) as f:
f.write(json.dumps(cfg, indent=2))
if hasattr(stat, "S_ISVTX"):
# set the sticky bit on the parent directory of the file
# to ensure only owner can remove it
runtime_dir = os.path.dirname(fname)
if runtime_dir:
permissions = os.stat(runtime_dir).st_mode
new_permissions = permissions | stat.S_ISVTX
if new_permissions != permissions:
try:
os.chmod(runtime_dir, new_permissions)
except OSError as e:
if e.errno == errno.EPERM:
# suppress permission errors setting sticky bit on runtime_dir,
# which we may not own.
pass
return fname, cfg
def find_connection_file(
filename: str = "kernel-*.json",
path: Optional[Union[str, List[str]]] = None,
profile: Optional[str] = None,
) -> str:
"""find a connection file, and return its absolute path.
The current working directory and optional search path
will be searched for the file if it is not given by absolute path.
If the argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the profile's security dir with
the latest access time will be used.
Parameters
----------
filename : str
The connection file or fileglob to search for.
path : str or list of strs[optional]
Paths in which to search for connection files.
Returns
-------
str : The absolute path of the connection file.
"""
if profile is not None:
warnings.warn("Jupyter has no profiles. profile=%s has been ignored." % profile)
if path is None:
path = [".", jupyter_runtime_dir()]
if isinstance(path, str):
path = [path]
try:
# first, try explicit name
return _filefind(filename, path)
except IOError:
pass
# not found by full name
if "*" in filename:
# given as a glob already
pat = filename
else:
# accept any substring match
pat = "*%s*" % filename
matches = []
for p in path:
matches.extend(glob.glob(os.path.join(p, pat)))
matches = [os.path.abspath(m) for m in matches]
if not matches:
raise IOError("Could not find %r in %r" % (filename, path))
elif len(matches) == 1:
return matches[0]
else:
# get most recent match, by access time:
return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
def tunnel_to_kernel(
connection_info: Union[str, KernelConnectionInfo],
sshserver: str,
sshkey: Optional[str] = None,
) -> Tuple[Any, ...]:
"""tunnel connections to a kernel via ssh
This will open five SSH tunnels from localhost on this machine to the
ports associated with the kernel. They can be either direct
localhost-localhost tunnels, or if an intermediate server is necessary,
the kernel must be listening on a public IP.
Parameters
----------
connection_info : dict or str (path)
Either a connection dict, or the path to a JSON connection file
sshserver : str
The ssh sever to use to tunnel to the kernel. Can be a full
`user@server:port` string. ssh config aliases are respected.
sshkey : str [optional]
Path to file containing ssh key to use for authentication.
Only necessary if your ssh config does not already associate
a keyfile with the host.
Returns
-------
(shell, iopub, stdin, hb, control) : ints
The five ports on localhost that have been forwarded to the kernel.
"""
from .ssh import tunnel
if isinstance(connection_info, str):
# it's a path, unpack it
with open(connection_info) as f:
connection_info = json.loads(f.read())
cf = cast(Dict[str, Any], connection_info)
lports = tunnel.select_random_ports(5)
rports = (
cf["shell_port"],
cf["iopub_port"],
cf["stdin_port"],
cf["hb_port"],
cf["control_port"],
)
remote_ip = cf["ip"]
if tunnel.try_passwordless_ssh(sshserver, sshkey):
password: Union[bool, str] = False
else:
password = getpass("SSH Password for %s: " % sshserver)
for lp, rp in zip(lports, rports):
tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
return tuple(lports)
# -----------------------------------------------------------------------------
# Mixin for classes that work with connection files
# -----------------------------------------------------------------------------
channel_socket_types = {
"hb": zmq.REQ,
"shell": zmq.DEALER,
"iopub": zmq.SUB,
"stdin": zmq.DEALER,
"control": zmq.DEALER,
}
port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
class ConnectionFileMixin(LoggingConfigurable):
"""Mixin for configurable classes that work with connection files"""
data_dir = Unicode()
def _data_dir_default(self):
return jupyter_data_dir()
# The addresses for the communication channels
connection_file = Unicode(
"",
config=True,
help="""JSON file in which to store connection info [default: kernel-<pid>.json]
This file will contain the IP, ports, and authentication key needed to connect
clients to this kernel. By default, this file will be created in the security dir
of the current profile, but can be specified by absolute path.
""",
)
_connection_file_written = Bool(False)
transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
kernel_name = Unicode()
ip = Unicode(
config=True,
help="""Set the kernel\'s IP address [default localhost].
If the IP address is something other than localhost, then
Consoles on other machines will be able to connect
to the Kernel, so be careful!""",
)
def _ip_default(self):
if self.transport == "ipc":
if self.connection_file:
return os.path.splitext(self.connection_file)[0] + "-ipc"
else:
return "kernel-ipc"
else:
return localhost()
@observe("ip")
def _ip_changed(self, change):
if change["new"] == "*":
self.ip = "0.0.0.0"
# protected traits
hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
# names of the ports with random assignment
_random_port_names: Optional[List[str]] = None
@property
def ports(self) -> List[int]:
return [getattr(self, name) for name in port_names]
# The Session to use for communication with the kernel.
session = Instance("jupyter_client.session.Session")
def _session_default(self):
from jupyter_client.session import Session
return Session(parent=self)
# --------------------------------------------------------------------------
# Connection and ipc file management
# --------------------------------------------------------------------------
def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
"""Return the connection info as a dict
Parameters
----------
session : bool [default: False]
If True, return our session object will be included in the connection info.
If False (default), the configuration parameters of our session object will be included,
rather than the session object itself.
Returns
-------
connect_info : dict
dictionary of connection information.
"""
info = dict(
transport=self.transport,
ip=self.ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
)
if session:
# add *clone* of my session,
# so that state such as digest_history is not shared.
info["session"] = self.session.clone()
else:
# add session info
info.update(
dict(
signature_scheme=self.session.signature_scheme,
key=self.session.key,
)
)
return info
# factory for blocking clients
blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
def blocking_client(self):
"""Make a blocking client connected to my kernel"""
info = self.get_connection_info()
bc = self.blocking_class(parent=self)
bc.load_connection_info(info)
return bc
def cleanup_connection_file(self) -> None:
"""Cleanup connection file *if we wrote it*
Will not raise if the connection file was already removed somehow.
"""
if self._connection_file_written:
# cleanup connection files on full shutdown of kernel we started
self._connection_file_written = False
try:
os.remove(self.connection_file)
except (OSError, AttributeError):
pass
def cleanup_ipc_files(self) -> None:
"""Cleanup ipc files if we wrote them."""
if self.transport != "ipc":
return
for port in self.ports:
ipcfile = "%s-%i" % (self.ip, port)
try:
os.remove(ipcfile)
except OSError:
pass
def _record_random_port_names(self) -> None:
"""Records which of the ports are randomly assigned.
Records on first invocation, if the transport is tcp.
Does nothing on later invocations."""
if self.transport != "tcp":
return
if self._random_port_names is not None:
return
self._random_port_names = []
for name in port_names:
if getattr(self, name) <= 0:
self._random_port_names.append(name)
def cleanup_random_ports(self) -> None:
"""Forgets randomly assigned port numbers and cleans up the connection file.
Does nothing if no port numbers have been randomly assigned.
In particular, does nothing unless the transport is tcp.
"""
if not self._random_port_names:
return
for name in self._random_port_names:
setattr(self, name, 0)
self.cleanup_connection_file()
def write_connection_file(self) -> None:
"""Write connection info to JSON dict in self.connection_file."""
if self._connection_file_written and os.path.exists(self.connection_file):
return
self.connection_file, cfg = write_connection_file(
self.connection_file,
transport=self.transport,
ip=self.ip,
key=self.session.key,
stdin_port=self.stdin_port,
iopub_port=self.iopub_port,
shell_port=self.shell_port,
hb_port=self.hb_port,
control_port=self.control_port,
signature_scheme=self.session.signature_scheme,
kernel_name=self.kernel_name,
)
# write_connection_file also sets default ports:
self._record_random_port_names()
for name in port_names:
setattr(self, name, cfg[name])
self._connection_file_written = True
def load_connection_file(self, connection_file: Optional[str] = None) -> None:
"""Load connection info from JSON dict in self.connection_file.
Parameters
----------
connection_file: unicode, optional
Path to connection file to load.
If unspecified, use self.connection_file
"""
if connection_file is None:
connection_file = self.connection_file
self.log.debug("Loading connection file %s", connection_file)
with open(connection_file) as f:
info = json.load(f)
self.load_connection_info(info)
def load_connection_info(self, info: KernelConnectionInfo) -> None:
"""Load connection info from a dict containing connection info.
Typically this data comes from a connection file
and is called by load_connection_file.
Parameters
----------
info: dict
Dictionary containing connection_info.
See the connection_file spec for details.
"""
self.transport = info.get("transport", self.transport)
self.ip = info.get("ip", self._ip_default())
self._record_random_port_names()
for name in port_names:
if getattr(self, name) == 0 and name in info:
# not overridden by config or cl_args
setattr(self, name, info[name])
if "key" in info:
key = info["key"]
if isinstance(key, str):
key = key.encode()
assert isinstance(key, bytes)
self.session.key = key
if "signature_scheme" in info:
self.session.signature_scheme = info["signature_scheme"]
def _force_connection_info(self, info: KernelConnectionInfo) -> None:
"""Unconditionally loads connection info from a dict containing connection info.
Overwrites connection info-based attributes, regardless of their current values
and writes this information to the connection file.
"""
# Reset current ports to 0 and indicate file has not been written to enable override
self._connection_file_written = False
for name in port_names:
setattr(self, name, 0)
self.load_connection_info(info)
self.write_connection_file()
# --------------------------------------------------------------------------
# Creating connected sockets
# --------------------------------------------------------------------------
def _make_url(self, channel: str) -> str:
"""Make a ZeroMQ URL for a given channel."""
transport = self.transport
ip = self.ip
port = getattr(self, "%s_port" % channel)
if transport == "tcp":
return "tcp://%s:%i" % (ip, port)
else:
return "%s://%s-%s" % (transport, ip, port)
def _create_connected_socket(
self, channel: str, identity: Optional[bytes] = None
) -> zmq.sugar.socket.Socket:
"""Create a zmq Socket and connect it to the kernel."""
url = self._make_url(channel)
socket_type = channel_socket_types[channel]
self.log.debug("Connecting to: %s" % url)
sock = self.context.socket(socket_type)
# set linger to 1s to prevent hangs at exit
sock.linger = 1000
if identity:
sock.identity = identity
sock.connect(url)
return sock
def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the IOPub channel"""
sock = self._create_connected_socket("iopub", identity=identity)
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock
def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Shell channel"""
return self._create_connected_socket("shell", identity=identity)
def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the StdIn channel"""
return self._create_connected_socket("stdin", identity=identity)
def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Heartbeat channel"""
return self._create_connected_socket("hb", identity=identity)
def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
"""return zmq Socket connected to the Control channel"""
return self._create_connected_socket("control", identity=identity)
class LocalPortCache(SingletonConfigurable):
"""
Used to keep track of local ports in order to prevent race conditions that
can occur between port acquisition and usage by the kernel. All locally-
provisioned kernels should use this mechanism to limit the possibility of
race conditions. Note that this does not preclude other applications from
acquiring a cached but unused port, thereby re-introducing the issue this
class is attempting to resolve (minimize).
See: https://github.com/jupyter/jupyter_client/issues/487
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.currently_used_ports: Set[int] = set()
def find_available_port(self, ip: str) -> int:
while True:
tmp_sock = socket.socket()
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
tmp_sock.bind((ip, 0))
port = tmp_sock.getsockname()[1]
tmp_sock.close()
# This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
# We prevent two kernels to have the same ports.
if port not in self.currently_used_ports:
self.currently_used_ports.add(port)
return port
def return_port(self, port: int) -> None:
if port in self.currently_used_ports: # Tolerate uncached ports
self.currently_used_ports.remove(port)
__all__ = [
"write_connection_file",
"find_connection_file",
"tunnel_to_kernel",
"KernelConnectionInfo",
"LocalPortCache",
]

View File

@@ -0,0 +1,373 @@
""" A minimal application base mixin for all ZMQ based IPython frontends.
This is not a complete console app, as subprocess will not be able to receive
input, there is no real readline support, among other limitations. This is a
refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import atexit
import os
import signal
import sys
import uuid
import warnings
from typing import cast
from jupyter_core.application import base_aliases
from jupyter_core.application import base_flags
from traitlets import CBool
from traitlets import CUnicode
from traitlets import Dict
from traitlets import List
from traitlets import Type
from traitlets import Unicode
from traitlets.config.application import boolean_flag
from . import connect
from . import find_connection_file
from . import KernelManager
from . import tunnel_to_kernel
from .blocking import BlockingKernelClient
from .kernelspec import NoSuchKernel
from .localinterfaces import localhost
from .restarter import KernelRestarter
from .session import Session
from .utils import _filefind
ConnectionFileMixin = connect.ConnectionFileMixin
# -----------------------------------------------------------------------------
# Aliases and Flags
# -----------------------------------------------------------------------------
flags = {}
flags.update(base_flags)
# the flags that are specific to the frontend
# these must be scrubbed before being passed to the kernel,
# or it will raise an error on unrecognized flags
app_flags = {
"existing": (
{"JupyterConsoleApp": {"existing": "kernel*.json"}},
"Connect to an existing kernel. If no argument specified, guess most recent",
),
}
app_flags.update(
boolean_flag(
"confirm-exit",
"JupyterConsoleApp.confirm_exit",
"""Set to display confirmation dialog on exit. You can always use 'exit' or
'quit', to force a direct exit without any confirmation. This can also
be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
""",
"""Don't prompt the user when exiting. This will terminate the kernel
if it is owned by the frontend, and leave it alive if it is external.
This can also be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
""",
)
)
flags.update(app_flags)
aliases = {}
aliases.update(base_aliases)
# also scrub aliases from the frontend
app_aliases = dict(
ip="JupyterConsoleApp.ip",
transport="JupyterConsoleApp.transport",
hb="JupyterConsoleApp.hb_port",
shell="JupyterConsoleApp.shell_port",
iopub="JupyterConsoleApp.iopub_port",
stdin="JupyterConsoleApp.stdin_port",
control="JupyterConsoleApp.control_port",
existing="JupyterConsoleApp.existing",
f="JupyterConsoleApp.connection_file",
kernel="JupyterConsoleApp.kernel_name",
ssh="JupyterConsoleApp.sshserver",
sshkey="JupyterConsoleApp.sshkey",
)
aliases.update(app_aliases)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
classes = [KernelManager, KernelRestarter, Session]
class JupyterConsoleApp(ConnectionFileMixin):
name = "jupyter-console-mixin"
description = """
The Jupyter Console Mixin.
This class contains the common portions of console client (QtConsole,
ZMQ-based terminal console, etc). It is not a full console, in that
launched terminal subprocesses will not be able to accept input.
The Console using this mixing supports various extra features beyond
the single-process Terminal IPython shell, such as connecting to
existing kernel, via:
jupyter console <appname> --existing
as well as tunnel via SSH
"""
classes = classes
flags = Dict(flags)
aliases = Dict(aliases)
kernel_manager_class = Type(
default_value=KernelManager,
config=True,
help="The kernel manager class to use.",
)
kernel_client_class = BlockingKernelClient
kernel_argv = List(Unicode())
# connection info:
sshserver = Unicode("", config=True, help="""The SSH server to use to connect to the kernel.""")
sshkey = Unicode(
"",
config=True,
help="""Path to the ssh key to use for logging in to the ssh server.""",
)
def _connection_file_default(self) -> str:
return "kernel-%i.json" % os.getpid()
existing = CUnicode("", config=True, help="""Connect to an already running kernel""")
kernel_name = Unicode(
"python", config=True, help="""The name of the default kernel to start."""
)
confirm_exit = CBool(
True,
config=True,
help="""
Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
to force a direct exit without any confirmation.""",
)
def build_kernel_argv(self, argv: object = None) -> None:
"""build argv to be passed to kernel subprocess
Override in subclasses if any args should be passed to the kernel
"""
self.kernel_argv = self.extra_args
def init_connection_file(self) -> None:
"""find the connection file, and load the info if found.
The current working directory and the current profile's security
directory will be searched for the file if it is not given by
absolute path.
When attempting to connect to an existing kernel and the `--existing`
argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the current profile's security dir
with the latest access time will be used.
After this method is called, self.connection_file contains the *full path*
to the connection file, never just its name.
"""
if self.existing:
try:
cf = find_connection_file(self.existing, [".", self.runtime_dir])
except Exception:
self.log.critical(
"Could not find existing kernel connection file %s", self.existing
)
self.exit(1)
self.log.debug("Connecting to existing kernel: %s" % cf)
self.connection_file = cf
else:
# not existing, check if we are going to write the file
# and ensure that self.connection_file is a full path, not just the shortname
try:
cf = find_connection_file(self.connection_file, [self.runtime_dir])
except Exception:
# file might not exist
if self.connection_file == os.path.basename(self.connection_file):
# just shortname, put it in security dir
cf = os.path.join(self.runtime_dir, self.connection_file)
else:
cf = self.connection_file
self.connection_file = cf
try:
self.connection_file = _filefind(self.connection_file, [".", self.runtime_dir])
except IOError:
self.log.debug("Connection File not found: %s", self.connection_file)
return
# should load_connection_file only be used for existing?
# as it is now, this allows reusing ports if an existing
# file is requested
try:
self.load_connection_file()
except Exception:
self.log.error(
"Failed to load connection file: %r",
self.connection_file,
exc_info=True,
)
self.exit(1)
def init_ssh(self) -> None:
"""set up ssh tunnels, if needed."""
if not self.existing or (not self.sshserver and not self.sshkey):
return
self.load_connection_file()
transport = self.transport
ip = self.ip
if transport != "tcp":
self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
sys.exit(-1)
if self.sshkey and not self.sshserver:
# specifying just the key implies that we are connecting directly
self.sshserver = ip
ip = localhost()
# build connection dict for tunnels:
info = dict(
ip=ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
)
self.log.info("Forwarding connections to %s via %s" % (ip, self.sshserver))
# tunnels return a new set of ports, which will be on localhost:
self.ip = localhost()
try:
newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
except: # noqa
# even catch KeyboardInterrupt
self.log.error("Could not setup tunnels", exc_info=True)
self.exit(1)
(
self.shell_port,
self.iopub_port,
self.stdin_port,
self.hb_port,
self.control_port,
) = newports
cf = self.connection_file
root, ext = os.path.splitext(cf)
self.connection_file = root + "-ssh" + ext
self.write_connection_file() # write the new connection file
self.log.info("To connect another client via this tunnel, use:")
self.log.info("--existing %s" % os.path.basename(self.connection_file))
def _new_connection_file(self) -> str:
cf = ""
while not cf:
# we don't need a 128b id to distinguish kernels, use more readable
# 48b node segment (12 hex chars). Users running more than 32k simultaneous
# kernels can subclass.
ident = str(uuid.uuid4()).split("-")[-1]
cf = os.path.join(self.runtime_dir, "kernel-%s.json" % ident)
# only keep if it's actually new. Protect against unlikely collision
# in 48b random search space
cf = cf if not os.path.exists(cf) else ""
return cf
def init_kernel_manager(self) -> None:
# Don't let Qt or ZMQ swallow KeyboardInterupts.
if self.existing:
self.kernel_manager = None
return
signal.signal(signal.SIGINT, signal.SIG_DFL)
# Create a KernelManager and start a kernel.
try:
self.kernel_manager = self.kernel_manager_class(
ip=self.ip,
session=self.session,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
kernel_name=self.kernel_name,
parent=self,
data_dir=self.data_dir,
)
except NoSuchKernel:
self.log.critical("Could not find kernel %s", self.kernel_name)
self.exit(1)
self.kernel_manager = cast(KernelManager, self.kernel_manager)
self.kernel_manager.client_factory = self.kernel_client_class
kwargs = {}
kwargs["extra_arguments"] = self.kernel_argv
self.kernel_manager.start_kernel(**kwargs)
atexit.register(self.kernel_manager.cleanup_ipc_files)
if self.sshserver:
# ssh, write new connection file
self.kernel_manager.write_connection_file()
# in case KM defaults / ssh writing changes things:
km = self.kernel_manager
self.shell_port = km.shell_port
self.iopub_port = km.iopub_port
self.stdin_port = km.stdin_port
self.hb_port = km.hb_port
self.control_port = km.control_port
self.connection_file = km.connection_file
atexit.register(self.kernel_manager.cleanup_connection_file)
def init_kernel_client(self) -> None:
if self.kernel_manager is not None:
self.kernel_client = self.kernel_manager.client()
else:
self.kernel_client = self.kernel_client_class(
session=self.session,
ip=self.ip,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
parent=self,
)
self.kernel_client.start_channels()
def initialize(self, argv: object = None) -> None:
"""
Classes which mix this class in should call:
JupyterConsoleApp.initialize(self,argv)
"""
if self._dispatching:
return
self.init_connection_file()
self.init_ssh()
self.init_kernel_manager()
self.init_kernel_client()
class IPythonConsoleApp(JupyterConsoleApp):
def __init__(self, *args, **kwargs):
warnings.warn("IPythonConsoleApp is deprecated. Use JupyterConsoleApp")
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,4 @@
from .manager import AsyncIOLoopKernelManager # noqa
from .manager import IOLoopKernelManager # noqa
from .restarter import AsyncIOLoopKernelRestarter # noqa
from .restarter import IOLoopKernelRestarter # noqa

View File

@@ -0,0 +1,98 @@
"""A kernel manager with a tornado IOLoop"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from tornado import ioloop
from traitlets import Instance
from traitlets import Type
from zmq.eventloop.zmqstream import ZMQStream
from .restarter import AsyncIOLoopKernelRestarter
from .restarter import IOLoopKernelRestarter
from jupyter_client.manager import AsyncKernelManager
from jupyter_client.manager import KernelManager
def as_zmqstream(f):
def wrapped(self, *args, **kwargs):
socket = f(self, *args, **kwargs)
return ZMQStream(socket, self.loop)
return wrapped
class IOLoopKernelManager(KernelManager):
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self):
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=IOLoopKernelRestarter,
klass=IOLoopKernelRestarter,
help=(
"Type of KernelRestarter to use. "
"Must be a subclass of IOLoopKernelRestarter.\n"
"Override this to customize how kernel restarts are managed."
),
config=True,
)
_restarter = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True)
def start_restarter(self):
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop, parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self):
if self.autorestart:
if self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(KernelManager.connect_shell)
connect_control = as_zmqstream(KernelManager.connect_control)
connect_iopub = as_zmqstream(KernelManager.connect_iopub)
connect_stdin = as_zmqstream(KernelManager.connect_stdin)
connect_hb = as_zmqstream(KernelManager.connect_hb)
class AsyncIOLoopKernelManager(AsyncKernelManager):
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self):
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=AsyncIOLoopKernelRestarter,
klass=AsyncIOLoopKernelRestarter,
help=(
"Type of KernelRestarter to use. "
"Must be a subclass of AsyncIOLoopKernelManager.\n"
"Override this to customize how kernel restarts are managed."
),
config=True,
)
_restarter = Instance("jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True)
def start_restarter(self):
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop, parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self):
if self.autorestart:
if self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(AsyncKernelManager.connect_shell)
connect_control = as_zmqstream(AsyncKernelManager.connect_control)
connect_iopub = as_zmqstream(AsyncKernelManager.connect_iopub)
connect_stdin = as_zmqstream(AsyncKernelManager.connect_stdin)
connect_hb = as_zmqstream(AsyncKernelManager.connect_hb)

View File

@@ -0,0 +1,101 @@
"""A basic in process kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import time
import warnings
from traitlets import Instance
from zmq.eventloop import ioloop
from jupyter_client.restarter import KernelRestarter
from jupyter_client.utils import run_sync
class IOLoopKernelRestarter(KernelRestarter):
"""Monitor and autorestart a kernel."""
loop = Instance("tornado.ioloop.IOLoop")
def _loop_default(self):
warnings.warn(
"IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2",
DeprecationWarning,
stacklevel=4,
)
return ioloop.IOLoop.current()
_pcallback = None
def start(self):
"""Start the polling of the kernel."""
if self._pcallback is None:
if asyncio.iscoroutinefunction(self.poll):
cb = run_sync(self.poll)
else:
cb = self.poll
self._pcallback = ioloop.PeriodicCallback(
cb,
1000 * self.time_to_dead,
)
self._pcallback.start()
def stop(self):
"""Stop the kernel polling."""
if self._pcallback is not None:
self._pcallback.stop()
self._pcallback = None
class AsyncIOLoopKernelRestarter(IOLoopKernelRestarter):
async def poll(self):
if self.debug:
self.log.debug("Polling kernel...")
is_alive = await self.kernel_manager.is_alive()
now = time.time()
if not is_alive:
self._last_dead = now
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count > self.restart_limit:
self.log.warning("AsyncIOLoopKernelRestarter: restart failed")
self._fire_callbacks("dead")
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info(
"AsyncIOLoopKernelRestarter: restarting kernel (%i/%i), %s random ports",
self._restart_count,
self.restart_limit,
"new" if newports else "keep",
)
self._fire_callbacks("restart")
await self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
# Since `is_alive` only tests that the kernel process is alive, it does not
# indicate that the kernel has successfully completed startup. To solve this
# correctly, we would need to wait for a kernel info reply, but it is not
# necessarily appropriate to start a kernel client + channels in the
# restarter. Therefore, we use "has been alive continuously for X time" as a
# heuristic for a stable start up.
# See https://github.com/jupyter/jupyter_client/pull/717 for details.
stable_start_time = self.stable_start_time
if self.kernel_manager.provisioner:
stable_start_time = self.kernel_manager.provisioner.get_stable_start_time(
recommended=stable_start_time
)
if self._initial_startup and now - self._last_dead >= stable_start_time:
self._initial_startup = False
if self._restarting and now - self._last_dead >= stable_start_time:
self.log.debug("AsyncIOLoopKernelRestarter: restart apparently succeeded")
self._restarting = False

View File

@@ -0,0 +1,192 @@
"""Utilities to manipulate JSON objects."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import math
import numbers
import re
import types
import warnings
from binascii import b2a_base64
from collections.abc import Iterable
from datetime import datetime
from typing import Optional
from typing import Union
from dateutil.parser import parse as _dateutil_parse # type: ignore
from dateutil.tz import tzlocal # type: ignore
next_attr_name = "__next__" # Not sure what downstream library uses this, but left it to be safe
# -----------------------------------------------------------------------------
# Globals and constants
# -----------------------------------------------------------------------------
# timestamp formats
ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
ISO8601_PAT = re.compile(
r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$"
)
# holy crap, strptime is not threadsafe.
# Calling it once at import seems to help.
datetime.strptime("1", "%d")
# -----------------------------------------------------------------------------
# Classes and functions
# -----------------------------------------------------------------------------
def _ensure_tzinfo(dt: datetime) -> datetime:
"""Ensure a datetime object has tzinfo
If no tzinfo is present, add tzlocal
"""
if not dt.tzinfo:
# No more naïve datetime objects!
warnings.warn(
"Interpreting naive datetime as local %s. Please add timezone info to timestamps." % dt,
DeprecationWarning,
stacklevel=4,
)
dt = dt.replace(tzinfo=tzlocal())
return dt
def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]:
"""parse an ISO8601 date string
If it is None or not a valid ISO8601 timestamp,
it will be returned unmodified.
Otherwise, it will return a datetime object.
"""
if s is None:
return s
m = ISO8601_PAT.match(s)
if m:
dt = _dateutil_parse(s)
return _ensure_tzinfo(dt)
return s
def extract_dates(obj):
"""extract ISO8601 dates from unpacked JSON"""
if isinstance(obj, dict):
new_obj = {} # don't clobber
for k, v in obj.items():
new_obj[k] = extract_dates(v)
obj = new_obj
elif isinstance(obj, (list, tuple)):
obj = [extract_dates(o) for o in obj]
elif isinstance(obj, str):
obj = parse_date(obj)
return obj
def squash_dates(obj):
"""squash datetime objects into ISO8601 strings"""
if isinstance(obj, dict):
obj = dict(obj) # don't clobber
for k, v in obj.items():
obj[k] = squash_dates(v)
elif isinstance(obj, (list, tuple)):
obj = [squash_dates(o) for o in obj]
elif isinstance(obj, datetime):
obj = obj.isoformat()
return obj
def date_default(obj):
"""DEPRECATED: Use jupyter_client.jsonutil.json_default"""
warnings.warn(
"date_default is deprecated since jupyter_client 7.0.0."
" Use jupyter_client.jsonutil.json_default.",
stacklevel=2,
)
return json_default(obj)
def json_default(obj):
"""default function for packing objects in JSON."""
if isinstance(obj, datetime):
obj = _ensure_tzinfo(obj)
return obj.isoformat().replace('+00:00', 'Z')
if isinstance(obj, bytes):
return b2a_base64(obj).decode('ascii')
if isinstance(obj, Iterable):
return list(obj)
if isinstance(obj, numbers.Integral):
return int(obj)
if isinstance(obj, numbers.Real):
return float(obj)
raise TypeError("%r is not JSON serializable" % obj)
# Copy of the old ipykernel's json_clean
# This is temporary, it should be removed when we deprecate support for
# non-valid JSON messages
def json_clean(obj):
# types that are 'atomic' and ok in json as-is.
atomic_ok = (str, type(None))
# containers that we need to convert into lists
container_to_list = (tuple, set, types.GeneratorType)
# Since bools are a subtype of Integrals, which are a subtype of Reals,
# we have to check them in that order.
if isinstance(obj, bool):
return obj
if isinstance(obj, numbers.Integral):
# cast int to int, in case subclasses override __str__ (e.g. boost enum, #4598)
return int(obj)
if isinstance(obj, numbers.Real):
# cast out-of-range floats to their reprs
if math.isnan(obj) or math.isinf(obj):
return repr(obj)
return float(obj)
if isinstance(obj, atomic_ok):
return obj
if isinstance(obj, bytes):
# unanmbiguous binary data is base64-encoded
# (this probably should have happened upstream)
return b2a_base64(obj).decode('ascii')
if isinstance(obj, container_to_list) or (
hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)
):
obj = list(obj)
if isinstance(obj, list):
return [json_clean(x) for x in obj]
if isinstance(obj, dict):
# First, validate that the dict won't lose data in conversion due to
# key collisions after stringification. This can happen with keys like
# True and 'true' or 1 and '1', which collide in JSON.
nkeys = len(obj)
nkeys_collapsed = len(set(map(str, obj)))
if nkeys != nkeys_collapsed:
raise ValueError(
'dict cannot be safely converted to JSON: '
'key collision would lead to dropped values'
)
# If all OK, proceed by making the new dict that will be json-safe
out = {}
for k, v in obj.items():
out[str(k)] = json_clean(v)
return out
if isinstance(obj, datetime):
return obj.strftime(ISO8601)
# we don't understand it, it's probably an unserializable object
raise ValueError("Can't clean for JSON: %r" % obj)

View File

@@ -0,0 +1,88 @@
import os
import signal
import uuid
from jupyter_core.application import base_flags
from jupyter_core.application import JupyterApp
from tornado.ioloop import IOLoop
from traitlets import Unicode
from . import __version__
from .kernelspec import KernelSpecManager
from .kernelspec import NATIVE_KERNEL_NAME
from .manager import KernelManager
class KernelApp(JupyterApp):
"""Launch a kernel by name in a local subprocess."""
version = __version__
description = "Run a kernel locally in a subprocess"
classes = [KernelManager, KernelSpecManager]
aliases = {
"kernel": "KernelApp.kernel_name",
"ip": "KernelManager.ip",
}
flags = {"debug": base_flags["debug"]}
kernel_name = Unicode(NATIVE_KERNEL_NAME, help="The name of a kernel type to start").tag(
config=True
)
def initialize(self, argv=None):
super().initialize(argv)
cf_basename = "kernel-%s.json" % uuid.uuid4()
self.config.setdefault("KernelManager", {}).setdefault(
"connection_file", os.path.join(self.runtime_dir, cf_basename)
)
self.km = KernelManager(kernel_name=self.kernel_name, config=self.config)
self.loop = IOLoop.current()
self.loop.add_callback(self._record_started)
def setup_signals(self) -> None:
"""Shutdown on SIGTERM or SIGINT (Ctrl-C)"""
if os.name == "nt":
return
def shutdown_handler(signo, frame):
self.loop.add_callback_from_signal(self.shutdown, signo)
for sig in [signal.SIGTERM, signal.SIGINT]:
signal.signal(sig, shutdown_handler)
def shutdown(self, signo: int) -> None:
self.log.info("Shutting down on signal %d" % signo)
self.km.shutdown_kernel()
self.loop.stop()
def log_connection_info(self) -> None:
cf = self.km.connection_file
self.log.info("Connection file: %s", cf)
self.log.info("To connect a client: --existing %s", os.path.basename(cf))
def _record_started(self) -> None:
"""For tests, create a file to indicate that we've started
Do not rely on this except in our own tests!
"""
fn = os.environ.get("JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE")
if fn is not None:
with open(fn, "wb"):
pass
def start(self) -> None:
self.log.info("Starting kernel %r", self.kernel_name)
try:
self.km.start_kernel()
self.log_connection_info()
self.setup_signals()
self.loop.start()
finally:
self.km.cleanup_resources()
main = KernelApp.launch_instance

View File

@@ -0,0 +1,447 @@
"""Tools for managing kernel specs"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import io
import json
import os
import re
import shutil
import warnings
from jupyter_core.paths import jupyter_data_dir
from jupyter_core.paths import jupyter_path
from jupyter_core.paths import SYSTEM_JUPYTER_PATH
from traitlets import Bool
from traitlets import CaselessStrEnum
from traitlets import Dict
from traitlets import HasTraits
from traitlets import List
from traitlets import observe
from traitlets import Set
from traitlets import Type
from traitlets import Unicode
from traitlets.config import LoggingConfigurable
from .provisioning import KernelProvisionerFactory as KPF
pjoin = os.path.join
NATIVE_KERNEL_NAME = "python3"
class KernelSpec(HasTraits):
argv = List()
name = Unicode()
mimetype = Unicode()
display_name = Unicode()
language = Unicode()
env = Dict()
resource_dir = Unicode()
interrupt_mode = CaselessStrEnum(["message", "signal"], default_value="signal")
metadata = Dict()
@classmethod
def from_resource_dir(cls, resource_dir):
"""Create a KernelSpec object by reading kernel.json
Pass the path to the *directory* containing kernel.json.
"""
kernel_file = pjoin(resource_dir, "kernel.json")
with io.open(kernel_file, "r", encoding="utf-8") as f:
kernel_dict = json.load(f)
return cls(resource_dir=resource_dir, **kernel_dict)
def to_dict(self):
d = dict(
argv=self.argv,
env=self.env,
display_name=self.display_name,
language=self.language,
interrupt_mode=self.interrupt_mode,
metadata=self.metadata,
)
return d
def to_json(self):
"""Serialise this kernelspec to a JSON object.
Returns a string.
"""
return json.dumps(self.to_dict())
_kernel_name_pat = re.compile(r"^[a-z0-9._\-]+$", re.IGNORECASE)
def _is_valid_kernel_name(name):
"""Check that a kernel name is valid."""
# quote is not unicode-safe on Python 2
return _kernel_name_pat.match(name)
_kernel_name_description = (
"Kernel names can only contain ASCII letters and numbers and these separators:"
" - . _ (hyphen, period, and underscore)."
)
def _is_kernel_dir(path):
"""Is ``path`` a kernel directory?"""
return os.path.isdir(path) and os.path.isfile(pjoin(path, "kernel.json"))
def _list_kernels_in(dir):
"""Return a mapping of kernel names to resource directories from dir.
If dir is None or does not exist, returns an empty dict.
"""
if dir is None or not os.path.isdir(dir):
return {}
kernels = {}
for f in os.listdir(dir):
path = pjoin(dir, f)
if not _is_kernel_dir(path):
continue
key = f.lower()
if not _is_valid_kernel_name(key):
warnings.warn(
"Invalid kernelspec directory name (%s): %s" % (_kernel_name_description, path),
stacklevel=3,
)
kernels[key] = path
return kernels
class NoSuchKernel(KeyError):
def __init__(self, name):
self.name = name
def __str__(self):
return "No such kernel named {}".format(self.name)
class KernelSpecManager(LoggingConfigurable):
kernel_spec_class = Type(
KernelSpec,
config=True,
help="""The kernel spec class. This is configurable to allow
subclassing of the KernelSpecManager for customized behavior.
""",
)
ensure_native_kernel = Bool(
True,
config=True,
help="""If there is no Python kernelspec registered and the IPython
kernel is available, ensure it is added to the spec list.
""",
)
data_dir = Unicode()
def _data_dir_default(self):
return jupyter_data_dir()
user_kernel_dir = Unicode()
def _user_kernel_dir_default(self):
return pjoin(self.data_dir, "kernels")
whitelist = Set(
config=True,
help="""Deprecated, use `KernelSpecManager.allowed_kernelspecs`
""",
)
allowed_kernelspecs = Set(
config=True,
help="""List of allowed kernel names.
By default, all installed kernels are allowed.
""",
)
kernel_dirs = List(
help="List of kernel directories to search. Later ones take priority over earlier."
)
_deprecated_aliases = {
"whitelist": ("allowed_kernelspecs", "7.0"),
}
# Method copied from
# https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
@observe(*list(_deprecated_aliases))
def _deprecated_trait(self, change):
"""observer for deprecated traits"""
old_attr = change.name
new_attr, version = self._deprecated_aliases[old_attr]
new_value = getattr(self, new_attr)
if new_value != change.new:
# only warn if different
# protects backward-compatible config from warnings
# if they set the same value under both names
self.log.warning(
(
"{cls}.{old} is deprecated in jupyter_client "
"{version}, use {cls}.{new} instead"
).format(
cls=self.__class__.__name__,
old=old_attr,
new=new_attr,
version=version,
)
)
setattr(self, new_attr, change.new)
def _kernel_dirs_default(self):
dirs = jupyter_path("kernels")
# At some point, we should stop adding .ipython/kernels to the path,
# but the cost to keeping it is very small.
try:
from IPython.paths import get_ipython_dir # type: ignore
except ImportError:
try:
from IPython.utils.path import get_ipython_dir # type: ignore
except ImportError:
# no IPython, no ipython dir
get_ipython_dir = None
if get_ipython_dir is not None:
dirs.append(os.path.join(get_ipython_dir(), "kernels"))
return dirs
def find_kernel_specs(self):
"""Returns a dict mapping kernel names to resource directories."""
d = {}
for kernel_dir in self.kernel_dirs:
kernels = _list_kernels_in(kernel_dir)
for kname, spec in kernels.items():
if kname not in d:
self.log.debug("Found kernel %s in %s", kname, kernel_dir)
d[kname] = spec
if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
try:
from ipykernel.kernelspec import RESOURCES # type: ignore
self.log.debug(
"Native kernel (%s) available from %s",
NATIVE_KERNEL_NAME,
RESOURCES,
)
d[NATIVE_KERNEL_NAME] = RESOURCES
except ImportError:
self.log.warning("Native kernel (%s) is not available", NATIVE_KERNEL_NAME)
if self.allowed_kernelspecs:
# filter if there's an allow list
d = {name: spec for name, spec in d.items() if name in self.allowed_kernelspecs}
return d
# TODO: Caching?
def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
"""Returns a :class:`KernelSpec` instance for a given kernel_name
and resource_dir.
"""
kspec = None
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES, get_kernel_dict
except ImportError:
# It should be impossible to reach this, but let's play it safe
pass
else:
if resource_dir == RESOURCES:
kspec = self.kernel_spec_class(resource_dir=resource_dir, **get_kernel_dict())
if not kspec:
kspec = self.kernel_spec_class.from_resource_dir(resource_dir)
if not KPF.instance(parent=self.parent).is_provisioner_available(kspec):
raise NoSuchKernel(kernel_name)
return kspec
def _find_spec_directory(self, kernel_name):
"""Find the resource directory of a named kernel spec"""
for kernel_dir in [kd for kd in self.kernel_dirs if os.path.isdir(kd)]:
files = os.listdir(kernel_dir)
for f in files:
path = pjoin(kernel_dir, f)
if f.lower() == kernel_name and _is_kernel_dir(path):
return path
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES
except ImportError:
pass
else:
return RESOURCES
def get_kernel_spec(self, kernel_name):
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises :exc:`NoSuchKernel` if the given kernel name is not found.
"""
if not _is_valid_kernel_name(kernel_name):
self.log.warning(
f"Kernelspec name {kernel_name} is invalid: {_kernel_name_description}"
)
resource_dir = self._find_spec_directory(kernel_name.lower())
if resource_dir is None:
self.log.warning(f"Kernelspec name {kernel_name} cannot be found!")
raise NoSuchKernel(kernel_name)
return self._get_kernel_spec_by_name(kernel_name, resource_dir)
def get_all_specs(self):
"""Returns a dict mapping kernel names to kernelspecs.
Returns a dict of the form::
{
'kernel_name': {
'resource_dir': '/path/to/kernel_name',
'spec': {"the spec itself": ...}
},
...
}
"""
d = self.find_kernel_specs()
res = {}
for kname, resource_dir in d.items():
try:
if self.__class__ is KernelSpecManager:
spec = self._get_kernel_spec_by_name(kname, resource_dir)
else:
# avoid calling private methods in subclasses,
# which may have overridden find_kernel_specs
# and get_kernel_spec, but not the newer get_all_specs
spec = self.get_kernel_spec(kname)
res[kname] = {"resource_dir": resource_dir, "spec": spec.to_dict()}
except NoSuchKernel:
pass # The appropriate warning has already been logged
except Exception:
self.log.warning("Error loading kernelspec %r", kname, exc_info=True)
return res
def remove_kernel_spec(self, name):
"""Remove a kernel spec directory by name.
Returns the path that was deleted.
"""
save_native = self.ensure_native_kernel
try:
self.ensure_native_kernel = False
specs = self.find_kernel_specs()
finally:
self.ensure_native_kernel = save_native
spec_dir = specs[name]
self.log.debug("Removing %s", spec_dir)
if os.path.islink(spec_dir):
os.remove(spec_dir)
else:
shutil.rmtree(spec_dir)
return spec_dir
def _get_destination_dir(self, kernel_name, user=False, prefix=None):
if user:
return os.path.join(self.user_kernel_dir, kernel_name)
elif prefix:
return os.path.join(os.path.abspath(prefix), "share", "jupyter", "kernels", kernel_name)
else:
return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name)
def install_kernel_spec(
self, source_dir, kernel_name=None, user=False, replace=None, prefix=None
):
"""Install a kernel spec by copying its directory.
If ``kernel_name`` is not given, the basename of ``source_dir`` will
be used.
If ``user`` is False, it will attempt to install into the systemwide
kernel registry. If the process does not have appropriate permissions,
an :exc:`OSError` will be raised.
If ``prefix`` is given, the kernelspec will be installed to
PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
for installation inside virtual or conda envs.
"""
source_dir = source_dir.rstrip("/\\")
if not kernel_name:
kernel_name = os.path.basename(source_dir)
kernel_name = kernel_name.lower()
if not _is_valid_kernel_name(kernel_name):
raise ValueError(
"Invalid kernel name %r. %s" % (kernel_name, _kernel_name_description)
)
if user and prefix:
raise ValueError("Can't specify both user and prefix. Please choose one or the other.")
if replace is not None:
warnings.warn(
"replace is ignored. Installing a kernelspec always replaces an existing "
"installation",
DeprecationWarning,
stacklevel=2,
)
destination = self._get_destination_dir(kernel_name, user=user, prefix=prefix)
self.log.debug("Installing kernelspec in %s", destination)
kernel_dir = os.path.dirname(destination)
if kernel_dir not in self.kernel_dirs:
self.log.warning(
"Installing to %s, which is not in %s. The kernelspec may not be found.",
kernel_dir,
self.kernel_dirs,
)
if os.path.isdir(destination):
self.log.info("Removing existing kernelspec in %s", destination)
shutil.rmtree(destination)
shutil.copytree(source_dir, destination)
self.log.info("Installed kernelspec %s in %s", kernel_name, destination)
return destination
def install_native_kernel_spec(self, user=False):
"""DEPRECATED: Use ipykernel.kernelspec.install"""
warnings.warn(
"install_native_kernel_spec is deprecated. Use ipykernel.kernelspec import install.",
stacklevel=2,
)
from ipykernel.kernelspec import install
install(self, user=user)
def find_kernel_specs():
"""Returns a dict mapping kernel names to resource directories."""
return KernelSpecManager().find_kernel_specs()
def get_kernel_spec(kernel_name):
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises KeyError if the given kernel name is not found.
"""
return KernelSpecManager().get_kernel_spec(kernel_name)
def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False, prefix=None):
return KernelSpecManager().install_kernel_spec(source_dir, kernel_name, user, replace, prefix)
install_kernel_spec.__doc__ = KernelSpecManager.install_kernel_spec.__doc__
def install_native_kernel_spec(user=False):
return KernelSpecManager().install_native_kernel_spec(user=user)
install_native_kernel_spec.__doc__ = KernelSpecManager.install_native_kernel_spec.__doc__

View File

@@ -0,0 +1,328 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import json
import os.path
import sys
import typing as t
from jupyter_core.application import base_aliases
from jupyter_core.application import base_flags
from jupyter_core.application import JupyterApp
from traitlets import Bool
from traitlets import Dict
from traitlets import Instance
from traitlets import List
from traitlets import Unicode
from traitlets.config.application import Application
from . import __version__
from .kernelspec import KernelSpecManager
from .provisioning.factory import KernelProvisionerFactory
class ListKernelSpecs(JupyterApp):
version = __version__
description = """List installed kernel specifications."""
kernel_spec_manager = Instance(KernelSpecManager)
json_output = Bool(
False,
help="output spec name and location as machine-readable json.",
config=True,
)
flags = {
"json": (
{"ListKernelSpecs": {"json_output": True}},
"output spec name and location as machine-readable json.",
),
"debug": base_flags["debug"],
}
def _kernel_spec_manager_default(self):
return KernelSpecManager(parent=self, data_dir=self.data_dir)
def start(self):
paths = self.kernel_spec_manager.find_kernel_specs()
specs = self.kernel_spec_manager.get_all_specs()
if not self.json_output:
if not specs:
print("No kernels available")
return
# pad to width of longest kernel name
name_len = len(sorted(paths, key=lambda name: len(name))[-1])
def path_key(item):
"""sort key function for Jupyter path priority"""
path = item[1]
for idx, prefix in enumerate(self.jupyter_path):
if path.startswith(prefix):
return (idx, path)
# not in jupyter path, artificially added to the front
return (-1, path)
print("Available kernels:")
for kernelname, path in sorted(paths.items(), key=path_key):
print(" %s %s" % (kernelname.ljust(name_len), path))
else:
print(json.dumps({"kernelspecs": specs}, indent=2))
class InstallKernelSpec(JupyterApp):
version = __version__
description = """Install a kernel specification directory.
Given a SOURCE DIRECTORY containing a kernel spec,
jupyter will copy that directory into one of the Jupyter kernel directories.
The default is to install kernelspecs for all users.
`--user` can be specified to install a kernel only for the current user.
"""
examples = """
jupyter kernelspec install /path/to/my_kernel --user
"""
usage = "jupyter kernelspec install SOURCE_DIR [--options]"
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir)
sourcedir = Unicode()
kernel_name = Unicode("", config=True, help="Install the kernel spec with this name")
def _kernel_name_default(self):
return os.path.basename(self.sourcedir)
user = Bool(
False,
config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
""",
)
prefix = Unicode(
"",
config=True,
help="""Specify a prefix to install to, e.g. an env.
The kernelspec will be installed in PREFIX/share/jupyter/kernels/
""",
)
replace = Bool(False, config=True, help="Replace any existing kernel spec with this name.")
aliases = {
"name": "InstallKernelSpec.kernel_name",
"prefix": "InstallKernelSpec.prefix",
}
aliases.update(base_aliases)
flags = {
"user": (
{"InstallKernelSpec": {"user": True}},
"Install to the per-user kernel registry",
),
"replace": (
{"InstallKernelSpec": {"replace": True}},
"Replace any existing kernel spec with this name.",
),
"sys-prefix": (
{"InstallKernelSpec": {"prefix": sys.prefix}},
"Install to Python's sys.prefix. Useful in conda/virtual environments.",
),
"debug": base_flags["debug"],
}
def parse_command_line(self, argv):
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.sourcedir = self.extra_args[0]
else:
print("No source directory specified.")
self.exit(1)
def start(self):
if self.user and self.prefix:
self.exit("Can't specify both user and prefix. Please choose one or the other.")
try:
self.kernel_spec_manager.install_kernel_spec(
self.sourcedir,
kernel_name=self.kernel_name,
user=self.user,
prefix=self.prefix,
replace=self.replace,
)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print(
"Perhaps you want to install with `sudo` or `--user`?",
file=sys.stderr,
)
self.exit(1)
elif e.errno == errno.EEXIST:
print(
"A kernel spec is already present at %s" % e.filename,
file=sys.stderr,
)
self.exit(1)
raise
class RemoveKernelSpec(JupyterApp):
version = __version__
description = """Remove one or more Jupyter kernelspecs by name."""
examples = """jupyter kernelspec remove python2 [my_kernel ...]"""
force = Bool(False, config=True, help="""Force removal, don't prompt for confirmation.""")
spec_names = List(Unicode())
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir, parent=self)
flags = {
"f": ({"RemoveKernelSpec": {"force": True}}, force.help),
}
flags.update(JupyterApp.flags)
def parse_command_line(self, argv):
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.spec_names = sorted(set(self.extra_args)) # remove duplicates
else:
self.exit("No kernelspec specified.")
def start(self):
self.kernel_spec_manager.ensure_native_kernel = False
spec_paths = self.kernel_spec_manager.find_kernel_specs()
missing = set(self.spec_names).difference(set(spec_paths))
if missing:
self.exit("Couldn't find kernel spec(s): %s" % ", ".join(missing))
if not (self.force or self.answer_yes):
print("Kernel specs to remove:")
for name in self.spec_names:
print(" %s\t%s" % (name.ljust(20), spec_paths[name]))
answer = input("Remove %i kernel specs [y/N]: " % len(self.spec_names))
if not answer.lower().startswith("y"):
return
for kernel_name in self.spec_names:
try:
path = self.kernel_spec_manager.remove_kernel_spec(kernel_name)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
print("Perhaps you want sudo?", file=sys.stderr)
self.exit(1)
else:
raise
self.log.info("Removed %s", path)
class InstallNativeKernelSpec(JupyterApp):
version = __version__
description = """[DEPRECATED] Install the IPython kernel spec directory for this Python."""
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir)
user = Bool(
False,
config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
""",
)
flags = {
"user": (
{"InstallNativeKernelSpec": {"user": True}},
"Install to the per-user kernel registry",
),
"debug": base_flags["debug"],
}
def start(self):
self.log.warning(
"`jupyter kernelspec install-self` is DEPRECATED as of 4.0."
" You probably want `ipython kernel install` to install the IPython kernelspec."
)
try:
from ipykernel import kernelspec
except ImportError:
print("ipykernel not available, can't install its spec.", file=sys.stderr)
self.exit(1)
try:
kernelspec.install(self.kernel_spec_manager, user=self.user)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print(
"Perhaps you want to install with `sudo` or `--user`?",
file=sys.stderr,
)
self.exit(1)
self.exit(e)
class ListProvisioners(JupyterApp):
version = __version__
description = """List available provisioners for use in kernel specifications."""
def start(self):
kfp = KernelProvisionerFactory.instance(parent=self)
print("Available kernel provisioners:")
provisioners = kfp.get_provisioner_entries()
# pad to width of longest kernel name
name_len = len(sorted(provisioners, key=lambda name: len(name))[-1])
for name in sorted(provisioners):
print(f" {name.ljust(name_len)} {provisioners[name]}")
class KernelSpecApp(Application):
version = __version__
name = "jupyter kernelspec"
description = """Manage Jupyter kernel specifications."""
subcommands = Dict(
{
"list": (ListKernelSpecs, ListKernelSpecs.description.splitlines()[0]),
"install": (
InstallKernelSpec,
InstallKernelSpec.description.splitlines()[0],
),
"uninstall": (RemoveKernelSpec, "Alias for remove"),
"remove": (RemoveKernelSpec, RemoveKernelSpec.description.splitlines()[0]),
"install-self": (
InstallNativeKernelSpec,
InstallNativeKernelSpec.description.splitlines()[0],
),
"provisioners": (ListProvisioners, ListProvisioners.description.splitlines()[0]),
}
)
aliases: t.Dict[str, object] = {}
flags: t.Dict[str, object] = {}
def start(self):
if self.subapp is None:
print("No subcommand specified. Must specify one of: %s" % list(self.subcommands))
print()
self.print_description()
self.print_subcommands()
self.exit(1)
else:
return self.subapp.start()
if __name__ == "__main__":
KernelSpecApp.launch_instance()

View File

@@ -0,0 +1,187 @@
"""Utilities for launching kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import sys
from subprocess import PIPE
from subprocess import Popen
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from traitlets.log import get_logger
def launch_kernel(
cmd: List[str],
stdin: Optional[int] = None,
stdout: Optional[int] = None,
stderr: Optional[int] = None,
env: Optional[Dict[str, str]] = None,
independent: bool = False,
cwd: Optional[str] = None,
**kw: Any,
) -> Popen:
"""Launches a localhost kernel, binding to the specified ports.
Parameters
----------
cmd : Popen list,
A string of Python code that imports and executes a kernel entry point.
stdin, stdout, stderr : optional (default None)
Standards streams, as defined in subprocess.Popen.
env: dict, optional
Environment variables passed to the kernel
independent : bool, optional (default False)
If set, the kernel process is guaranteed to survive if this process
dies. If not set, an effort is made to ensure that the kernel is killed
when this process dies. Note that in this case it is still good practice
to kill kernels manually before exiting.
cwd : path, optional
The working dir of the kernel process (default: cwd of this process).
**kw: optional
Additional arguments for Popen
Returns
-------
Popen instance for the kernel subprocess
"""
# Popen will fail (sometimes with a deadlock) if stdin, stdout, and stderr
# are invalid. Unfortunately, there is in general no way to detect whether
# they are valid. The following two blocks redirect them to (temporary)
# pipes in certain important cases.
# If this process has been backgrounded, our stdin is invalid. Since there
# is no compelling reason for the kernel to inherit our stdin anyway, we'll
# place this one safe and always redirect.
redirect_in = True
_stdin = PIPE if stdin is None else stdin
# If this process in running on pythonw, we know that stdin, stdout, and
# stderr are all invalid.
redirect_out = sys.executable.endswith("pythonw.exe")
if redirect_out:
blackhole = open(os.devnull, "w")
_stdout = blackhole if stdout is None else stdout
_stderr = blackhole if stderr is None else stderr
else:
_stdout, _stderr = stdout, stderr
env = env if (env is not None) else os.environ.copy()
kwargs = kw.copy()
main_args = dict(
stdin=_stdin,
stdout=_stdout,
stderr=_stderr,
cwd=cwd,
env=env,
)
kwargs.update(main_args)
# Spawn a kernel.
if sys.platform == "win32":
if cwd:
kwargs["cwd"] = cwd
from .win_interrupt import create_interrupt_event
# Create a Win32 event for interrupting the kernel
# and store it in an environment variable.
interrupt_event = create_interrupt_event()
env["JPY_INTERRUPT_EVENT"] = str(interrupt_event)
# deprecated old env name:
env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"]
try:
from _winapi import (
CREATE_NEW_PROCESS_GROUP,
DUPLICATE_SAME_ACCESS,
DuplicateHandle,
GetCurrentProcess,
)
except: # noqa
from _subprocess import (
GetCurrentProcess,
CREATE_NEW_PROCESS_GROUP,
DUPLICATE_SAME_ACCESS,
DuplicateHandle,
)
# create a handle on the parent to be inherited
if independent:
kwargs["creationflags"] = CREATE_NEW_PROCESS_GROUP
else:
pid = GetCurrentProcess()
handle = DuplicateHandle(
pid,
pid,
pid,
0,
True,
DUPLICATE_SAME_ACCESS, # Inheritable by new processes.
)
env["JPY_PARENT_PID"] = str(int(handle))
# Prevent creating new console window on pythonw
if redirect_out:
kwargs["creationflags"] = (
kwargs.setdefault("creationflags", 0) | 0x08000000
) # CREATE_NO_WINDOW
# Avoid closing the above parent and interrupt handles.
# close_fds is True by default on Python >=3.7
# or when no stream is captured on Python <3.7
# (we always capture stdin, so this is already False by default on <3.7)
kwargs["close_fds"] = False
else:
# Create a new session.
# This makes it easier to interrupt the kernel,
# because we want to interrupt the whole process group.
# We don't use setpgrp, which is known to cause problems for kernels starting
# certain interactive subprocesses, such as bash -i.
kwargs["start_new_session"] = True
if not independent:
env["JPY_PARENT_PID"] = str(os.getpid())
try:
# Allow to use ~/ in the command or its arguments
cmd = [os.path.expanduser(s) for s in cmd]
proc = Popen(cmd, **kwargs)
except Exception as ex:
try:
msg = "Failed to run command:\n{}\n PATH={!r}\n with kwargs:\n{!r}\n"
# exclude environment variables,
# which may contain access tokens and the like.
without_env = {key: value for key, value in kwargs.items() if key != "env"}
msg = msg.format(cmd, env.get("PATH", os.defpath), without_env)
get_logger().error(msg)
except Exception as ex2: # Don't let a formatting/logger issue lead to the wrong exception
print(f"Failed to run command: '{cmd}' due to exception: {ex}")
print(f"The following exception occurred handling the previous failure: {ex2}")
raise ex
if sys.platform == "win32":
# Attach the interrupt event to the Popen objet so it can be used later.
proc.win32_interrupt_event = interrupt_event
# Clean up pipes created to work around Popen bug.
if redirect_in:
if stdin is None:
assert proc.stdin is not None
proc.stdin.close()
return proc
__all__ = [
"launch_kernel",
]

View File

@@ -0,0 +1,296 @@
"""Utilities for identifying local IP addresses."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import re
import socket
import subprocess
from subprocess import PIPE
from subprocess import Popen
from typing import Iterable
from typing import List
from warnings import warn
LOCAL_IPS: List = []
PUBLIC_IPS: List = []
LOCALHOST = ""
def _uniq_stable(elems: Iterable) -> List:
"""uniq_stable(elems) -> list
Return from an iterable, a list of all the unique elements in the input,
maintaining the order in which they first appear.
"""
seen = set()
value = []
for x in elems:
if x not in seen:
value.append(x)
seen.add(x)
return value
def _get_output(cmd):
"""Get output of a command, raising IOError if it fails"""
startupinfo = None
if os.name == "nt":
startupinfo = subprocess.STARTUPINFO() # type:ignore[attr-defined]
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type:ignore[attr-defined]
p = Popen(cmd, stdout=PIPE, stderr=PIPE, startupinfo=startupinfo)
stdout, stderr = p.communicate()
if p.returncode:
raise IOError("Failed to run %s: %s" % (cmd, stderr.decode("utf8", "replace")))
return stdout.decode("utf8", "replace")
def _only_once(f):
"""decorator to only run a function once"""
f.called = False
def wrapped(**kwargs):
if f.called:
return
ret = f(**kwargs)
f.called = True
return ret
return wrapped
def _requires_ips(f):
"""decorator to ensure load_ips has been run before f"""
def ips_loaded(*args, **kwargs):
_load_ips()
return f(*args, **kwargs)
return ips_loaded
# subprocess-parsing ip finders
class NoIPAddresses(Exception):
pass
def _populate_from_list(addrs):
"""populate local and public IPs from flat list of all IPs"""
if not addrs:
raise NoIPAddresses
global LOCALHOST
public_ips = []
local_ips = []
for ip in addrs:
local_ips.append(ip)
if not ip.startswith("127."):
public_ips.append(ip)
elif not LOCALHOST:
LOCALHOST = ip
if not LOCALHOST or LOCALHOST == "127.0.0.1":
LOCALHOST = "127.0.0.1"
local_ips.insert(0, LOCALHOST)
local_ips.extend(["0.0.0.0", ""])
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
_ifconfig_ipv4_pat = re.compile(r"inet\b.*?(\d+\.\d+\.\d+\.\d+)", re.IGNORECASE)
def _load_ips_ifconfig():
"""load ip addresses from `ifconfig` output (posix)"""
try:
out = _get_output("ifconfig")
except OSError:
# no ifconfig, it's usually in /sbin and /sbin is not on everyone's PATH
out = _get_output("/sbin/ifconfig")
lines = out.splitlines()
addrs = []
for line in lines:
m = _ifconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_ip():
"""load ip addresses from `ip addr` output (Linux)"""
out = _get_output(["ip", "-f", "inet", "addr"])
lines = out.splitlines()
addrs = []
for line in lines:
blocks = line.lower().split()
if (len(blocks) >= 2) and (blocks[0] == "inet"):
addrs.append(blocks[1].split("/")[0])
_populate_from_list(addrs)
_ipconfig_ipv4_pat = re.compile(r"ipv4.*?(\d+\.\d+\.\d+\.\d+)$", re.IGNORECASE)
def _load_ips_ipconfig():
"""load ip addresses from `ipconfig` output (Windows)"""
out = _get_output("ipconfig")
lines = out.splitlines()
addrs = []
for line in lines:
m = _ipconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_netifaces():
"""load ip addresses with netifaces"""
import netifaces # type: ignore
global LOCALHOST
local_ips = []
public_ips = []
# list of iface names, 'lo0', 'eth0', etc.
for iface in netifaces.interfaces():
# list of ipv4 addrinfo dicts
ipv4s = netifaces.ifaddresses(iface).get(netifaces.AF_INET, [])
for entry in ipv4s:
addr = entry.get("addr")
if not addr:
continue
if not (iface.startswith("lo") or addr.startswith("127.")):
public_ips.append(addr)
elif not LOCALHOST:
LOCALHOST = addr
local_ips.append(addr)
if not LOCALHOST:
# we never found a loopback interface (can this ever happen?), assume common default
LOCALHOST = "127.0.0.1"
local_ips.insert(0, LOCALHOST)
local_ips.extend(["0.0.0.0", ""])
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
def _load_ips_gethostbyname():
"""load ip addresses with socket.gethostbyname_ex
This can be slow.
"""
global LOCALHOST
try:
LOCAL_IPS[:] = socket.gethostbyname_ex("localhost")[2]
except socket.error:
# assume common default
LOCAL_IPS[:] = ["127.0.0.1"]
try:
hostname = socket.gethostname()
PUBLIC_IPS[:] = socket.gethostbyname_ex(hostname)[2]
# try hostname.local, in case hostname has been short-circuited to loopback
if not hostname.endswith(".local") and all(ip.startswith("127") for ip in PUBLIC_IPS):
PUBLIC_IPS[:] = socket.gethostbyname_ex(socket.gethostname() + ".local")[2]
except socket.error:
pass
finally:
PUBLIC_IPS[:] = _uniq_stable(PUBLIC_IPS)
LOCAL_IPS.extend(PUBLIC_IPS)
# include all-interface aliases: 0.0.0.0 and ''
LOCAL_IPS.extend(["0.0.0.0", ""])
LOCAL_IPS[:] = _uniq_stable(LOCAL_IPS)
LOCALHOST = LOCAL_IPS[0]
def _load_ips_dumb():
"""Fallback in case of unexpected failure"""
global LOCALHOST
LOCALHOST = "127.0.0.1"
LOCAL_IPS[:] = [LOCALHOST, "0.0.0.0", ""]
PUBLIC_IPS[:] = []
@_only_once
def _load_ips(suppress_exceptions=True):
"""load the IPs that point to this machine
This function will only ever be called once.
It will use netifaces to do it quickly if available.
Then it will fallback on parsing the output of ifconfig / ip addr / ipconfig, as appropriate.
Finally, it will fallback on socket.gethostbyname_ex, which can be slow.
"""
try:
# first priority, use netifaces
try:
return _load_ips_netifaces()
except ImportError:
pass
# second priority, parse subprocess output (how reliable is this?)
if os.name == "nt":
try:
return _load_ips_ipconfig()
except (OSError, NoIPAddresses):
pass
else:
try:
return _load_ips_ip()
except (OSError, NoIPAddresses):
pass
try:
return _load_ips_ifconfig()
except (OSError, NoIPAddresses):
pass
# lowest priority, use gethostbyname
return _load_ips_gethostbyname()
except Exception as e:
if not suppress_exceptions:
raise
# unexpected error shouldn't crash, load dumb default values instead.
warn("Unexpected error discovering local network interfaces: %s" % e)
_load_ips_dumb()
@_requires_ips
def local_ips():
"""return the IP addresses that point to this machine"""
return LOCAL_IPS
@_requires_ips
def public_ips():
"""return the IP addresses for this machine that are visible to other machines"""
return PUBLIC_IPS
@_requires_ips
def localhost():
"""return ip for localhost (almost always 127.0.0.1)"""
return LOCALHOST
@_requires_ips
def is_local_ip(ip):
"""does `ip` point to this machine?"""
return ip in LOCAL_IPS
@_requires_ips
def is_public_ip(ip):
"""is `ip` a publicly visible address?"""
return ip in PUBLIC_IPS

View File

@@ -0,0 +1,714 @@
"""Base class to manage a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import functools
import os
import re
import signal
import sys
import typing as t
import uuid
from asyncio.futures import Future
from concurrent.futures import Future as CFuture
from contextlib import contextmanager
from enum import Enum
import zmq
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import DottedObjectName
from traitlets import Float
from traitlets import Instance
from traitlets import observe
from traitlets import observe_compat
from traitlets import Type
from traitlets import Unicode
from traitlets.utils.importstring import import_item
from .connect import ConnectionFileMixin
from .managerabc import KernelManagerABC
from .provisioning import KernelProvisionerBase
from .provisioning import KernelProvisionerFactory as KPF
from .utils import ensure_async
from .utils import run_sync
from jupyter_client import KernelClient
from jupyter_client import kernelspec
class _ShutdownStatus(Enum):
"""
This is so far used only for testing in order to track the internal state of
the shutdown logic, and verifying which path is taken for which
missbehavior.
"""
Unset = None
ShutdownRequest = "ShutdownRequest"
SigtermRequest = "SigtermRequest"
SigkillRequest = "SigkillRequest"
F = t.TypeVar('F', bound=t.Callable[..., t.Any])
def in_pending_state(method: F) -> F:
"""Sets the kernel to a pending state by
creating a fresh Future for the KernelManager's `ready`
attribute. Once the method is finished, set the Future's results.
"""
@t.no_type_check
@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
# Create a future for the decorated method
try:
self._ready = Future()
except RuntimeError:
# No event loop running, use concurrent future
self._ready = CFuture()
try:
# call wrapped method, await, and set the result or exception.
out = await method(self, *args, **kwargs)
# Add a small sleep to ensure tests can capture the state before done
await asyncio.sleep(0.01)
self._ready.set_result(None)
return out
except Exception as e:
self._ready.set_exception(e)
self.log.exception(self._ready.exception())
raise e
return t.cast(F, wrapper)
class KernelManager(ConnectionFileMixin):
"""Manages a single kernel in a subprocess on this host.
This version starts kernels with Popen.
"""
_ready: t.Union[Future, CFuture]
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self._shutdown_status = _ShutdownStatus.Unset
# Create a place holder future.
try:
asyncio.get_running_loop()
self._ready = Future()
except RuntimeError:
# No event loop running, use concurrent future
self._ready = CFuture()
_created_context: Bool = Bool(False)
# The PyZMQ Context to use for communication with the kernel.
context: Instance = Instance(zmq.Context)
@default("context") # type:ignore[misc]
def _context_default(self) -> zmq.Context:
self._created_context = True
return zmq.Context()
# the class to create with our `client` method
client_class: DottedObjectName = DottedObjectName(
"jupyter_client.blocking.BlockingKernelClient"
)
client_factory: Type = Type(klass="jupyter_client.KernelClient")
@default("client_factory") # type:ignore[misc]
def _client_factory_default(self) -> Type:
return import_item(self.client_class)
@observe("client_class") # type:ignore[misc]
def _client_class_changed(self, change: t.Dict[str, DottedObjectName]) -> None:
self.client_factory = import_item(str(change["new"]))
kernel_id: str = Unicode(None, allow_none=True)
# The kernel provisioner with which this KernelManager is communicating.
# This will generally be a LocalProvisioner instance unless the kernelspec
# indicates otherwise.
provisioner: t.Optional[KernelProvisionerBase] = None
kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager)
@default("kernel_spec_manager") # type:ignore[misc]
def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager:
return kernelspec.KernelSpecManager(data_dir=self.data_dir)
@observe("kernel_spec_manager") # type:ignore[misc]
@observe_compat # type:ignore[misc]
def _kernel_spec_manager_changed(self, change: t.Dict[str, Instance]) -> None:
self._kernel_spec = None
shutdown_wait_time: Float = Float(
5.0,
config=True,
help="Time to wait for a kernel to terminate before killing it, "
"in seconds. When a shutdown request is initiated, the kernel "
"will be immediately sent an interrupt (SIGINT), followed"
"by a shutdown_request message, after 1/2 of `shutdown_wait_time`"
"it will be sent a terminate (SIGTERM) request, and finally at "
"the end of `shutdown_wait_time` will be killed (SIGKILL). terminate "
"and kill may be equivalent on windows. Note that this value can be"
"overridden by the in-use kernel provisioner since shutdown times may"
"vary by provisioned environment.",
)
kernel_name: Unicode = Unicode(kernelspec.NATIVE_KERNEL_NAME)
@observe("kernel_name") # type:ignore[misc]
def _kernel_name_changed(self, change: t.Dict[str, Unicode]) -> None:
self._kernel_spec = None
if change["new"] == "python":
self.kernel_name = kernelspec.NATIVE_KERNEL_NAME
_kernel_spec: t.Optional[kernelspec.KernelSpec] = None
@property
def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]:
if self._kernel_spec is None and self.kernel_name != "":
self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name)
return self._kernel_spec
cache_ports: Bool = Bool(
help="True if the MultiKernelManager should cache ports for this KernelManager instance"
)
@default("cache_ports") # type:ignore[misc]
def _default_cache_ports(self) -> bool:
return self.transport == "tcp"
@property
def ready(self) -> t.Union[CFuture, Future]:
"""A future that resolves when the kernel process has started for the first time"""
return self._ready
@property
def ipykernel(self) -> bool:
return self.kernel_name in {"python", "python2", "python3"}
# Protected traits
_launch_args: Any = Any()
_control_socket: Any = Any()
_restarter: Any = Any()
autorestart: Bool = Bool(
True, config=True, help="""Should we autorestart the kernel if it dies."""
)
shutting_down: bool = False
def __del__(self) -> None:
self._close_control_socket()
self.cleanup_connection_file()
# --------------------------------------------------------------------------
# Kernel restarter
# --------------------------------------------------------------------------
def start_restarter(self) -> None:
pass
def stop_restarter(self) -> None:
pass
def add_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
"""register a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.add_callback(callback, event)
def remove_restart_callback(self, callback: t.Callable, event: str = "restart") -> None:
"""unregister a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.remove_callback(callback, event)
# --------------------------------------------------------------------------
# create a Client connected to our Kernel
# --------------------------------------------------------------------------
def client(self, **kwargs: Any) -> KernelClient:
"""Create a client configured to connect to our kernel"""
kw = {}
kw.update(self.get_connection_info(session=True))
kw.update(
dict(
connection_file=self.connection_file,
parent=self,
)
)
# add kwargs last, for manual overrides
kw.update(kwargs)
return self.client_factory(**kw)
# --------------------------------------------------------------------------
# Kernel management
# --------------------------------------------------------------------------
def format_kernel_cmd(self, extra_arguments: t.Optional[t.List[str]] = None) -> t.List[str]:
"""replace templated args (e.g. {connection_file})"""
extra_arguments = extra_arguments or []
assert self.kernel_spec is not None
cmd = self.kernel_spec.argv + extra_arguments
if cmd and cmd[0] in {
"python",
"python%i" % sys.version_info[0],
"python%i.%i" % sys.version_info[:2],
}:
# executable is 'python' or 'python3', use sys.executable.
# These will typically be the same,
# but if the current process is in an env
# and has been launched by abspath without
# activating the env, python on PATH may not be sys.executable,
# but it should be.
cmd[0] = sys.executable
# Make sure to use the realpath for the connection_file
# On windows, when running with the store python, the connection_file path
# is not usable by non python kernels because the path is being rerouted when
# inside of a store app.
# See this bug here: https://bugs.python.org/issue41196
ns = dict(
connection_file=os.path.realpath(self.connection_file),
prefix=sys.prefix,
)
if self.kernel_spec:
ns["resource_dir"] = self.kernel_spec.resource_dir
ns.update(self._launch_args)
pat = re.compile(r"\{([A-Za-z0-9_]+)\}")
def from_ns(match):
"""Get the key out of ns if it's there, otherwise no change."""
return ns.get(match.group(1), match.group())
return [pat.sub(from_ns, arg) for arg in cmd]
async def _async_launch_kernel(self, kernel_cmd: t.List[str], **kw: Any) -> None:
"""actually launch the kernel
override in a subclass to launch kernel subprocesses differently
Note that provisioners can now be used to customize kernel environments
and
"""
assert self.provisioner is not None
connection_info = await self.provisioner.launch_kernel(kernel_cmd, **kw)
assert self.provisioner.has_process
# Provisioner provides the connection information. Load into kernel manager and write file.
self._force_connection_info(connection_info)
_launch_kernel = run_sync(_async_launch_kernel)
# Control socket used for polite kernel shutdown
def _connect_control_socket(self) -> None:
if self._control_socket is None:
self._control_socket = self._create_connected_socket("control")
self._control_socket.linger = 100
def _close_control_socket(self) -> None:
if self._control_socket is None:
return
self._control_socket.close()
self._control_socket = None
async def _async_pre_start_kernel(self, **kw: Any) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]:
"""Prepares a kernel for startup in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
self.shutting_down = False
self.kernel_id = self.kernel_id or kw.pop('kernel_id', str(uuid.uuid4()))
# save kwargs for use in restart
self._launch_args = kw.copy()
if self.provisioner is None: # will not be None on restarts
self.provisioner = KPF.instance(parent=self.parent).create_provisioner_instance(
self.kernel_id,
self.kernel_spec,
parent=self,
)
kw = await self.provisioner.pre_launch(**kw)
kernel_cmd = kw.pop('cmd')
return kernel_cmd, kw
pre_start_kernel = run_sync(_async_pre_start_kernel)
async def _async_post_start_kernel(self, **kw: Any) -> None:
"""Performs any post startup tasks relative to the kernel.
Parameters
----------
`**kw` : optional
keyword arguments that were used in the kernel process's launch.
"""
self.start_restarter()
self._connect_control_socket()
assert self.provisioner is not None
await self.provisioner.post_launch(**kw)
post_start_kernel = run_sync(_async_post_start_kernel)
@in_pending_state
async def _async_start_kernel(self, **kw: Any) -> None:
"""Starts a kernel on this host in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
kernel_cmd, kw = await ensure_async(self.pre_start_kernel(**kw))
# launch the kernel subprocess
self.log.debug("Starting kernel: %s", kernel_cmd)
await ensure_async(self._launch_kernel(kernel_cmd, **kw))
await ensure_async(self.post_start_kernel(**kw))
start_kernel = run_sync(_async_start_kernel)
async def _async_request_shutdown(self, restart: bool = False) -> None:
"""Send a shutdown request via control channel"""
content = dict(restart=restart)
msg = self.session.msg("shutdown_request", content=content)
# ensure control socket is connected
self._connect_control_socket()
self.session.send(self._control_socket, msg)
assert self.provisioner is not None
await self.provisioner.shutdown_requested(restart=restart)
self._shutdown_status = _ShutdownStatus.ShutdownRequest
request_shutdown = run_sync(_async_request_shutdown)
async def _async_finish_shutdown(
self,
waittime: t.Optional[float] = None,
pollinterval: float = 0.1,
restart: t.Optional[bool] = False,
) -> None:
"""Wait for kernel shutdown, then kill process if it doesn't shutdown.
This does not send shutdown requests - use :meth:`request_shutdown`
first.
"""
if waittime is None:
waittime = max(self.shutdown_wait_time, 0)
if self.provisioner: # Allow provisioner to override
waittime = self.provisioner.get_shutdown_wait_time(recommended=waittime)
try:
await asyncio.wait_for(
self._async_wait(pollinterval=pollinterval), timeout=waittime / 2
)
except asyncio.TimeoutError:
self.log.debug("Kernel is taking too long to finish, terminating")
self._shutdown_status = _ShutdownStatus.SigtermRequest
await ensure_async(self._send_kernel_sigterm())
try:
await asyncio.wait_for(
self._async_wait(pollinterval=pollinterval), timeout=waittime / 2
)
except asyncio.TimeoutError:
self.log.debug("Kernel is taking too long to finish, killing")
self._shutdown_status = _ShutdownStatus.SigkillRequest
await ensure_async(self._kill_kernel(restart=restart))
else:
# Process is no longer alive, wait and clear
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.wait()
finish_shutdown = run_sync(_async_finish_shutdown)
async def _async_cleanup_resources(self, restart: bool = False) -> None:
"""Clean up resources when the kernel is shut down"""
if not restart:
self.cleanup_connection_file()
self.cleanup_ipc_files()
self._close_control_socket()
self.session.parent = None
if self._created_context and not restart:
self.context.destroy(linger=100)
if self.provisioner:
await self.provisioner.cleanup(restart=restart)
cleanup_resources = run_sync(_async_cleanup_resources)
@in_pending_state
async def _async_shutdown_kernel(self, now: bool = False, restart: bool = False) -> None:
"""Attempts to stop the kernel process cleanly.
This attempts to shutdown the kernels cleanly by:
1. Sending it a shutdown message over the control channel.
2. If that fails, the kernel is shutdown forcibly by sending it
a signal.
Parameters
----------
now : bool
Should the kernel be forcible killed *now*. This skips the
first, nice shutdown attempt.
restart: bool
Will this kernel be restarted after it is shutdown. When this
is True, connection files will not be cleaned up.
"""
self.shutting_down = True # Used by restarter to prevent race condition
# Stop monitoring for restarting while we shutdown.
self.stop_restarter()
if self.has_kernel:
await ensure_async(self.interrupt_kernel())
if now:
await ensure_async(self._kill_kernel())
else:
await ensure_async(self.request_shutdown(restart=restart))
# Don't send any additional kernel kill messages immediately, to give
# the kernel a chance to properly execute shutdown actions. Wait for at
# most 1s, checking every 0.1s.
await ensure_async(self.finish_shutdown(restart=restart))
await ensure_async(self.cleanup_resources(restart=restart))
shutdown_kernel = run_sync(_async_shutdown_kernel)
async def _async_restart_kernel(
self, now: bool = False, newports: bool = False, **kw: Any
) -> None:
"""Restarts a kernel with the arguments that were used to launch it.
Parameters
----------
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
newports : bool, optional
If the old kernel was launched with random ports, this flag decides
whether the same ports and connection file will be used again.
If False, the same ports and connection file are used. This is
the default. If True, new random port numbers are chosen and a
new connection file is written. It is still possible that the newly
chosen random port numbers happen to be the same as the old ones.
`**kw` : optional
Any options specified here will overwrite those used to launch the
kernel.
"""
if self._launch_args is None:
raise RuntimeError("Cannot restart the kernel. No previous call to 'start_kernel'.")
# Stop currently running kernel.
await ensure_async(self.shutdown_kernel(now=now, restart=True))
if newports:
self.cleanup_random_ports()
# Start new kernel.
self._launch_args.update(kw)
await ensure_async(self.start_kernel(**self._launch_args))
restart_kernel = run_sync(_async_restart_kernel)
@property
def has_kernel(self) -> bool:
"""Has a kernel process been started that we are actively managing."""
return self.provisioner is not None and self.provisioner.has_process
async def _async_send_kernel_sigterm(self, restart: bool = False) -> None:
"""similar to _kill_kernel, but with sigterm (not sigkill), but do not block"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.terminate(restart=restart)
_send_kernel_sigterm = run_sync(_async_send_kernel_sigterm)
async def _async_kill_kernel(self, restart: bool = False) -> None:
"""Kill the running kernel.
This is a private method, callers should use shutdown_kernel(now=True).
"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.kill(restart=restart)
# Wait until the kernel terminates.
try:
await asyncio.wait_for(self._async_wait(), timeout=5.0)
except asyncio.TimeoutError:
# Wait timed out, just log warning but continue - not much more we can do.
self.log.warning("Wait for final termination of kernel timed out - continuing...")
pass
else:
# Process is no longer alive, wait and clear
if self.has_kernel:
await self.provisioner.wait()
_kill_kernel = run_sync(_async_kill_kernel)
async def _async_interrupt_kernel(self) -> None:
"""Interrupts the kernel by sending it a signal.
Unlike ``signal_kernel``, this operation is well supported on all
platforms.
"""
if self.has_kernel:
assert self.kernel_spec is not None
interrupt_mode = self.kernel_spec.interrupt_mode
if interrupt_mode == "signal":
await ensure_async(self.signal_kernel(signal.SIGINT))
elif interrupt_mode == "message":
msg = self.session.msg("interrupt_request", content={})
self._connect_control_socket()
self.session.send(self._control_socket, msg)
else:
raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
interrupt_kernel = run_sync(_async_interrupt_kernel)
async def _async_signal_kernel(self, signum: int) -> None:
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, this function is
only useful on Unix systems.
"""
if self.has_kernel:
assert self.provisioner is not None
await self.provisioner.send_signal(signum)
else:
raise RuntimeError("Cannot signal kernel. No kernel is running!")
signal_kernel = run_sync(_async_signal_kernel)
async def _async_is_alive(self) -> bool:
"""Is the kernel process still running?"""
if self.has_kernel:
assert self.provisioner is not None
ret = await self.provisioner.poll()
if ret is None:
return True
return False
is_alive = run_sync(_async_is_alive)
async def _async_wait(self, pollinterval: float = 0.1) -> None:
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see _kill_kernel()).
while await ensure_async(self.is_alive()):
await asyncio.sleep(pollinterval)
class AsyncKernelManager(KernelManager):
# the class to create with our `client` method
client_class: DottedObjectName = DottedObjectName(
"jupyter_client.asynchronous.AsyncKernelClient"
)
client_factory: Type = Type(klass="jupyter_client.asynchronous.AsyncKernelClient")
_launch_kernel = KernelManager._async_launch_kernel
start_kernel = KernelManager._async_start_kernel
pre_start_kernel = KernelManager._async_pre_start_kernel
post_start_kernel = KernelManager._async_post_start_kernel
request_shutdown = KernelManager._async_request_shutdown
finish_shutdown = KernelManager._async_finish_shutdown
cleanup_resources = KernelManager._async_cleanup_resources
shutdown_kernel = KernelManager._async_shutdown_kernel
restart_kernel = KernelManager._async_restart_kernel
_send_kernel_sigterm = KernelManager._async_send_kernel_sigterm
_kill_kernel = KernelManager._async_kill_kernel
interrupt_kernel = KernelManager._async_interrupt_kernel
signal_kernel = KernelManager._async_signal_kernel
is_alive = KernelManager._async_is_alive
KernelManagerABC.register(KernelManager)
def start_new_kernel(
startup_timeout: float = 60, kernel_name: str = "python", **kwargs: Any
) -> t.Tuple[KernelManager, KernelClient]:
"""Start a new kernel, and return its Manager and Client"""
km = KernelManager(kernel_name=kernel_name)
km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
raise
return km, kc
async def start_new_async_kernel(
startup_timeout: float = 60, kernel_name: str = "python", **kwargs: Any
) -> t.Tuple[AsyncKernelManager, KernelClient]:
"""Start a new kernel, and return its Manager and Client"""
km = AsyncKernelManager(kernel_name=kernel_name)
await km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
await kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
await km.shutdown_kernel()
raise
return (km, kc)
@contextmanager
def run_kernel(**kwargs: Any) -> t.Iterator[KernelClient]:
"""Context manager to create a kernel in a subprocess.
The kernel is shut down when the context exits.
Returns
-------
kernel_client: connected KernelClient instance
"""
km, kc = start_new_kernel(**kwargs)
try:
yield kc
finally:
kc.stop_channels()
km.shutdown_kernel(now=True)

View File

@@ -0,0 +1,49 @@
"""Abstract base class for kernel managers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
class KernelManagerABC(object, metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.kernelmanager.KernelManager`
"""
@abc.abstractproperty
def kernel(self):
pass
# --------------------------------------------------------------------------
# Kernel management
# --------------------------------------------------------------------------
@abc.abstractmethod
def start_kernel(self, **kw):
pass
@abc.abstractmethod
def shutdown_kernel(self, now=False, restart=False):
pass
@abc.abstractmethod
def restart_kernel(self, now=False, **kw):
pass
@abc.abstractproperty
def has_kernel(self):
pass
@abc.abstractmethod
def interrupt_kernel(self):
pass
@abc.abstractmethod
def signal_kernel(self, signum):
pass
@abc.abstractmethod
def is_alive(self):
pass

View File

@@ -0,0 +1,549 @@
"""A kernel manager for multiple kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
import socket
import typing as t
import uuid
import zmq
from traitlets import Any
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import DottedObjectName
from traitlets import Instance
from traitlets import observe
from traitlets import Unicode
from traitlets.config.configurable import LoggingConfigurable
from traitlets.utils.importstring import import_item
from .kernelspec import KernelSpecManager
from .kernelspec import NATIVE_KERNEL_NAME
from .manager import KernelManager
from .utils import ensure_async
from .utils import run_sync
class DuplicateKernelError(Exception):
pass
def kernel_method(f: t.Callable) -> t.Callable:
"""decorator for proxying MKM.method(kernel_id) to individual KMs by ID"""
def wrapped(
self: Any, kernel_id: str, *args: Any, **kwargs: Any
) -> t.Union[t.Callable, t.Awaitable]:
# get the kernel
km = self.get_kernel(kernel_id)
method = getattr(km, f.__name__)
# call the kernel's method
r = method(*args, **kwargs)
# last thing, call anything defined in the actual class method
# such as logging messages
f(self, kernel_id, *args, **kwargs)
# return the method result
return r
return wrapped
class MultiKernelManager(LoggingConfigurable):
"""A class for managing multiple kernels."""
default_kernel_name = Unicode(
NATIVE_KERNEL_NAME, help="The name of the default kernel to start"
).tag(config=True)
kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.IOLoopKernelManager",
help="""The kernel manager class. This is configurable to allow
subclassing of the KernelManager for customized behavior.
""",
).tag(config=True)
@observe("kernel_manager_class")
def _kernel_manager_class_changed(self, change):
self.kernel_manager_factory = self._create_kernel_manager_factory()
kernel_manager_factory = Any(help="this is kernel_manager_class after import")
@default("kernel_manager_factory")
def _kernel_manager_factory_default(self):
return self._create_kernel_manager_factory()
def _create_kernel_manager_factory(self) -> t.Callable:
kernel_manager_ctor = import_item(self.kernel_manager_class)
def create_kernel_manager(*args: Any, **kwargs: Any) -> KernelManager:
if self.shared_context:
if self.context.closed:
# recreate context if closed
self.context = self._context_default()
kwargs.setdefault("context", self.context)
km = kernel_manager_ctor(*args, **kwargs)
return km
return create_kernel_manager
shared_context = Bool(
True,
help="Share a single zmq.Context to talk to all my kernels",
).tag(config=True)
context = Instance("zmq.Context")
_created_context = Bool(False)
_pending_kernels = Dict()
@property
def _starting_kernels(self):
"""A shim for backwards compatibility."""
return self._pending_kernels
@default("context") # type:ignore[misc]
def _context_default(self) -> zmq.Context:
self._created_context = True
return zmq.Context()
connection_dir = Unicode("")
_kernels = Dict()
def __del__(self):
"""Handle garbage collection. Destroy context if applicable."""
if self._created_context and self.context and not self.context.closed:
if self.log:
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__
except AttributeError:
pass
else:
super_del()
def list_kernel_ids(self) -> t.List[str]:
"""Return a list of the kernel ids of the active kernels."""
# Create a copy so we can iterate over kernels in operations
# that delete keys.
return list(self._kernels.keys())
def __len__(self) -> int:
"""Return the number of running kernels."""
return len(self.list_kernel_ids())
def __contains__(self, kernel_id: str) -> bool:
return kernel_id in self._kernels
def pre_start_kernel(
self, kernel_name: t.Optional[str], kwargs: Any
) -> t.Tuple[KernelManager, str, str]:
# kwargs should be mutable, passing it as a dict argument.
kernel_id = kwargs.pop("kernel_id", self.new_kernel_id(**kwargs))
if kernel_id in self:
raise DuplicateKernelError("Kernel already exists: %s" % kernel_id)
if kernel_name is None:
kernel_name = self.default_kernel_name
# kernel_manager_factory is the constructor for the KernelManager
# subclass we are using. It can be configured as any Configurable,
# including things like its transport and ip.
constructor_kwargs = {}
if self.kernel_spec_manager:
constructor_kwargs["kernel_spec_manager"] = self.kernel_spec_manager
km = self.kernel_manager_factory(
connection_file=os.path.join(self.connection_dir, "kernel-%s.json" % kernel_id),
parent=self,
log=self.log,
kernel_name=kernel_name,
**constructor_kwargs,
)
return km, kernel_name, kernel_id
async def _add_kernel_when_ready(
self, kernel_id: str, km: KernelManager, kernel_awaitable: t.Awaitable
) -> None:
try:
await kernel_awaitable
self._kernels[kernel_id] = km
self._pending_kernels.pop(kernel_id, None)
except Exception as e:
self.log.exception(e)
async def _remove_kernel_when_ready(
self, kernel_id: str, kernel_awaitable: t.Awaitable
) -> None:
try:
await kernel_awaitable
self.remove_kernel(kernel_id)
self._pending_kernels.pop(kernel_id, None)
except Exception as e:
self.log.exception(e)
def _using_pending_kernels(self):
"""Returns a boolean; a clearer method for determining if
this multikernelmanager is using pending kernels or not
"""
return getattr(self, 'use_pending_kernels', False)
async def _async_start_kernel(self, kernel_name: t.Optional[str] = None, **kwargs: Any) -> str:
"""Start a new kernel.
The caller can pick a kernel_id by passing one in as a keyword arg,
otherwise one will be generated using new_kernel_id().
The kernel ID for the newly started kernel is returned.
"""
km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs)
if not isinstance(km, KernelManager):
self.log.warning(
"Kernel manager class ({km_class}) is not an instance of 'KernelManager'!".format(
km_class=self.kernel_manager_class.__class__
)
)
kwargs['kernel_id'] = kernel_id # Make kernel_id available to manager and provisioner
starter = ensure_async(km.start_kernel(**kwargs))
task = asyncio.create_task(self._add_kernel_when_ready(kernel_id, km, starter))
self._pending_kernels[kernel_id] = task
# Handling a Pending Kernel
if self._using_pending_kernels():
# If using pending kernels, do not block
# on the kernel start.
self._kernels[kernel_id] = km
else:
await task
# raise an exception if one occurred during kernel startup.
if km.ready.exception():
raise km.ready.exception() # type: ignore
return kernel_id
start_kernel = run_sync(_async_start_kernel)
async def _async_shutdown_kernel(
self,
kernel_id: str,
now: t.Optional[bool] = False,
restart: t.Optional[bool] = False,
) -> None:
"""Shutdown a kernel by its kernel uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to shutdown.
now : bool
Should the kernel be shutdown forcibly using a signal.
restart : bool
Will the kernel be restarted?
"""
self.log.info("Kernel shutdown: %s" % kernel_id)
# If the kernel is still starting, wait for it to be ready.
if kernel_id in self._pending_kernels:
task = self._pending_kernels[kernel_id]
try:
await task
km = self.get_kernel(kernel_id)
await t.cast(asyncio.Future, km.ready)
except asyncio.CancelledError:
pass
except Exception:
self.remove_kernel(kernel_id)
return
km = self.get_kernel(kernel_id)
# If a pending kernel raised an exception, remove it.
if not km.ready.cancelled() and km.ready.exception():
self.remove_kernel(kernel_id)
return
stopper = ensure_async(km.shutdown_kernel(now, restart))
fut = asyncio.ensure_future(self._remove_kernel_when_ready(kernel_id, stopper))
self._pending_kernels[kernel_id] = fut
# Await the kernel if not using pending kernels.
if not self._using_pending_kernels():
await fut
# raise an exception if one occurred during kernel shutdown.
if km.ready.exception():
raise km.ready.exception() # type: ignore
shutdown_kernel = run_sync(_async_shutdown_kernel)
@kernel_method
def request_shutdown(self, kernel_id: str, restart: t.Optional[bool] = False) -> None:
"""Ask a kernel to shut down by its kernel uuid"""
@kernel_method
def finish_shutdown(
self,
kernel_id: str,
waittime: t.Optional[float] = None,
pollinterval: t.Optional[float] = 0.1,
) -> None:
"""Wait for a kernel to finish shutting down, and kill it if it doesn't"""
self.log.info("Kernel shutdown: %s" % kernel_id)
@kernel_method
def cleanup_resources(self, kernel_id: str, restart: bool = False) -> None:
"""Clean up a kernel's resources"""
def remove_kernel(self, kernel_id: str) -> KernelManager:
"""remove a kernel from our mapping.
Mainly so that a kernel can be removed if it is already dead,
without having to call shutdown_kernel.
The kernel object is returned, or `None` if not found.
"""
return self._kernels.pop(kernel_id, None)
async def _async_shutdown_all(self, now: bool = False) -> None:
"""Shutdown all kernels."""
kids = self.list_kernel_ids()
kids += list(self._pending_kernels)
kms = list(self._kernels.values())
futs = [ensure_async(self.shutdown_kernel(kid, now=now)) for kid in set(kids)]
await asyncio.gather(*futs)
# If using pending kernels, the kernels will not have been fully shut down.
if self._using_pending_kernels():
for km in kms:
try:
await km.ready
except asyncio.CancelledError:
self._pending_kernels[km.kernel_id].cancel()
except Exception:
# Will have been logged in _add_kernel_when_ready
pass
shutdown_all = run_sync(_async_shutdown_all)
def interrupt_kernel(self, kernel_id: str) -> None:
"""Interrupt (SIGINT) the kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
kernel = self.get_kernel(kernel_id)
if not kernel.ready.done():
raise RuntimeError("Kernel is in a pending state. Cannot interrupt.")
out = kernel.interrupt_kernel()
self.log.info("Kernel interrupted: %s" % kernel_id)
return out
@kernel_method
def signal_kernel(self, kernel_id: str, signum: int) -> None:
"""Sends a signal to the kernel by its uuid.
Note that since only SIGTERM is supported on Windows, this function
is only useful on Unix systems.
Parameters
==========
kernel_id : uuid
The id of the kernel to signal.
signum : int
Signal number to send kernel.
"""
self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum))
async def _async_restart_kernel(self, kernel_id: str, now: bool = False) -> None:
"""Restart a kernel by its uuid, keeping the same ports.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
"""
kernel = self.get_kernel(kernel_id)
if self._using_pending_kernels():
if not kernel.ready.done():
raise RuntimeError("Kernel is in a pending state. Cannot restart.")
out = await ensure_async(kernel.restart_kernel(now=now))
self.log.info("Kernel restarted: %s" % kernel_id)
return out
restart_kernel = run_sync(_async_restart_kernel)
@kernel_method
def is_alive(self, kernel_id: str) -> bool:
"""Is the kernel alive.
This calls KernelManager.is_alive() which calls Popen.poll on the
actual kernel subprocess.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
def _check_kernel_id(self, kernel_id: str) -> None:
"""check that a kernel id is valid"""
if kernel_id not in self:
raise KeyError("Kernel with id not found: %s" % kernel_id)
def get_kernel(self, kernel_id: str) -> KernelManager:
"""Get the single KernelManager object for a kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
self._check_kernel_id(kernel_id)
return self._kernels[kernel_id]
@kernel_method
def add_restart_callback(
self, kernel_id: str, callback: t.Callable, event: str = "restart"
) -> None:
"""add a callback for the KernelRestarter"""
@kernel_method
def remove_restart_callback(
self, kernel_id: str, callback: t.Callable, event: str = "restart"
) -> None:
"""remove a callback for the KernelRestarter"""
@kernel_method
def get_connection_info(self, kernel_id: str) -> t.Dict[str, t.Any]:
"""Return a dictionary of connection data for a kernel.
Parameters
==========
kernel_id : uuid
The id of the kernel.
Returns
=======
connection_dict : dict
A dict of the information needed to connect to a kernel.
This includes the ip address and the integer port
numbers of the different channels (stdin_port, iopub_port,
shell_port, hb_port).
"""
@kernel_method
def connect_iopub(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket:
"""Return a zmq Socket connected to the iopub channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_shell(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket:
"""Return a zmq Socket connected to the shell channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_control(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket:
"""Return a zmq Socket connected to the control channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_stdin(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket:
"""Return a zmq Socket connected to the stdin channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_hb(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket:
"""Return a zmq Socket connected to the hb channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
def new_kernel_id(self, **kwargs: Any) -> str:
"""
Returns the id to associate with the kernel for this request. Subclasses may override
this method to substitute other sources of kernel ids.
:param kwargs:
:return: string-ized version 4 uuid
"""
return str(uuid.uuid4())
class AsyncMultiKernelManager(MultiKernelManager):
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.AsyncIOLoopKernelManager",
config=True,
help="""The kernel manager class. This is configurable to allow
subclassing of the AsyncKernelManager for customized behavior.
""",
)
use_pending_kernels = Bool(
False,
help="""Whether to make kernels available before the process has started. The
kernel has a `.ready` future which can be awaited before connecting""",
).tag(config=True)
start_kernel = MultiKernelManager._async_start_kernel
restart_kernel = MultiKernelManager._async_restart_kernel
shutdown_kernel = MultiKernelManager._async_shutdown_kernel
shutdown_all = MultiKernelManager._async_shutdown_all

View File

@@ -0,0 +1,3 @@
from .factory import KernelProvisionerFactory # noqa
from .local_provisioner import LocalProvisioner # noqa
from .provisioner_base import KernelProvisionerBase # noqa

View File

@@ -0,0 +1,201 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import glob
from os import getenv
from os import path
from typing import Any
from typing import Dict
from typing import List
from entrypoints import EntryPoint
from entrypoints import get_group_all
from entrypoints import get_single
from entrypoints import NoSuchEntryPoint
from traitlets.config import default
from traitlets.config import SingletonConfigurable
from traitlets.config import Unicode
from .provisioner_base import KernelProvisionerBase
class KernelProvisionerFactory(SingletonConfigurable):
"""
:class:`KernelProvisionerFactory` is responsible for creating provisioner instances.
A singleton instance, `KernelProvisionerFactory` is also used by the :class:`KernelSpecManager`
to validate `kernel_provisioner` references found in kernel specifications to confirm their
availability (in cases where the kernel specification references a kernel provisioner that has
not been installed into the current Python environment).
It's `default_provisioner_name` attribute can be used to specify the default provisioner
to use when a kernel_spec is found to not reference a provisioner. It's value defaults to
`"local-provisioner"` which identifies the local provisioner implemented by
:class:`LocalProvisioner`.
"""
GROUP_NAME = 'jupyter_client.kernel_provisioners'
provisioners: Dict[str, EntryPoint] = {}
default_provisioner_name_env = "JUPYTER_DEFAULT_PROVISIONER_NAME"
default_provisioner_name = Unicode(
config=True,
help="""Indicates the name of the provisioner to use when no kernel_provisioner
entry is present in the kernelspec.""",
)
@default('default_provisioner_name')
def default_provisioner_name_default(self):
return getenv(self.default_provisioner_name_env, "local-provisioner")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
for ep in KernelProvisionerFactory._get_all_provisioners():
self.provisioners[ep.name] = ep
def is_provisioner_available(self, kernel_spec: Any) -> bool:
"""
Reads the associated ``kernel_spec`` to determine the provisioner and returns whether it
exists as an entry_point (True) or not (False). If the referenced provisioner is not
in the current cache or cannot be loaded via entry_points, a warning message is issued
indicating it is not available.
"""
is_available: bool = True
provisioner_cfg = self._get_provisioner_config(kernel_spec)
provisioner_name = str(provisioner_cfg.get('provisioner_name'))
if not self._check_availability(provisioner_name):
is_available = False
self.log.warning(
f"Kernel '{kernel_spec.display_name}' is referencing a kernel "
f"provisioner ('{provisioner_name}') that is not available. "
f"Ensure the appropriate package has been installed and retry."
)
return is_available
def create_provisioner_instance(
self, kernel_id: str, kernel_spec: Any, parent: Any
) -> KernelProvisionerBase:
"""
Reads the associated ``kernel_spec`` to see if it has a `kernel_provisioner` stanza.
If one exists, it instantiates an instance. If a kernel provisioner is not
specified in the kernel specification, a default provisioner stanza is fabricated
and instantiated corresponding to the current value of `default_provisioner_name` trait.
The instantiated instance is returned.
If the provisioner is found to not exist (not registered via entry_points),
`ModuleNotFoundError` is raised.
"""
provisioner_cfg = self._get_provisioner_config(kernel_spec)
provisioner_name = str(provisioner_cfg.get('provisioner_name'))
if not self._check_availability(provisioner_name):
raise ModuleNotFoundError(
f"Kernel provisioner '{provisioner_name}' has not been registered."
)
self.log.debug(
f"Instantiating kernel '{kernel_spec.display_name}' with "
f"kernel provisioner: {provisioner_name}"
)
provisioner_class = self.provisioners[provisioner_name].load()
provisioner_config = provisioner_cfg.get('config')
provisioner: KernelProvisionerBase = provisioner_class(
kernel_id=kernel_id, kernel_spec=kernel_spec, parent=parent, **provisioner_config
)
return provisioner
def _check_availability(self, provisioner_name: str) -> bool:
"""
Checks that the given provisioner is available.
If the given provisioner is not in the current set of loaded provisioners an attempt
is made to fetch the named entry point and, if successful, loads it into the cache.
:param provisioner_name:
:return:
"""
is_available = True
if provisioner_name not in self.provisioners:
try:
ep = self._get_provisioner(provisioner_name)
self.provisioners[provisioner_name] = ep # Update cache
except NoSuchEntryPoint:
is_available = False
return is_available
def _get_provisioner_config(self, kernel_spec: Any) -> Dict[str, Any]:
"""
Return the kernel_provisioner stanza from the kernel_spec.
Checks the kernel_spec's metadata dictionary for a kernel_provisioner entry.
If found, it is returned, else one is created relative to the DEFAULT_PROVISIONER
and returned.
Parameters
----------
kernel_spec : Any - this is a KernelSpec type but listed as Any to avoid circular import
The kernel specification object from which the provisioner dictionary is derived.
Returns
-------
dict
The provisioner portion of the kernel_spec. If one does not exist, it will contain
the default information. If no `config` sub-dictionary exists, an empty `config`
dictionary will be added.
"""
env_provisioner = kernel_spec.metadata.get('kernel_provisioner', {})
if 'provisioner_name' in env_provisioner: # If no provisioner_name, return default
if (
'config' not in env_provisioner
): # if provisioner_name, but no config stanza, add one
env_provisioner.update({"config": {}})
return env_provisioner # Return what we found (plus config stanza if necessary)
return {"provisioner_name": self.default_provisioner_name, "config": {}}
def get_provisioner_entries(self) -> Dict[str, str]:
"""
Returns a dictionary of provisioner entries.
The key is the provisioner name for its entry point. The value is the colon-separated
string of the entry point's module name and object name.
"""
entries = {}
for name, ep in self.provisioners.items():
entries[name] = f"{ep.module_name}:{ep.object_name}"
return entries
@staticmethod
def _get_all_provisioners() -> List[EntryPoint]:
"""Wrapper around entrypoints.get_group_all() - primarily to facilitate testing."""
return get_group_all(KernelProvisionerFactory.GROUP_NAME)
def _get_provisioner(self, name: str) -> EntryPoint:
"""Wrapper around entrypoints.get_single() - primarily to facilitate testing."""
try:
ep = get_single(KernelProvisionerFactory.GROUP_NAME, name)
except NoSuchEntryPoint:
# Check if the entrypoint name is 'local-provisioner'. Although this should never
# happen, we have seen cases where the previous distribution of jupyter_client has
# remained which doesn't include kernel-provisioner entrypoints (so 'local-provisioner'
# is deemed not found even though its definition is in THIS package). In such cass,
# the entrypoints package uses what it first finds - which is the older distribution
# resulting in a violation of a supposed invariant condition. To address this scenario,
# we will log a warning message indicating this situation, then build the entrypoint
# instance ourselves - since we have that information.
if name == 'local-provisioner':
distros = glob.glob(f"{path.dirname(path.dirname(__file__))}-*")
self.log.warning(
f"Kernel Provisioning: The 'local-provisioner' is not found. This is likely "
f"due to the presence of multiple jupyter_client distributions and a previous "
f"distribution is being used as the source for entrypoints - which does not "
f"include 'local-provisioner'. That distribution should be removed such that "
f"only the version-appropriate distribution remains (version >= 7). Until "
f"then, a 'local-provisioner' entrypoint will be automatically constructed "
f"and used.\nThe candidate distribution locations are: {distros}"
)
ep = EntryPoint(
'local-provisioner', 'jupyter_client.provisioning', 'LocalProvisioner'
)
else:
raise
return ep

View File

@@ -0,0 +1,236 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
import signal
import sys
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from ..connect import KernelConnectionInfo
from ..connect import LocalPortCache
from ..launcher import launch_kernel
from ..localinterfaces import is_local_ip
from ..localinterfaces import local_ips
from .provisioner_base import KernelProvisionerBase
class LocalProvisioner(KernelProvisionerBase):
"""
:class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
and is the out-of-box default implementation used when no kernel provisioner is
specified in the kernel specification (``kernel.json``). It provides functional
parity to existing applications by launching the kernel locally and using
:class:`subprocess.Popen` to manage its lifecycle.
This class is intended to be subclassed for customizing local kernel environments
and serve as a reference implementation for other custom provisioners.
"""
process = None
_exit_future = None
pid = None
pgid = None
ip = None
ports_cached = False
@property
def has_process(self) -> bool:
return self.process is not None
async def poll(self) -> Optional[int]:
ret = 0
if self.process:
ret = self.process.poll()
return ret
async def wait(self) -> Optional[int]:
ret = 0
if self.process:
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see kill()).
while await self.poll() is None:
await asyncio.sleep(0.1)
# Process is no longer alive, wait and clear
ret = self.process.wait()
# Make sure all the fds get closed.
for attr in ['stdout', 'stderr', 'stdin']:
fid = getattr(self.process, attr)
if fid:
fid.close()
self.process = None # allow has_process to now return False
return ret
async def send_signal(self, signum: int) -> None:
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, we will
check if the desired signal is for interrupt and apply the
applicable code on Windows in that case.
"""
if self.process:
if signum == signal.SIGINT and sys.platform == 'win32':
from ..win_interrupt import send_interrupt
send_interrupt(self.process.win32_interrupt_event)
return
# Prefer process-group over process
if self.pgid and hasattr(os, "killpg"):
try:
os.killpg(self.pgid, signum)
return
except OSError:
pass # We'll retry sending the signal to only the process below
# If we're here, send the signal to the process and let caller handle exceptions
self.process.send_signal(signum)
return
async def kill(self, restart: bool = False) -> None:
if self.process:
if hasattr(signal, "SIGKILL"):
# If available, give preference to signalling the process-group over `kill()`.
try:
await self.send_signal(signal.SIGKILL)
return
except OSError:
pass
try:
self.process.kill()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
async def terminate(self, restart: bool = False) -> None:
if self.process:
if hasattr(signal, "SIGTERM"):
# If available, give preference to signalling the process group over `terminate()`.
try:
await self.send_signal(signal.SIGTERM)
return
except OSError:
pass
try:
self.process.terminate()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
@staticmethod
def _tolerate_no_process(os_error: OSError) -> None:
# In Windows, we will get an Access Denied error if the process
# has already terminated. Ignore it.
if sys.platform == 'win32':
if os_error.winerror != 5:
raise
# On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
# the process has already terminated. Ignore it.
else:
from errno import ESRCH
if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
raise
async def cleanup(self, restart: bool = False) -> None:
if self.ports_cached and not restart:
# provisioner is about to be destroyed, return cached ports
lpc = LocalPortCache.instance()
ports = (
self.connection_info['shell_port'],
self.connection_info['iopub_port'],
self.connection_info['stdin_port'],
self.connection_info['hb_port'],
self.connection_info['control_port'],
)
for port in ports:
lpc.return_port(port)
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
"""Perform any steps in preparation for kernel process launch.
This includes applying additional substitutions to the kernel launch command and env.
It also includes preparation of launch parameters.
Returns the updated kwargs.
"""
# This should be considered temporary until a better division of labor can be defined.
km = self.parent
if km:
if km.transport == 'tcp' and not is_local_ip(km.ip):
raise RuntimeError(
"Can only launch a kernel on a local interface. "
"This one is not: %s."
"Make sure that the '*_address' attributes are "
"configured properly. "
"Currently valid addresses are: %s" % (km.ip, local_ips())
)
# build the Popen cmd
extra_arguments = kwargs.pop('extra_arguments', [])
# write connection file / get default ports
# TODO - change when handshake pattern is adopted
if km.cache_ports and not self.ports_cached:
lpc = LocalPortCache.instance()
km.shell_port = lpc.find_available_port(km.ip)
km.iopub_port = lpc.find_available_port(km.ip)
km.stdin_port = lpc.find_available_port(km.ip)
km.hb_port = lpc.find_available_port(km.ip)
km.control_port = lpc.find_available_port(km.ip)
self.ports_cached = True
km.write_connection_file()
self.connection_info = km.get_connection_info()
kernel_cmd = km.format_kernel_cmd(
extra_arguments=extra_arguments
) # This needs to remain here for b/c
else:
extra_arguments = kwargs.pop('extra_arguments', [])
kernel_cmd = self.kernel_spec.argv + extra_arguments
return await super().pre_launch(cmd=kernel_cmd, **kwargs)
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
self.process = launch_kernel(cmd, **scrubbed_kwargs)
pgid = None
if hasattr(os, "getpgid"):
try:
pgid = os.getpgid(self.process.pid)
except OSError:
pass
self.pid = self.process.pid
self.pgid = pgid
return self.connection_info
@staticmethod
def _scrub_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Remove any keyword arguments that Popen does not tolerate."""
keywords_to_scrub: List[str] = ['extra_arguments', 'kernel_id']
scrubbed_kwargs = kwargs.copy()
for kw in keywords_to_scrub:
scrubbed_kwargs.pop(kw, None)
return scrubbed_kwargs
async def get_provisioner_info(self) -> Dict:
"""Captures the base information necessary for persistence relative to this instance."""
provisioner_info = await super().get_provisioner_info()
provisioner_info.update({'pid': self.pid, 'pgid': self.pgid, 'ip': self.ip})
return provisioner_info
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
"""Loads the base information necessary for persistence relative to this instance."""
await super().load_provisioner_info(provisioner_info)
self.pid = provisioner_info['pid']
self.pgid = provisioner_info['pgid']
self.ip = provisioner_info['ip']

View File

@@ -0,0 +1,262 @@
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from abc import ABC
from abc import ABCMeta
from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from traitlets.config import Instance
from traitlets.config import LoggingConfigurable
from traitlets.config import Unicode
from ..connect import KernelConnectionInfo
class KernelProvisionerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore
pass
class KernelProvisionerBase(ABC, LoggingConfigurable, metaclass=KernelProvisionerMeta):
"""
Abstract base class defining methods for KernelProvisioner classes.
A majority of methods are abstract (requiring implementations via a subclass) while
some are optional and others provide implementations common to all instances.
Subclasses should be aware of which methods require a call to the superclass.
Many of these methods model those of :class:`subprocess.Popen` for parity with
previous versions where the kernel process was managed directly.
"""
# The kernel specification associated with this provisioner
kernel_spec: Any = Instance('jupyter_client.kernelspec.KernelSpec', allow_none=True)
kernel_id: str = Unicode(None, allow_none=True)
connection_info: KernelConnectionInfo = {}
@property
@abstractmethod
def has_process(self) -> bool:
"""
Returns true if this provisioner is currently managing a process.
This property is asserted to be True immediately following a call to
the provisioner's :meth:`launch_kernel` method.
"""
pass
@abstractmethod
async def poll(self) -> Optional[int]:
"""
Checks if kernel process is still running.
If running, None is returned, otherwise the process's integer-valued exit code is returned.
This method is called from :meth:`KernelManager.is_alive`.
"""
pass
@abstractmethod
async def wait(self) -> Optional[int]:
"""
Waits for kernel process to terminate.
This method is called from `KernelManager.finish_shutdown()` and
`KernelManager.kill_kernel()` when terminating a kernel gracefully or
immediately, respectively.
"""
pass
@abstractmethod
async def send_signal(self, signum: int) -> None:
"""
Sends signal identified by signum to the kernel process.
This method is called from `KernelManager.signal_kernel()` to send the
kernel process a signal.
"""
pass
@abstractmethod
async def kill(self, restart: bool = False) -> None:
"""
Kill the kernel process.
This is typically accomplished via a SIGKILL signal, which cannot be caught.
This method is called from `KernelManager.kill_kernel()` when terminating
a kernel immediately.
restart is True if this operation will precede a subsequent launch_kernel request.
"""
pass
@abstractmethod
async def terminate(self, restart: bool = False) -> None:
"""
Terminates the kernel process.
This is typically accomplished via a SIGTERM signal, which can be caught, allowing
the kernel provisioner to perform possible cleanup of resources. This method is
called indirectly from `KernelManager.finish_shutdown()` during a kernel's
graceful termination.
restart is True if this operation precedes a start launch_kernel request.
"""
pass
@abstractmethod
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
"""
Launch the kernel process and return its connection information.
This method is called from `KernelManager.launch_kernel()` during the
kernel manager's start kernel sequence.
"""
pass
@abstractmethod
async def cleanup(self, restart: bool = False) -> None:
"""
Cleanup any resources allocated on behalf of the kernel provisioner.
This method is called from `KernelManager.cleanup_resources()` as part of
its shutdown kernel sequence.
restart is True if this operation precedes a start launch_kernel request.
"""
pass
async def shutdown_requested(self, restart: bool = False) -> None:
"""
Allows the provisioner to determine if the kernel's shutdown has been requested.
This method is called from `KernelManager.request_shutdown()` as part of
its shutdown sequence.
This method is optional and is primarily used in scenarios where the provisioner
may need to perform other operations in preparation for a kernel's shutdown.
"""
pass
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
"""
Perform any steps in preparation for kernel process launch.
This includes applying additional substitutions to the kernel launch command
and environment. It also includes preparation of launch parameters.
NOTE: Subclass implementations are advised to call this method as it applies
environment variable substitutions from the local environment and calls the
provisioner's :meth:`_finalize_env()` method to allow each provisioner the
ability to cleanup the environment variables that will be used by the kernel.
This method is called from `KernelManager.pre_start_kernel()` as part of its
start kernel sequence.
Returns the (potentially updated) keyword arguments that are passed to
:meth:`launch_kernel()`.
"""
env = kwargs.pop('env', os.environ).copy()
env.update(self.__apply_env_substitutions(env))
self._finalize_env(env)
kwargs['env'] = env
return kwargs
async def post_launch(self, **kwargs: Any) -> None:
"""
Perform any steps following the kernel process launch.
This method is called from `KernelManager.post_start_kernel()` as part of its
start kernel sequence.
"""
pass
async def get_provisioner_info(self) -> Dict[str, Any]:
"""
Captures the base information necessary for persistence relative to this instance.
This enables applications that subclass `KernelManager` to persist a kernel provisioner's
relevant information to accomplish functionality like disaster recovery or high availability
by calling this method via the kernel manager's `provisioner` attribute.
NOTE: The superclass method must always be called first to ensure proper serialization.
"""
provisioner_info: Dict[str, Any] = {}
provisioner_info['kernel_id'] = self.kernel_id
provisioner_info['connection_info'] = self.connection_info
return provisioner_info
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
"""
Loads the base information necessary for persistence relative to this instance.
The inverse of `get_provisioner_info()`, this enables applications that subclass
`KernelManager` to re-establish communication with a provisioner that is managing
a (presumably) remote kernel from an entirely different process that the original
provisioner.
NOTE: The superclass method must always be called first to ensure proper deserialization.
"""
self.kernel_id = provisioner_info['kernel_id']
self.connection_info = provisioner_info['connection_info']
def get_shutdown_wait_time(self, recommended: float = 5.0) -> float:
"""
Returns the time allowed for a complete shutdown. This may vary by provisioner.
This method is called from `KernelManager.finish_shutdown()` during the graceful
phase of its kernel shutdown sequence.
The recommended value will typically be what is configured in the kernel manager.
"""
return recommended
def get_stable_start_time(self, recommended: float = 10.0) -> float:
"""
Returns the expected upper bound for a kernel (re-)start to complete.
This may vary by provisioner.
The recommended value will typically be what is configured in the kernel restarter.
"""
return recommended
def _finalize_env(self, env: Dict[str, str]) -> None:
"""
Ensures env is appropriate prior to launch.
This method is called from `KernelProvisionerBase.pre_launch()` during the kernel's
start sequence.
NOTE: Subclasses should be sure to call super()._finalize_env(env)
"""
if self.kernel_spec.language and self.kernel_spec.language.lower().startswith("python"):
# Don't allow PYTHONEXECUTABLE to be passed to kernel process.
# If set, it can bork all the things.
env.pop('PYTHONEXECUTABLE', None)
def __apply_env_substitutions(self, substitution_values: Dict[str, str]) -> Dict[str, str]:
"""
Walks entries in the kernelspec's env stanza and applies substitutions from current env.
This method is called from `KernelProvisionerBase.pre_launch()` during the kernel's
start sequence.
Returns the substituted list of env entries.
NOTE: This method is private and is not intended to be overridden by provisioners.
"""
substituted_env = {}
if self.kernel_spec:
from string import Template
# For each templated env entry, fill any templated references
# matching names of env variables with those values and build
# new dict with substitutions.
templated_env = self.kernel_spec.env
for k, v in templated_env.items():
substituted_env.update({k: Template(v).safe_substitute(substitution_values)})
return substituted_env

View File

@@ -0,0 +1,162 @@
"""A basic kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
It is an incomplete base class, and must be subclassed.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import time
from traitlets import Bool
from traitlets import default
from traitlets import Dict
from traitlets import Float
from traitlets import Instance
from traitlets import Integer
from traitlets.config.configurable import LoggingConfigurable
class KernelRestarter(LoggingConfigurable):
"""Monitor and autorestart a kernel."""
kernel_manager = Instance("jupyter_client.KernelManager")
debug = Bool(
False,
config=True,
help="""Whether to include every poll event in debugging output.
Has to be set explicitly, because there will be *a lot* of output.
""",
)
time_to_dead = Float(3.0, config=True, help="""Kernel heartbeat interval in seconds.""")
stable_start_time = Float(
10.0,
config=True,
help="""The time in seconds to consider the kernel to have completed a stable start up.""",
)
restart_limit = Integer(
5,
config=True,
help="""The number of consecutive autorestarts before the kernel is presumed dead.""",
)
random_ports_until_alive = Bool(
True,
config=True,
help="""Whether to choose new random ports when restarting before the kernel is alive.""",
)
_restarting = Bool(False)
_restart_count = Integer(0)
_initial_startup = Bool(True)
_last_dead = Float()
@default("_last_dead")
def _default_last_dead(self):
return time.time()
callbacks = Dict()
def _callbacks_default(self):
return dict(restart=[], dead=[])
def start(self):
"""Start the polling of the kernel."""
raise NotImplementedError("Must be implemented in a subclass")
def stop(self):
"""Stop the kernel polling."""
raise NotImplementedError("Must be implemented in a subclass")
def add_callback(self, f, event="restart"):
"""register a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
self.callbacks[event].append(f)
def remove_callback(self, f, event="restart"):
"""unregister a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
try:
self.callbacks[event].remove(f)
except ValueError:
pass
def _fire_callbacks(self, event):
"""fire our callbacks for a particular event"""
for callback in self.callbacks[event]:
try:
callback()
except Exception:
self.log.error(
"KernelRestarter: %s callback %r failed",
event,
callback,
exc_info=True,
)
def poll(self):
if self.debug:
self.log.debug("Polling kernel...")
if self.kernel_manager.shutting_down:
self.log.debug("Kernel shutdown in progress...")
return
now = time.time()
if not self.kernel_manager.is_alive():
self._last_dead = now
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count > self.restart_limit:
self.log.warning("KernelRestarter: restart failed")
self._fire_callbacks("dead")
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info(
"KernelRestarter: restarting kernel (%i/%i), %s random ports",
self._restart_count,
self.restart_limit,
"new" if newports else "keep",
)
self._fire_callbacks("restart")
self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
# Since `is_alive` only tests that the kernel process is alive, it does not
# indicate that the kernel has successfully completed startup. To solve this
# correctly, we would need to wait for a kernel info reply, but it is not
# necessarily appropriate to start a kernel client + channels in the
# restarter. Therefore, we use "has been alive continuously for X time" as a
# heuristic for a stable start up.
# See https://github.com/jupyter/jupyter_client/pull/717 for details.
stable_start_time = self.stable_start_time
if self.kernel_manager.provisioner:
stable_start_time = self.kernel_manager.provisioner.get_stable_start_time(
recommended=stable_start_time
)
if self._initial_startup and now - self._last_dead >= stable_start_time:
self._initial_startup = False
if self._restarting and now - self._last_dead >= stable_start_time:
self.log.debug("KernelRestarter: restart apparently succeeded")
self._restarting = False

View File

@@ -0,0 +1,122 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import queue
import signal
import sys
import time
from jupyter_core.application import base_aliases
from jupyter_core.application import base_flags
from jupyter_core.application import JupyterApp
from traitlets import Any
from traitlets import Dict
from traitlets import Float
from traitlets.config import catch_config_error
from . import __version__
from .consoleapp import app_aliases
from .consoleapp import app_flags
from .consoleapp import JupyterConsoleApp
OUTPUT_TIMEOUT = 10
# copy flags from mixin:
flags = dict(base_flags)
# start with mixin frontend flags:
frontend_flags_dict = dict(app_flags)
# update full dict with frontend flags:
flags.update(frontend_flags_dict)
# copy flags from mixin
aliases = dict(base_aliases)
# start with mixin frontend flags
frontend_aliases_dict = dict(app_aliases)
# load updated frontend flags into full dict
aliases.update(frontend_aliases_dict)
# get flags&aliases into sets, and remove a couple that
# shouldn't be scrubbed from backend flags:
frontend_aliases = set(frontend_aliases_dict.keys())
frontend_flags = set(frontend_flags_dict.keys())
class RunApp(JupyterApp, JupyterConsoleApp):
version = __version__
name = "jupyter run"
description = """Run Jupyter kernel code."""
flags = Dict(flags)
aliases = Dict(aliases)
frontend_aliases = Any(frontend_aliases)
frontend_flags = Any(frontend_flags)
kernel_timeout = Float(
60,
config=True,
help="""Timeout for giving up on a kernel (in seconds).
On first connect and restart, the console tests whether the
kernel is running and responsive by sending kernel_info_requests.
This sets the timeout in seconds for how long the kernel can take
before being presumed dead.
""",
)
def parse_command_line(self, argv=None):
super().parse_command_line(argv)
self.build_kernel_argv(self.extra_args)
self.filenames_to_run = self.extra_args[:]
@catch_config_error
def initialize(self, argv=None):
self.log.debug("jupyter run: initialize...")
super().initialize(argv)
JupyterConsoleApp.initialize(self)
signal.signal(signal.SIGINT, self.handle_sigint)
self.init_kernel_info()
def handle_sigint(self, *args):
if self.kernel_manager:
self.kernel_manager.interrupt_kernel()
else:
self.log.error("Cannot interrupt kernels we didn't start.\n")
def init_kernel_info(self):
"""Wait for a kernel to be ready, and store kernel info"""
timeout = self.kernel_timeout
tic = time.time()
self.kernel_client.hb_channel.unpause()
msg_id = self.kernel_client.kernel_info()
while True:
try:
reply = self.kernel_client.get_shell_msg(timeout=1)
except queue.Empty as e:
if (time.time() - tic) > timeout:
raise RuntimeError("Kernel didn't respond to kernel_info_request") from e
else:
if reply["parent_header"].get("msg_id") == msg_id:
self.kernel_info = reply["content"]
return
def start(self):
self.log.debug("jupyter run: starting...")
super().start()
if self.filenames_to_run:
for filename in self.filenames_to_run:
self.log.debug("jupyter run: executing `%s`" % filename)
with open(filename) as fp:
code = fp.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply["content"]["status"] == "ok" else 1
if return_code:
raise Exception("jupyter-run error running '%s'" % filename)
else:
code = sys.stdin.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply["content"]["status"] == "ok" else 1
if return_code:
raise Exception("jupyter-run error running 'stdin'")
main = launch_new_instance = RunApp.launch_instance
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
from jupyter_client.ssh.tunnel import * # noqa

View File

@@ -0,0 +1,97 @@
#
# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
# Edits Copyright (C) 2010 The IPython Team
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
"""
Sample script showing how to do local port forwarding over paramiko.
This script connects to the requested SSH server and sets up local port
forwarding (the openssh -L option) from a local port through a tunneled
connection to a destination reachable from the SSH server machine.
"""
import logging
import select
import socketserver
import typing as t
logger = logging.getLogger("ssh")
class ForwardServer(socketserver.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True
class Handler(socketserver.BaseRequestHandler):
@t.no_type_check
def handle(self):
try:
chan = self.ssh_transport.open_channel(
"direct-tcpip",
(self.chain_host, self.chain_port),
self.request.getpeername(),
)
except Exception as e:
logger.debug(
"Incoming request to %s:%d failed: %s" % (self.chain_host, self.chain_port, repr(e))
)
return
if chan is None:
logger.debug(
"Incoming request to %s:%d was rejected by the SSH server."
% (self.chain_host, self.chain_port)
)
return
logger.debug(
"Connected! Tunnel open %r -> %r -> %r"
% (
self.request.getpeername(),
chan.getpeername(),
(self.chain_host, self.chain_port),
)
)
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
chan.send(data)
if chan in r:
data = chan.recv(1024)
if len(data) == 0:
break
self.request.send(data)
chan.close()
self.request.close()
logger.debug("Tunnel closed ")
def forward_tunnel(local_port, remote_host, remote_port, transport):
# this is a little convoluted, but lets me configure things for the Handler
# object. (SocketServer doesn't give Handlers any way to access the outer
# server normally.)
class SubHander(Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
ForwardServer(("127.0.0.1", local_port), SubHander).serve_forever()
__all__ = ["forward_tunnel"]

View File

@@ -0,0 +1,415 @@
"""Basic ssh tunnel utilities, and convenience functions for tunneling
zeromq connections.
"""
# Copyright (C) 2010-2011 IPython Development Team
# Copyright (C) 2011- PyZMQ Developers
#
# Redistributed from IPython under the terms of the BSD License.
import atexit
import os
import re
import signal
import socket
import sys
import warnings
from getpass import getpass
from getpass import getuser
from multiprocessing import Process
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
import paramiko
SSHException = paramiko.ssh_exception.SSHException
except ImportError:
paramiko = None # type:ignore[assignment]
class SSHException(Exception): # type: ignore
pass
else:
from .forward import forward_tunnel
try:
import pexpect # type: ignore
except ImportError:
pexpect = None
def select_random_ports(n):
"""Select and return n random ports that are available."""
ports = []
sockets = []
for _ in range(n):
sock = socket.socket()
sock.bind(("", 0))
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports
# -----------------------------------------------------------------------------
# Check for passwordless login
# -----------------------------------------------------------------------------
_password_pat = re.compile((r"pass(word|phrase):".encode("utf8")), re.IGNORECASE)
def try_passwordless_ssh(server, keyfile, paramiko=None):
"""Attempt to make an ssh connection without a password.
This is mainly used for requiring password input only once
when many tunnels may be connected to the same server.
If paramiko is None, the default for the platform is chosen.
"""
if paramiko is None:
paramiko = sys.platform == "win32"
if not paramiko:
f = _try_passwordless_openssh
else:
f = _try_passwordless_paramiko
return f(server, keyfile)
def _try_passwordless_openssh(server, keyfile):
"""Try passwordless login with shell ssh command."""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko")
cmd = "ssh -f " + server
if keyfile:
cmd += " -i " + keyfile
cmd += " exit"
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop("SSH_ASKPASS", None)
ssh_newkey = "Are you sure you want to continue connecting"
p = pexpect.spawn(cmd, env=env)
while True:
try:
i = p.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
raise SSHException("The authenticity of the host can't be established.")
except pexpect.TIMEOUT:
continue
except pexpect.EOF:
return True
else:
return False
def _try_passwordless_paramiko(server, keyfile):
"""Try passwordless login with paramiko."""
if paramiko is None:
msg = "Paramiko unavailable, "
if sys.platform == "win32":
msg += "Paramiko is required for ssh tunneled connections on Windows."
else:
msg += "use OpenSSH."
raise ImportError(msg)
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(server, port, username=username, key_filename=keyfile, look_for_keys=True)
except paramiko.AuthenticationException:
return False
else:
client.close()
return True
def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60):
"""Connect a socket to an address via an ssh tunnel.
This is a wrapper for socket.connect(addr), when addr is not accessible
from the local machine. It simply creates an ssh tunnel using the remaining args,
and calls socket.connect('tcp://localhost:lport') where lport is the randomly
selected local port of the tunnel.
"""
new_url, tunnel = open_tunnel(
addr,
server,
keyfile=keyfile,
password=password,
paramiko=paramiko,
timeout=timeout,
)
socket.connect(new_url)
return tunnel
def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60):
"""Open a tunneled connection from a 0MQ url.
For use inside tunnel_connection.
Returns
-------
(url, tunnel) : (str, object)
The 0MQ url that has been forwarded, and the tunnel object
"""
lport = select_random_ports(1)[0]
transport, addr = addr.split("://")
ip, rport = addr.split(":")
rport = int(rport)
if paramiko is None:
paramiko = sys.platform == "win32"
if paramiko:
tunnelf = paramiko_tunnel
else:
tunnelf = openssh_tunnel
tunnel = tunnelf(
lport,
rport,
server,
remoteip=ip,
keyfile=keyfile,
password=password,
timeout=timeout,
)
return "tcp://127.0.0.1:%i" % lport, tunnel
def openssh_tunnel(
lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60
):
"""Create an ssh tunnel using command-line ssh that connects port lport
on this machine to localhost:rport on server. The tunnel
will automatically close when not in use, remaining open
for a minimum of timeout seconds for an initial connection.
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko_tunnel")
ssh = "ssh "
if keyfile:
ssh += "-i " + keyfile
if ":" in server:
server, port = server.split(":")
ssh += " -p %s" % port
cmd = "%s -O check %s" % (ssh, server)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")]) # noqa
cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % (
ssh,
lport,
remoteip,
rport,
server,
)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1))
return pid
cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % (
ssh,
lport,
remoteip,
rport,
server,
timeout,
)
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop("SSH_ASKPASS", None)
ssh_newkey = "Are you sure you want to continue connecting"
tunnel = pexpect.spawn(cmd, env=env)
failed = False
while True:
try:
i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1)
if i == 0:
raise SSHException("The authenticity of the host can't be established.")
except pexpect.TIMEOUT:
continue
except pexpect.EOF as e:
if tunnel.exitstatus:
print(tunnel.exitstatus)
print(tunnel.before)
print(tunnel.after)
raise RuntimeError("tunnel '%s' failed to start" % (cmd)) from e
else:
return tunnel.pid
else:
if failed:
print("Password rejected, try again")
password = None
if password is None:
password = getpass("%s's password: " % (server))
tunnel.sendline(password)
failed = True
def _stop_tunnel(cmd):
pexpect.run(cmd)
def _split_server(server):
if "@" in server:
username, server = server.split("@", 1)
else:
username = getuser()
if ":" in server:
server, port = server.split(":")
port = int(port)
else:
port = 22
return username, server, port
def paramiko_tunnel(
lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60
):
"""launch a tunner with paramiko in a subprocess. This should only be used
when shell ssh is unavailable (e.g. Windows).
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
If you are familiar with ssh tunnels, this creates the tunnel:
ssh server -L localhost:lport:remoteip:rport
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if paramiko is None:
raise ImportError("Paramiko not available")
if password is None:
if not _try_passwordless_paramiko(server, keyfile):
password = getpass("%s's password: " % (server))
p = Process(
target=_paramiko_tunnel,
args=(lport, rport, server, remoteip),
kwargs=dict(keyfile=keyfile, password=password),
)
p.daemon = True
p.start()
return p
def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
"""Function for actually starting a paramiko tunnel, to be passed
to multiprocessing.Process(target=this), and not called directly.
"""
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(
server,
port,
username=username,
key_filename=keyfile,
look_for_keys=True,
password=password,
)
# except paramiko.AuthenticationException:
# if password is None:
# password = getpass("%s@%s's password: "%(username, server))
# client.connect(server, port, username=username, password=password)
# else:
# raise
except Exception as e:
print("*** Failed to connect to %s:%d: %r" % (server, port, e))
sys.exit(1)
# Don't let SIGINT kill the tunnel subprocess
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
forward_tunnel(lport, remoteip, rport, client.get_transport())
except KeyboardInterrupt:
print("SIGINT: Port forwarding stopped cleanly")
sys.exit(0)
except Exception as e:
print("Port forwarding stopped uncleanly: %s" % e)
sys.exit(255)
if sys.platform == "win32":
ssh_tunnel = paramiko_tunnel
else:
ssh_tunnel = openssh_tunnel
__all__ = [
"tunnel_connection",
"ssh_tunnel",
"openssh_tunnel",
"paramiko_tunnel",
"try_passwordless_ssh",
]

View File

@@ -0,0 +1,62 @@
import asyncio
import os
import sys
import pytest
from jupyter_core import paths
from .utils import test_env
try:
import resource
except ImportError:
# Windows
resource = None
pjoin = os.path.join
# Handle resource limit
# Ensure a minimal soft limit of DEFAULT_SOFT if the current hard limit is at least that much.
if resource is not None:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
DEFAULT_SOFT = 4096
if hard >= DEFAULT_SOFT:
soft = DEFAULT_SOFT
if hard < soft:
hard = soft
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
if os.name == "nt" and sys.version_info >= (3, 7):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture
def event_loop():
# Make sure we test against a selector event loop
# since pyzmq doesn't like the proactor loop.
# This fixture is picked up by pytest-asyncio
if os.name == "nt" and sys.version_info >= (3, 7):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
loop = asyncio.SelectorEventLoop()
try:
yield loop
finally:
loop.close()
@pytest.fixture(autouse=True)
def env():
env_patch = test_env()
env_patch.start()
yield
env_patch.stop()
@pytest.fixture()
def kernel_dir():
return pjoin(paths.jupyter_data_dir(), 'kernels')

View File

@@ -0,0 +1,39 @@
"""Test kernel for signalling subprocesses"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import time
from ipykernel.displayhook import ZMQDisplayHook
from ipykernel.kernelapp import IPKernelApp
from ipykernel.kernelbase import Kernel
class ProblemTestKernel(Kernel):
"""Kernel for testing kernel problems"""
implementation = "problemtest"
implementation_version = "0.0"
banner = ""
class ProblemTestApp(IPKernelApp):
kernel_class = ProblemTestKernel
def init_io(self):
# Overridden to disable stdout/stderr capture
self.displayhook = ZMQDisplayHook(self.session, self.iopub_socket)
def init_sockets(self):
if os.environ.get("FAIL_ON_START") == "1":
# Simulates e.g. a port binding issue (Adress already in use)
raise RuntimeError("Failed for testing purposes")
return super().init_sockets()
if __name__ == "__main__":
# make startup artificially slow,
# so that we exercise client logic for slow-starting kernels
startup_delay = int(os.environ.get("STARTUP_DELAY", "2"))
time.sleep(startup_delay)
ProblemTestApp.launch_instance()

View File

@@ -0,0 +1,77 @@
"""Test kernel for signalling subprocesses"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import signal
import time
from subprocess import PIPE
from subprocess import Popen
from ipykernel.displayhook import ZMQDisplayHook
from ipykernel.kernelapp import IPKernelApp
from ipykernel.kernelbase import Kernel
class SignalTestKernel(Kernel):
"""Kernel for testing subprocess signaling"""
implementation = "signaltest"
implementation_version = "0.0"
banner = ""
def __init__(self, **kwargs):
kwargs.pop("user_ns", None)
super().__init__(**kwargs)
self.children = []
if os.environ.get("NO_SIGTERM_REPLY", None) == "1":
signal.signal(signal.SIGTERM, signal.SIG_IGN)
async def shutdown_request(self, stream, ident, parent):
if os.environ.get("NO_SHUTDOWN_REPLY") != "1":
await super().shutdown_request(stream, ident, parent)
def do_execute(
self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
):
code = code.strip()
reply = {
"status": "ok",
"user_expressions": {},
}
if code == "start":
child = Popen(["bash", "-i", "-c", "sleep 30"], stderr=PIPE)
self.children.append(child)
reply["user_expressions"]["pid"] = self.children[-1].pid
elif code == "check":
reply["user_expressions"]["poll"] = [child.poll() for child in self.children]
elif code == "env":
reply["user_expressions"]["env"] = os.getenv("TEST_VARS", "")
elif code == "sleep":
try:
time.sleep(10)
except KeyboardInterrupt:
reply["user_expressions"]["interrupted"] = True
else:
reply["user_expressions"]["interrupted"] = False
else:
reply["status"] = "error"
reply["ename"] = "Error"
reply["evalue"] = code
reply["traceback"] = ["no such command: %s" % code]
return reply
class SignalTestApp(IPKernelApp):
kernel_class = SignalTestKernel
def init_io(self):
# Overridden to disable stdout/stderr capture
self.displayhook = ZMQDisplayHook(self.session, self.iopub_socket)
if __name__ == "__main__":
# make startup artificially slow,
# so that we exercise client logic for slow-starting kernels
time.sleep(2)
SignalTestApp.launch_instance()

View File

@@ -0,0 +1,457 @@
"""Tests for adapting Jupyter msg spec versions"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import copy
import json
from unittest import TestCase
from jupyter_client.adapter import adapt
from jupyter_client.adapter import code_to_line
from jupyter_client.adapter import V4toV5
from jupyter_client.session import Session
def test_default_version():
s = Session()
msg = s.msg("msg_type")
msg["header"].pop("version")
original = copy.deepcopy(msg)
adapted = adapt(original)
assert adapted["header"]["version"] == V4toV5.version
def test_code_to_line_no_code():
line, pos = code_to_line("", 0)
assert line == ""
assert pos == 0
class AdapterTest(TestCase):
def setUp(self):
self.session = Session()
def adapt(self, msg, version=None):
original = copy.deepcopy(msg)
adapted = adapt(msg, version or self.to_version)
return original, adapted
def check_header(self, msg):
pass
class V4toV5TestCase(AdapterTest):
from_version = 4
to_version = 5
def msg(self, msg_type, content):
"""Create a v4 msg (same as v5, minus version header)"""
msg = self.session.msg(msg_type, content)
msg["header"].pop("version")
return msg
def test_same_version(self):
msg = self.msg("execute_result", content={"status": "ok"})
original, adapted = self.adapt(msg, self.from_version)
self.assertEqual(original, adapted)
def test_no_adapt(self):
msg = self.msg("input_reply", {"value": "some text"})
v4, v5 = self.adapt(msg)
self.assertEqual(v5["header"]["version"], V4toV5.version)
v5["header"].pop("version")
self.assertEqual(v4, v5)
def test_rename_type(self):
for v5_type, v4_type in [
("execute_result", "pyout"),
("execute_input", "pyin"),
("error", "pyerr"),
]:
msg = self.msg(v4_type, {"key": "value"})
v4, v5 = self.adapt(msg)
self.assertEqual(v5["header"]["version"], V4toV5.version)
self.assertEqual(v5["header"]["msg_type"], v5_type)
self.assertEqual(v4["content"], v5["content"])
def test_execute_request(self):
msg = self.msg(
"execute_request",
{
"code": "a=5",
"silent": False,
"user_expressions": {"a": "apple"},
"user_variables": ["b"],
},
)
v4, v5 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], v5["header"]["msg_type"])
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v5c["user_expressions"], {"a": "apple", "b": "b"})
self.assertNotIn("user_variables", v5c)
self.assertEqual(v5c["code"], v4c["code"])
def test_execute_reply(self):
msg = self.msg(
"execute_reply",
{
"status": "ok",
"execution_count": 7,
"user_variables": {"a": 1},
"user_expressions": {"a+a": 2},
"payload": [{"source": "page", "text": "blah"}],
},
)
v4, v5 = self.adapt(msg)
v5c = v5["content"]
self.assertNotIn("user_variables", v5c)
self.assertEqual(v5c["user_expressions"], {"a": 1, "a+a": 2})
self.assertEqual(v5c["payload"], [{"source": "page", "data": {"text/plain": "blah"}}])
def test_complete_request(self):
msg = self.msg(
"complete_request",
{
"text": "a.is",
"line": "foo = a.is",
"block": None,
"cursor_pos": 10,
},
)
v4, v5 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
for key in ("text", "line", "block"):
self.assertNotIn(key, v5c)
self.assertEqual(v5c["cursor_pos"], v4c["cursor_pos"])
self.assertEqual(v5c["code"], v4c["line"])
def test_complete_reply(self):
msg = self.msg(
"complete_reply",
{
"matched_text": "a.is",
"matches": [
"a.isalnum",
"a.isalpha",
"a.isdigit",
"a.islower",
],
},
)
v4, v5 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v5c["matches"], v4c["matches"])
self.assertEqual(v5c["metadata"], {})
self.assertEqual(v5c["cursor_start"], -4)
self.assertEqual(v5c["cursor_end"], None)
def test_object_info_request(self):
msg = self.msg(
"object_info_request",
{
"oname": "foo",
"detail_level": 1,
},
)
v4, v5 = self.adapt(msg)
self.assertEqual(v5["header"]["msg_type"], "inspect_request")
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v5c["code"], v4c["oname"])
self.assertEqual(v5c["cursor_pos"], len(v4c["oname"]))
self.assertEqual(v5c["detail_level"], v4c["detail_level"])
def test_object_info_reply(self):
msg = self.msg(
"object_info_reply",
{
"name": "foo",
"found": True,
"status": "ok",
"definition": "foo(a=5)",
"docstring": "the docstring",
},
)
v4, v5 = self.adapt(msg)
self.assertEqual(v5["header"]["msg_type"], "inspect_reply")
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(sorted(v5c), ["data", "found", "metadata", "status"])
text = v5c["data"]["text/plain"]
self.assertEqual(text, "\n".join([v4c["definition"], v4c["docstring"]]))
def test_object_info_reply_not_found(self):
msg = self.msg(
"object_info_reply",
{
"name": "foo",
"found": False,
},
)
v4, v5 = self.adapt(msg)
self.assertEqual(v5["header"]["msg_type"], "inspect_reply")
v4["content"]
v5c = v5["content"]
self.assertEqual(
v5c,
{
"status": "ok",
"found": False,
"data": {},
"metadata": {},
},
)
def test_kernel_info_reply(self):
msg = self.msg(
"kernel_info_reply",
{
"language": "python",
"language_version": [2, 8, 0],
"ipython_version": [1, 2, 3],
},
)
v4, v5 = self.adapt(msg)
v4["content"]
v5c = v5["content"]
self.assertEqual(
v5c,
{
"protocol_version": "4.1",
"implementation": "ipython",
"implementation_version": "1.2.3",
"language_info": {
"name": "python",
"version": "2.8.0",
},
"banner": "",
},
)
# iopub channel
def test_display_data(self):
jsondata = dict(a=5)
msg = self.msg(
"display_data",
{
"data": {
"text/plain": "some text",
"application/json": json.dumps(jsondata),
},
"metadata": {"text/plain": {"key": "value"}},
},
)
v4, v5 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v5c["metadata"], v4c["metadata"])
self.assertEqual(v5c["data"]["text/plain"], v4c["data"]["text/plain"])
self.assertEqual(v5c["data"]["application/json"], jsondata)
# stdin channel
def test_input_request(self):
msg = self.msg("input_request", {"prompt": "$>"})
v4, v5 = self.adapt(msg)
self.assertEqual(v5["content"]["prompt"], v4["content"]["prompt"])
self.assertFalse(v5["content"]["password"])
class V5toV4TestCase(AdapterTest):
from_version = 5
to_version = 4
def msg(self, msg_type, content):
return self.session.msg(msg_type, content)
def test_same_version(self):
msg = self.msg("execute_result", content={"status": "ok"})
original, adapted = self.adapt(msg, self.from_version)
self.assertEqual(original, adapted)
def test_no_adapt(self):
msg = self.msg("input_reply", {"value": "some text"})
v5, v4 = self.adapt(msg)
self.assertNotIn("version", v4["header"])
v5["header"].pop("version")
self.assertEqual(v4, v5)
def test_rename_type(self):
for v5_type, v4_type in [
("execute_result", "pyout"),
("execute_input", "pyin"),
("error", "pyerr"),
]:
msg = self.msg(v5_type, {"key": "value"})
v5, v4 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], v4_type)
assert "version" not in v4["header"]
self.assertEqual(v4["content"], v5["content"])
def test_execute_request(self):
msg = self.msg(
"execute_request",
{
"code": "a=5",
"silent": False,
"user_expressions": {"a": "apple"},
},
)
v5, v4 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], v5["header"]["msg_type"])
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v4c["user_variables"], [])
self.assertEqual(v5c["code"], v4c["code"])
def test_complete_request(self):
msg = self.msg(
"complete_request",
{
"code": "def foo():\n a.is\nfoo()",
"cursor_pos": 19,
},
)
v5, v4 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
self.assertNotIn("code", v4c)
self.assertEqual(v4c["line"], v5c["code"].splitlines(True)[1])
self.assertEqual(v4c["cursor_pos"], 8)
self.assertEqual(v4c["text"], "")
self.assertEqual(v4c["block"], None)
def test_complete_reply(self):
msg = self.msg(
"complete_reply",
{
"cursor_start": 10,
"cursor_end": 14,
"matches": [
"a.isalnum",
"a.isalpha",
"a.isdigit",
"a.islower",
],
"metadata": {},
},
)
v5, v4 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v4c["matched_text"], "a.is")
self.assertEqual(v4c["matches"], v5c["matches"])
def test_inspect_request(self):
msg = self.msg(
"inspect_request",
{
"code": "def foo():\n apple\nbar()",
"cursor_pos": 18,
"detail_level": 1,
},
)
v5, v4 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], "object_info_request")
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v4c["oname"], "apple")
self.assertEqual(v5c["detail_level"], v4c["detail_level"])
def test_inspect_request_token(self):
line = "something(range(10), kwarg=smth) ; xxx.xxx.xxx( firstarg, rand(234,23), kwarg1=2,"
msg = self.msg(
"inspect_request",
{
"code": line,
"cursor_pos": len(line) - 1,
"detail_level": 1,
},
)
v5, v4 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], "object_info_request")
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v4c["oname"], "xxx.xxx.xxx")
self.assertEqual(v5c["detail_level"], v4c["detail_level"])
def test_inspect_reply(self):
msg = self.msg(
"inspect_reply",
{
"name": "foo",
"found": True,
"data": {"text/plain": "some text"},
"metadata": {},
},
)
v5, v4 = self.adapt(msg)
self.assertEqual(v4["header"]["msg_type"], "object_info_reply")
v4c = v4["content"]
v5["content"]
self.assertEqual(sorted(v4c), ["found", "oname"])
self.assertEqual(v4c["found"], False)
def test_kernel_info_reply(self):
msg = self.msg(
"kernel_info_reply",
{
"protocol_version": "5.0",
"implementation": "ipython",
"implementation_version": "1.2.3",
"language_info": {
"name": "python",
"version": "2.8.0",
"mimetype": "text/x-python",
},
"banner": "the banner",
},
)
v5, v4 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
v5c["language_info"]
self.assertEqual(
v4c,
{
"protocol_version": [5, 0],
"language": "python",
"language_version": [2, 8, 0],
"ipython_version": [1, 2, 3],
},
)
# iopub channel
def test_display_data(self):
jsondata = dict(a=5)
msg = self.msg(
"display_data",
{
"data": {
"text/plain": "some text",
"application/json": jsondata,
},
"metadata": {"text/plain": {"key": "value"}},
},
)
v5, v4 = self.adapt(msg)
v4c = v4["content"]
v5c = v5["content"]
self.assertEqual(v5c["metadata"], v4c["metadata"])
self.assertEqual(v5c["data"]["text/plain"], v4c["data"]["text/plain"])
self.assertEqual(v4c["data"]["application/json"], json.dumps(jsondata))
# stdin channel
def test_input_request(self):
msg = self.msg("input_request", {"prompt": "$>", "password": True})
v5, v4 = self.adapt(msg)
self.assertEqual(v5["content"]["prompt"], v4["content"]["prompt"])
self.assertNotIn("password", v4["content"])

View File

@@ -0,0 +1,94 @@
"""Tests for the KernelClient"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from unittest import TestCase
import pytest
from IPython.utils.capture import capture_output
from ..manager import start_new_kernel
from .utils import test_env
from jupyter_client.kernelspec import KernelSpecManager
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
from jupyter_client.kernelspec import NoSuchKernel
TIMEOUT = 30
pjoin = os.path.join
class TestKernelClient(TestCase):
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
self.addCleanup(self.env_patch.stop)
try:
KernelSpecManager().get_kernel_spec(NATIVE_KERNEL_NAME)
except NoSuchKernel:
pytest.skip()
self.km, self.kc = start_new_kernel(kernel_name=NATIVE_KERNEL_NAME)
def tearDown(self):
self.env_patch.stop()
self.km.shutdown_kernel()
self.kc.stop_channels()
return super().tearDown()
def test_execute_interactive(self):
kc = self.kc
with capture_output() as io:
reply = kc.execute_interactive("print('hello')", timeout=TIMEOUT)
assert "hello" in io.stdout
assert reply["content"]["status"] == "ok"
def _check_reply(self, reply_type, reply):
self.assertIsInstance(reply, dict)
self.assertEqual(reply["header"]["msg_type"], reply_type + "_reply")
self.assertEqual(reply["parent_header"]["msg_type"], reply_type + "_request")
def test_history(self):
kc = self.kc
msg_id = kc.history(session=0)
self.assertIsInstance(msg_id, str)
reply = kc.history(session=0, reply=True, timeout=TIMEOUT)
self._check_reply("history", reply)
def test_inspect(self):
kc = self.kc
msg_id = kc.inspect("who cares")
self.assertIsInstance(msg_id, str)
reply = kc.inspect("code", reply=True, timeout=TIMEOUT)
self._check_reply("inspect", reply)
def test_complete(self):
kc = self.kc
msg_id = kc.complete("who cares")
self.assertIsInstance(msg_id, str)
reply = kc.complete("code", reply=True, timeout=TIMEOUT)
self._check_reply("complete", reply)
def test_kernel_info(self):
kc = self.kc
msg_id = kc.kernel_info()
self.assertIsInstance(msg_id, str)
reply = kc.kernel_info(reply=True, timeout=TIMEOUT)
self._check_reply("kernel_info", reply)
def test_comm_info(self):
kc = self.kc
msg_id = kc.comm_info()
self.assertIsInstance(msg_id, str)
reply = kc.comm_info(reply=True, timeout=TIMEOUT)
self._check_reply("comm_info", reply)
def test_shutdown(self):
kc = self.kc
reply = kc.shutdown(reply=True, timeout=TIMEOUT)
self._check_reply("shutdown", reply)
def test_shutdown_id(self):
kc = self.kc
msg_id = kc.shutdown()
self.assertIsInstance(msg_id, str)

View File

@@ -0,0 +1,237 @@
"""Tests for kernel connection utilities"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import os
from tempfile import TemporaryDirectory
from jupyter_core.application import JupyterApp
from jupyter_core.paths import jupyter_runtime_dir
from jupyter_client import connect
from jupyter_client import KernelClient
from jupyter_client.consoleapp import JupyterConsoleApp
from jupyter_client.session import Session
class TemporaryWorkingDirectory(TemporaryDirectory):
"""
Creates a temporary directory and sets the cwd to that directory.
Automatically reverts to previous cwd upon cleanup.
Usage example:
with TemporaryWorkingDirectory() as tmpdir:
...
"""
def __enter__(self):
self.old_wd = os.getcwd()
os.chdir(self.name)
return super().__enter__()
def __exit__(self, exc, value, tb):
os.chdir(self.old_wd)
return super().__exit__(exc, value, tb)
class DummyConsoleApp(JupyterApp, JupyterConsoleApp):
def initialize(self, argv=None):
JupyterApp.initialize(self, argv=argv or [])
self.init_connection_file()
class DummyConfigurable(connect.ConnectionFileMixin):
def initialize(self):
pass
sample_info = dict(
ip="1.2.3.4",
transport="ipc",
shell_port=1,
hb_port=2,
iopub_port=3,
stdin_port=4,
control_port=5,
key=b"abc123",
signature_scheme="hmac-md5",
kernel_name="python",
)
sample_info_kn = dict(
ip="1.2.3.4",
transport="ipc",
shell_port=1,
hb_port=2,
iopub_port=3,
stdin_port=4,
control_port=5,
key=b"abc123",
signature_scheme="hmac-md5",
kernel_name="test",
)
def test_write_connection_file():
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
connect.write_connection_file(cf, **sample_info)
assert os.path.exists(cf)
with open(cf, "r") as f:
info = json.load(f)
info["key"] = info["key"].encode()
assert info == sample_info
def test_load_connection_file_session():
"""test load_connection_file() after"""
session = Session()
app = DummyConsoleApp(session=Session())
app.initialize(argv=[])
session = app.session
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
connect.write_connection_file(cf, **sample_info)
app.connection_file = cf
app.load_connection_file()
assert session.key == sample_info["key"]
assert session.signature_scheme == sample_info["signature_scheme"]
def test_load_connection_file_session_with_kn():
"""test load_connection_file() after"""
session = Session()
app = DummyConsoleApp(session=Session())
app.initialize(argv=[])
session = app.session
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
connect.write_connection_file(cf, **sample_info_kn)
app.connection_file = cf
app.load_connection_file()
assert session.key == sample_info_kn["key"]
assert session.signature_scheme == sample_info_kn["signature_scheme"]
def test_app_load_connection_file():
"""test `ipython console --existing` loads a connection file"""
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
connect.write_connection_file(cf, **sample_info)
app = DummyConsoleApp(connection_file=cf)
app.initialize(argv=[])
for attr, expected in sample_info.items():
if attr in ("key", "signature_scheme"):
continue
value = getattr(app, attr)
assert value == expected, "app.%s = %s != %s" % (attr, value, expected)
def test_load_connection_info():
client = KernelClient()
info = {
"control_port": 53702,
"hb_port": 53705,
"iopub_port": 53703,
"ip": "0.0.0.0",
"key": "secret",
"shell_port": 53700,
"signature_scheme": "hmac-sha256",
"stdin_port": 53701,
"transport": "tcp",
}
client.load_connection_info(info)
assert client.control_port == info["control_port"]
assert client.session.key.decode("ascii") == info["key"]
assert client.ip == info["ip"]
def test_find_connection_file():
with TemporaryDirectory() as d:
cf = "kernel.json"
app = DummyConsoleApp(runtime_dir=d, connection_file=cf)
app.initialize()
security_dir = app.runtime_dir
profile_cf = os.path.join(security_dir, cf)
with open(profile_cf, "w") as f:
f.write("{}")
for query in (
"kernel.json",
"kern*",
"*ernel*",
"k*",
):
assert connect.find_connection_file(query, path=security_dir) == profile_cf
def test_find_connection_file_local():
with TemporaryWorkingDirectory():
cf = "test.json"
abs_cf = os.path.abspath(cf)
with open(cf, "w") as f:
f.write("{}")
for query in (
"test.json",
"test",
abs_cf,
os.path.join(".", "test.json"),
):
assert connect.find_connection_file(query, path=[".", jupyter_runtime_dir()]) == abs_cf
def test_find_connection_file_relative():
with TemporaryWorkingDirectory():
cf = "test.json"
os.mkdir("subdir")
cf = os.path.join("subdir", "test.json")
abs_cf = os.path.abspath(cf)
with open(cf, "w") as f:
f.write("{}")
for query in (
os.path.join(".", "subdir", "test.json"),
os.path.join("subdir", "test.json"),
abs_cf,
):
assert connect.find_connection_file(query, path=[".", jupyter_runtime_dir()]) == abs_cf
def test_find_connection_file_abspath():
with TemporaryDirectory():
cf = "absolute.json"
abs_cf = os.path.abspath(cf)
with open(cf, "w") as f:
f.write("{}")
assert connect.find_connection_file(abs_cf, path=jupyter_runtime_dir()) == abs_cf
os.remove(abs_cf)
def test_mixin_record_random_ports():
with TemporaryDirectory() as d:
dc = DummyConfigurable(data_dir=d, kernel_name="via-tcp", transport="tcp")
dc.write_connection_file()
assert dc._connection_file_written
assert os.path.exists(dc.connection_file)
assert dc._random_port_names == connect.port_names
def test_mixin_cleanup_random_ports():
with TemporaryDirectory() as d:
dc = DummyConfigurable(data_dir=d, kernel_name="via-tcp", transport="tcp")
dc.write_connection_file()
filename = dc.connection_file
dc.cleanup_random_ports()
assert not os.path.exists(filename)
for name in dc._random_port_names:
assert getattr(dc, name) == 0

View File

@@ -0,0 +1,131 @@
"""Test suite for our JSON utilities."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import datetime
import json
import numbers
from datetime import timedelta
from unittest import mock
import pytest
from dateutil.tz import tzlocal
from dateutil.tz import tzoffset
from jupyter_client import jsonutil
from jupyter_client.session import utcnow
REFERENCE_DATETIME = datetime.datetime(2013, 7, 3, 16, 34, 52, 249482, tzlocal())
class MyInt(object):
def __int__(self):
return 389
numbers.Integral.register(MyInt)
class MyFloat(object):
def __float__(self):
return 3.14
numbers.Real.register(MyFloat)
def test_extract_date_from_naive():
ref = REFERENCE_DATETIME
timestamp = "2013-07-03T16:34:52.249482"
with pytest.deprecated_call(match="Interpreting naive datetime as local"):
extracted = jsonutil.extract_dates(timestamp)
assert isinstance(extracted, datetime.datetime)
assert extracted.tzinfo is not None
assert extracted.tzinfo.utcoffset(ref) == tzlocal().utcoffset(ref)
assert extracted == ref
def test_extract_dates():
ref = REFERENCE_DATETIME
timestamps = [
"2013-07-03T16:34:52.249482Z",
"2013-07-03T16:34:52.249482-0800",
"2013-07-03T16:34:52.249482+0800",
"2013-07-03T16:34:52.249482-08:00",
"2013-07-03T16:34:52.249482+08:00",
]
extracted = jsonutil.extract_dates(timestamps)
for dt in extracted:
assert isinstance(dt, datetime.datetime)
assert dt.tzinfo is not None
assert extracted[0].tzinfo.utcoffset(ref) == timedelta(0)
assert extracted[1].tzinfo.utcoffset(ref) == timedelta(hours=-8)
assert extracted[2].tzinfo.utcoffset(ref) == timedelta(hours=8)
assert extracted[3].tzinfo.utcoffset(ref) == timedelta(hours=-8)
assert extracted[4].tzinfo.utcoffset(ref) == timedelta(hours=8)
def test_parse_ms_precision():
base = "2013-07-03T16:34:52"
digits = "1234567890"
parsed = jsonutil.parse_date(base + "Z")
assert isinstance(parsed, datetime.datetime)
for i in range(len(digits)):
ts = base + "." + digits[:i]
parsed = jsonutil.parse_date(ts + "Z")
if i >= 1 and i <= 6:
assert isinstance(parsed, datetime.datetime)
else:
assert isinstance(parsed, str)
def test_json_default_date():
naive = datetime.datetime.now()
local = tzoffset("Local", -8 * 3600)
other = tzoffset("Other", 2 * 3600)
data = dict(naive=naive, utc=utcnow(), withtz=naive.replace(tzinfo=other))
with mock.patch.object(jsonutil, "tzlocal", lambda: local):
with pytest.deprecated_call(match="Please add timezone info"):
jsondata = json.dumps(data, default=jsonutil.json_default)
assert "Z" in jsondata
assert jsondata.count("Z") == 1
extracted = jsonutil.extract_dates(json.loads(jsondata))
for dt in extracted.values():
assert isinstance(dt, datetime.datetime)
assert dt.tzinfo is not None
def test_json_default():
# list of input/expected output. Use None for the expected output if it
# can be the same as the input.
pairs = [
(1, None), # start with scalars
(1.123, None),
(1.0, None),
('a', None),
(True, None),
(False, None),
(None, None),
({"key": b"\xFF"}, {"key": "/w==\n"}),
# Containers
([1, 2], None),
((1, 2), [1, 2]),
(set([1, 2]), [1, 2]),
(dict(x=1), None),
({'x': 1, 'y': [1, 2, 3], '1': 'int'}, None),
# More exotic objects
((x for x in range(3)), [0, 1, 2]),
(iter([1, 2]), [1, 2]),
(MyFloat(), 3.14),
(MyInt(), 389),
]
for val, jval in pairs:
if jval is None:
jval = val
out = json.loads(json.dumps(val, default=jsonutil.json_default))
# validate our cleanup
assert out == jval

View File

@@ -0,0 +1,64 @@
import os
import shutil
import sys
import time
from subprocess import PIPE
from subprocess import Popen
from tempfile import mkdtemp
def _launch(extra_env):
env = os.environ.copy()
env.update(extra_env)
return Popen(
[sys.executable, "-c", "from jupyter_client.kernelapp import main; main()"],
env=env,
stderr=PIPE,
)
WAIT_TIME = 10
POLL_FREQ = 10
def test_kernelapp_lifecycle():
# Check that 'jupyter kernel' starts and terminates OK.
runtime_dir = mkdtemp()
startup_dir = mkdtemp()
started = os.path.join(startup_dir, "started")
try:
p = _launch(
{
"JUPYTER_RUNTIME_DIR": runtime_dir,
"JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE": started,
}
)
# Wait for start
for _ in range(WAIT_TIME * POLL_FREQ):
if os.path.isfile(started):
break
time.sleep(1 / POLL_FREQ)
else:
raise AssertionError("No started file created in {} seconds".format(WAIT_TIME))
# Connection file should be there by now
for _ in range(WAIT_TIME * POLL_FREQ):
files = os.listdir(runtime_dir)
if files:
break
time.sleep(1 / POLL_FREQ)
else:
raise AssertionError("No connection file created in {} seconds".format(WAIT_TIME))
assert len(files) == 1
cf = files[0]
assert cf.startswith("kernel")
assert cf.endswith(".json")
# Send SIGTERM to shut down
time.sleep(1)
p.terminate()
_, stderr = p.communicate(timeout=WAIT_TIME)
assert cf in stderr.decode("utf-8", "replace")
finally:
shutil.rmtree(runtime_dir)
shutil.rmtree(startup_dir)

View File

@@ -0,0 +1,585 @@
"""Tests for the KernelManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import concurrent.futures
import json
import os
import signal
import sys
import time
from subprocess import PIPE
import pytest
from jupyter_core import paths
from traitlets.config.loader import Config
from ..manager import _ShutdownStatus
from ..manager import start_new_async_kernel
from ..manager import start_new_kernel
from .utils import AsyncKMSubclass
from .utils import SyncKMSubclass
from jupyter_client import AsyncKernelManager
from jupyter_client import KernelManager
pjoin = os.path.join
TIMEOUT = 30
@pytest.fixture(params=["tcp", "ipc"])
def transport(request):
if sys.platform == "win32" and request.param == "ipc": #
pytest.skip("Transport 'ipc' not supported on Windows.")
return request.param
@pytest.fixture
def config(transport):
c = Config()
c.KernelManager.transport = transport
if transport == "ipc":
c.KernelManager.ip = "test"
return c
def _install_kernel(name="signaltest", extra_env=None):
if extra_env is None:
extra_env = {}
kernel_dir = pjoin(paths.jupyter_data_dir(), "kernels", name)
os.makedirs(kernel_dir)
with open(pjoin(kernel_dir, "kernel.json"), "w") as f:
f.write(
json.dumps(
{
"argv": [
sys.executable,
"-m",
"jupyter_client.tests.signalkernel",
"-f",
"{connection_file}",
],
"display_name": "Signal Test Kernel",
"env": {"TEST_VARS": "${TEST_VARS}:test_var_2", **extra_env},
}
)
)
@pytest.fixture
def install_kernel():
return _install_kernel()
def install_kernel_dont_shutdown():
_install_kernel("signaltest-no-shutdown", {"NO_SHUTDOWN_REPLY": "1"})
def install_kernel_dont_terminate():
return _install_kernel(
"signaltest-no-terminate", {"NO_SHUTDOWN_REPLY": "1", "NO_SIGTERM_REPLY": "1"}
)
@pytest.fixture
def start_kernel():
km, kc = start_new_kernel(kernel_name="signaltest")
yield km, kc
kc.stop_channels()
km.shutdown_kernel()
assert km.context.closed
@pytest.fixture
def km(config):
km = KernelManager(config=config)
return km
@pytest.fixture
def km_subclass(config):
km = SyncKMSubclass(config=config)
return km
@pytest.fixture
def zmq_context():
import zmq
ctx = zmq.Context()
yield ctx
ctx.term()
@pytest.fixture(params=[AsyncKernelManager, AsyncKMSubclass])
def async_km(request, config):
km = request.param(config=config)
return km
@pytest.fixture
def async_km_subclass(config):
km = AsyncKMSubclass(config=config)
return km
@pytest.fixture
async def start_async_kernel():
km, kc = await start_new_async_kernel(kernel_name="signaltest")
yield km, kc
kc.stop_channels()
await km.shutdown_kernel()
assert km.context.closed
class TestKernelManagerShutDownGracefully:
parameters = (
"name, install, expected",
[
("signaltest", _install_kernel, _ShutdownStatus.ShutdownRequest),
(
"signaltest-no-shutdown",
install_kernel_dont_shutdown,
_ShutdownStatus.SigtermRequest,
),
(
"signaltest-no-terminate",
install_kernel_dont_terminate,
_ShutdownStatus.SigkillRequest,
),
],
)
@pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
@pytest.mark.parametrize(*parameters)
def test_signal_kernel_subprocesses(self, name, install, expected):
# ipykernel doesn't support 3.6 and this test uses async shutdown_request
if expected == _ShutdownStatus.ShutdownRequest and sys.version_info < (3, 7):
pytest.skip()
install()
km, kc = start_new_kernel(kernel_name=name)
assert km._shutdown_status == _ShutdownStatus.Unset
assert km.is_alive()
# kc.execute("1")
kc.stop_channels()
km.shutdown_kernel()
if expected == _ShutdownStatus.ShutdownRequest:
expected = [expected, _ShutdownStatus.SigtermRequest]
else:
expected = [expected]
assert km._shutdown_status in expected
@pytest.mark.asyncio
@pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
@pytest.mark.parametrize(*parameters)
async def test_async_signal_kernel_subprocesses(self, name, install, expected):
install()
km, kc = await start_new_async_kernel(kernel_name=name)
assert km._shutdown_status == _ShutdownStatus.Unset
assert await km.is_alive()
# kc.execute("1")
kc.stop_channels()
await km.shutdown_kernel()
if expected == _ShutdownStatus.ShutdownRequest:
expected = [expected, _ShutdownStatus.SigtermRequest]
else:
expected = [expected]
assert km._shutdown_status in expected
class TestKernelManager:
def test_lifecycle(self, km):
km.start_kernel(stdout=PIPE, stderr=PIPE)
kc = km.client()
assert km.is_alive()
is_done = km.ready.done()
assert is_done
km.restart_kernel(now=True)
assert km.is_alive()
km.interrupt_kernel()
assert isinstance(km, KernelManager)
kc.stop_channels()
km.shutdown_kernel(now=True)
assert km.context.closed
def test_get_connect_info(self, km):
cinfo = km.get_connection_info()
keys = sorted(cinfo.keys())
expected = sorted(
[
"ip",
"transport",
"hb_port",
"shell_port",
"stdin_port",
"iopub_port",
"control_port",
"key",
"signature_scheme",
]
)
assert keys == expected
@pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
def test_signal_kernel_subprocesses(self, install_kernel, start_kernel):
km, kc = start_kernel
def execute(cmd):
request_id = kc.execute(cmd)
while True:
reply = kc.get_shell_msg(TIMEOUT)
if reply["parent_header"]["msg_id"] == request_id:
break
content = reply["content"]
assert content["status"] == "ok"
return content
N = 5
for i in range(N):
execute("start")
time.sleep(1) # make sure subprocs stay up
reply = execute("check")
assert reply["user_expressions"]["poll"] == [None] * N
# start a job on the kernel to be interrupted
kc.execute("sleep")
time.sleep(1) # ensure sleep message has been handled before we interrupt
km.interrupt_kernel()
reply = kc.get_shell_msg(TIMEOUT)
content = reply["content"]
assert content["status"] == "ok"
assert content["user_expressions"]["interrupted"]
# wait up to 10s for subprocesses to handle signal
for i in range(100):
reply = execute("check")
if reply["user_expressions"]["poll"] != [-signal.SIGINT] * N:
time.sleep(0.1)
else:
break
# verify that subprocesses were interrupted
assert reply["user_expressions"]["poll"] == [-signal.SIGINT] * N
def test_start_new_kernel(self, install_kernel, start_kernel):
km, kc = start_kernel
assert km.is_alive()
assert kc.is_alive()
assert km.context.closed is False
def _env_test_body(self, kc):
def execute(cmd):
request_id = kc.execute(cmd)
while True:
reply = kc.get_shell_msg(TIMEOUT)
if reply["parent_header"]["msg_id"] == request_id:
break
content = reply["content"]
assert content["status"] == "ok"
return content
reply = execute("env")
assert reply is not None
assert reply["user_expressions"]["env"] == "test_var_1:test_var_2"
def test_templated_kspec_env(self, install_kernel, start_kernel):
km, kc = start_kernel
assert km.is_alive()
assert kc.is_alive()
assert km.context.closed is False
self._env_test_body(kc)
def test_cleanup_context(self, km):
assert km.context is not None
km.cleanup_resources(restart=False)
assert km.context.closed
def test_no_cleanup_shared_context(self, zmq_context):
"""kernel manager does not terminate shared context"""
km = KernelManager(context=zmq_context)
assert km.context == zmq_context
assert km.context is not None
km.cleanup_resources(restart=False)
assert km.context.closed is False
assert zmq_context.closed is False
def test_subclass_callables(self, km_subclass):
km_subclass.reset_counts()
km_subclass.start_kernel(stdout=PIPE, stderr=PIPE)
assert km_subclass.call_count("start_kernel") == 1
assert km_subclass.call_count("_launch_kernel") == 1
is_alive = km_subclass.is_alive()
assert is_alive
km_subclass.reset_counts()
km_subclass.restart_kernel(now=True)
assert km_subclass.call_count("restart_kernel") == 1
assert km_subclass.call_count("shutdown_kernel") == 1
assert km_subclass.call_count("interrupt_kernel") == 1
assert km_subclass.call_count("_kill_kernel") == 1
assert km_subclass.call_count("cleanup_resources") == 1
assert km_subclass.call_count("start_kernel") == 1
assert km_subclass.call_count("_launch_kernel") == 1
assert km_subclass.call_count("signal_kernel") == 1
is_alive = km_subclass.is_alive()
assert is_alive
assert km_subclass.call_count("is_alive") >= 1
km_subclass.reset_counts()
km_subclass.interrupt_kernel()
assert km_subclass.call_count("interrupt_kernel") == 1
assert km_subclass.call_count("signal_kernel") == 1
assert isinstance(km_subclass, KernelManager)
km_subclass.reset_counts()
km_subclass.shutdown_kernel(now=False)
assert km_subclass.call_count("shutdown_kernel") == 1
assert km_subclass.call_count("interrupt_kernel") == 1
assert km_subclass.call_count("request_shutdown") == 1
assert km_subclass.call_count("finish_shutdown") == 1
assert km_subclass.call_count("cleanup_resources") == 1
assert km_subclass.call_count("signal_kernel") == 1
assert km_subclass.call_count("is_alive") >= 1
is_alive = km_subclass.is_alive()
assert is_alive is False
assert km_subclass.call_count("is_alive") >= 1
assert km_subclass.context.closed
class TestParallel:
@pytest.mark.timeout(TIMEOUT)
def test_start_sequence_kernels(self, config, install_kernel):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_signaltest_lifecycle(config)
self._run_signaltest_lifecycle(config)
self._run_signaltest_lifecycle(config)
@pytest.mark.timeout(TIMEOUT + 10)
def test_start_parallel_thread_kernels(self, config, install_kernel):
if config.KernelManager.transport == "ipc": # FIXME
pytest.skip("IPC transport is currently not working for this test!")
self._run_signaltest_lifecycle(config)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor:
future1 = thread_executor.submit(self._run_signaltest_lifecycle, config)
future2 = thread_executor.submit(self._run_signaltest_lifecycle, config)
future1.result()
future2.result()
@pytest.mark.timeout(TIMEOUT)
@pytest.mark.skipif(
(sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)),
reason='"Bad file descriptor" error',
)
def test_start_parallel_process_kernels(self, config, install_kernel):
if config.KernelManager.transport == "ipc": # FIXME
pytest.skip("IPC transport is currently not working for this test!")
self._run_signaltest_lifecycle(config)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor:
future1 = thread_executor.submit(self._run_signaltest_lifecycle, config)
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor:
future2 = process_executor.submit(self._run_signaltest_lifecycle, config)
future2.result()
future1.result()
@pytest.mark.timeout(TIMEOUT)
@pytest.mark.skipif(
(sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)),
reason='"Bad file descriptor" error',
)
def test_start_sequence_process_kernels(self, config, install_kernel):
if config.KernelManager.transport == "ipc": # FIXME
pytest.skip("IPC transport is currently not working for this test!")
self._run_signaltest_lifecycle(config)
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool_executor:
future = pool_executor.submit(self._run_signaltest_lifecycle, config)
future.result()
def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs):
km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
raise
return kc
def _run_signaltest_lifecycle(self, config=None):
km = KernelManager(config=config, kernel_name="signaltest")
kc = self._prepare_kernel(km, stdout=PIPE, stderr=PIPE)
def execute(cmd):
request_id = kc.execute(cmd)
while True:
reply = kc.get_shell_msg(TIMEOUT)
if reply["parent_header"]["msg_id"] == request_id:
break
content = reply["content"]
assert content["status"] == "ok"
return content
execute("start")
assert km.is_alive()
execute("check")
assert km.is_alive()
km.restart_kernel(now=True)
assert km.is_alive()
execute("check")
km.shutdown_kernel()
assert km.context.closed
kc.stop_channels()
@pytest.mark.asyncio
class TestAsyncKernelManager:
async def test_lifecycle(self, async_km):
await async_km.start_kernel(stdout=PIPE, stderr=PIPE)
is_alive = await async_km.is_alive()
assert is_alive
is_ready = async_km.ready.done()
assert is_ready
await async_km.restart_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive
await async_km.interrupt_kernel()
assert isinstance(async_km, AsyncKernelManager)
await async_km.shutdown_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive is False
assert async_km.context.closed
async def test_get_connect_info(self, async_km):
cinfo = async_km.get_connection_info()
keys = sorted(cinfo.keys())
expected = sorted(
[
"ip",
"transport",
"hb_port",
"shell_port",
"stdin_port",
"iopub_port",
"control_port",
"key",
"signature_scheme",
]
)
assert keys == expected
@pytest.mark.timeout(10)
@pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals")
async def test_signal_kernel_subprocesses(self, install_kernel, start_async_kernel):
km, kc = start_async_kernel
async def execute(cmd):
request_id = kc.execute(cmd)
while True:
reply = await kc.get_shell_msg(TIMEOUT)
if reply["parent_header"]["msg_id"] == request_id:
break
content = reply["content"]
assert content["status"] == "ok"
return content
# Ensure that shutdown_kernel and stop_channels are called at the end of the test.
# Note: we cannot use addCleanup(<func>) for these since it doesn't prpperly handle
# coroutines - which km.shutdown_kernel now is.
N = 5
for i in range(N):
await execute("start")
await asyncio.sleep(1) # make sure subprocs stay up
reply = await execute("check")
assert reply["user_expressions"]["poll"] == [None] * N
# start a job on the kernel to be interrupted
request_id = kc.execute("sleep")
await asyncio.sleep(1) # ensure sleep message has been handled before we interrupt
await km.interrupt_kernel()
while True:
reply = await kc.get_shell_msg(TIMEOUT)
if reply["parent_header"]["msg_id"] == request_id:
break
content = reply["content"]
assert content["status"] == "ok"
assert content["user_expressions"]["interrupted"] is True
# wait up to 5s for subprocesses to handle signal
for i in range(50):
reply = await execute("check")
if reply["user_expressions"]["poll"] != [-signal.SIGINT] * N:
await asyncio.sleep(0.1)
else:
break
# verify that subprocesses were interrupted
assert reply["user_expressions"]["poll"] == [-signal.SIGINT] * N
@pytest.mark.timeout(10)
async def test_start_new_async_kernel(self, install_kernel, start_async_kernel):
km, kc = start_async_kernel
is_alive = await km.is_alive()
assert is_alive
is_alive = await kc.is_alive()
assert is_alive
async def test_subclass_callables(self, async_km_subclass):
async_km_subclass.reset_counts()
await async_km_subclass.start_kernel(stdout=PIPE, stderr=PIPE)
assert async_km_subclass.call_count("start_kernel") == 1
assert async_km_subclass.call_count("_launch_kernel") == 1
is_alive = await async_km_subclass.is_alive()
assert is_alive
assert async_km_subclass.call_count("is_alive") >= 1
async_km_subclass.reset_counts()
await async_km_subclass.restart_kernel(now=True)
assert async_km_subclass.call_count("restart_kernel") == 1
assert async_km_subclass.call_count("shutdown_kernel") == 1
assert async_km_subclass.call_count("interrupt_kernel") == 1
assert async_km_subclass.call_count("_kill_kernel") == 1
assert async_km_subclass.call_count("cleanup_resources") == 1
assert async_km_subclass.call_count("start_kernel") == 1
assert async_km_subclass.call_count("_launch_kernel") == 1
assert async_km_subclass.call_count("signal_kernel") == 1
is_alive = await async_km_subclass.is_alive()
assert is_alive
assert async_km_subclass.call_count("is_alive") >= 1
async_km_subclass.reset_counts()
await async_km_subclass.interrupt_kernel()
assert async_km_subclass.call_count("interrupt_kernel") == 1
assert async_km_subclass.call_count("signal_kernel") == 1
assert isinstance(async_km_subclass, AsyncKernelManager)
async_km_subclass.reset_counts()
await async_km_subclass.shutdown_kernel(now=False)
assert async_km_subclass.call_count("shutdown_kernel") == 1
assert async_km_subclass.call_count("interrupt_kernel") == 1
assert async_km_subclass.call_count("request_shutdown") == 1
assert async_km_subclass.call_count("finish_shutdown") == 1
assert async_km_subclass.call_count("cleanup_resources") == 1
assert async_km_subclass.call_count("signal_kernel") == 1
assert async_km_subclass.call_count("is_alive") >= 1
is_alive = await async_km_subclass.is_alive()
assert is_alive is False
assert async_km_subclass.call_count("is_alive") >= 1
assert async_km_subclass.context.closed

View File

@@ -0,0 +1,200 @@
"""Tests for the KernelSpecManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import copy
import json
import os
import sys
import tempfile
import unittest
from io import StringIO
from logging import StreamHandler
from os.path import join as pjoin
from subprocess import PIPE
from subprocess import Popen
from subprocess import STDOUT
from tempfile import TemporaryDirectory
import pytest
from jupyter_core import paths
from .utils import install_kernel
from .utils import sample_kernel_json
from .utils import test_env
from jupyter_client import kernelspec
class KernelSpecTests(unittest.TestCase):
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
self.sample_kernel_dir = install_kernel(
pjoin(paths.jupyter_data_dir(), "kernels"), name="sample"
)
self.ksm = kernelspec.KernelSpecManager()
td2 = TemporaryDirectory()
self.addCleanup(td2.cleanup)
self.installable_kernel = td2.name
with open(pjoin(self.installable_kernel, "kernel.json"), "w") as f:
json.dump(sample_kernel_json, f)
def tearDown(self):
self.env_patch.stop()
def test_find_kernel_specs(self):
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels["sample"], self.sample_kernel_dir)
def test_allowed_kernel_names(self):
ksm = kernelspec.KernelSpecManager()
ksm.allowed_kernelspecs = ["foo"]
kernels = ksm.find_kernel_specs()
assert not len(kernels)
def test_deprecated_whitelist(self):
ksm = kernelspec.KernelSpecManager()
ksm.whitelist = ["bar"]
kernels = ksm.find_kernel_specs()
assert not len(kernels)
def test_get_kernel_spec(self):
ks = self.ksm.get_kernel_spec("SAMPLE") # Case insensitive
self.assertEqual(ks.resource_dir, self.sample_kernel_dir)
self.assertEqual(ks.argv, sample_kernel_json["argv"])
self.assertEqual(ks.display_name, sample_kernel_json["display_name"])
self.assertEqual(ks.env, {})
self.assertEqual(ks.metadata, {})
def test_find_all_specs(self):
kernels = self.ksm.get_all_specs()
self.assertEqual(kernels["sample"]["resource_dir"], self.sample_kernel_dir)
self.assertIsNotNone(kernels["sample"]["spec"])
def test_kernel_spec_priority(self):
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
sample_kernel = install_kernel(td.name, name="sample")
self.ksm.kernel_dirs.append(td.name)
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels["sample"], self.sample_kernel_dir)
self.ksm.kernel_dirs.insert(0, td.name)
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels["sample"], sample_kernel)
def test_install_kernel_spec(self):
self.ksm.install_kernel_spec(self.installable_kernel, kernel_name="tstinstalled", user=True)
self.assertIn("tstinstalled", self.ksm.find_kernel_specs())
# install again works
self.ksm.install_kernel_spec(self.installable_kernel, kernel_name="tstinstalled", user=True)
def test_install_kernel_spec_prefix(self):
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
capture = StringIO()
handler = StreamHandler(capture)
self.ksm.log.addHandler(handler)
self.ksm.install_kernel_spec(
self.installable_kernel, kernel_name="tstinstalled", prefix=td.name
)
captured = capture.getvalue()
self.ksm.log.removeHandler(handler)
self.assertIn("may not be found", captured)
self.assertNotIn("tstinstalled", self.ksm.find_kernel_specs())
# add prefix to path, so we find the spec
self.ksm.kernel_dirs.append(pjoin(td.name, "share", "jupyter", "kernels"))
self.assertIn("tstinstalled", self.ksm.find_kernel_specs())
# Run it again, no warning this time because we've added it to the path
capture = StringIO()
handler = StreamHandler(capture)
self.ksm.log.addHandler(handler)
self.ksm.install_kernel_spec(
self.installable_kernel, kernel_name="tstinstalled", prefix=td.name
)
captured = capture.getvalue()
self.ksm.log.removeHandler(handler)
self.assertNotIn("may not be found", captured)
@pytest.mark.skipif(
not (os.name != "nt" and not os.access("/usr/local/share", os.W_OK)),
reason="needs Unix system without root privileges",
)
def test_cant_install_kernel_spec(self):
with self.assertRaises(OSError):
self.ksm.install_kernel_spec(
self.installable_kernel, kernel_name="tstinstalled", user=False
)
def test_remove_kernel_spec(self):
path = self.ksm.remove_kernel_spec("sample")
self.assertEqual(path, self.sample_kernel_dir)
def test_remove_kernel_spec_app(self):
p = Popen(
[
sys.executable,
"-m",
"jupyter_client.kernelspecapp",
"remove",
"sample",
"-f",
],
stdout=PIPE,
stderr=STDOUT,
env=os.environ,
)
out, _ = p.communicate()
self.assertEqual(p.returncode, 0, out.decode("utf8", "replace"))
def test_validate_kernel_name(self):
for good in [
"julia-0.4",
"ipython",
"R",
"python_3",
"Haskell-1-2-3",
]:
assert kernelspec._is_valid_kernel_name(good)
for bad in [
"has space",
"ünicode",
"%percent",
"question?",
]:
assert not kernelspec._is_valid_kernel_name(bad)
def test_subclass(self):
"""Test get_all_specs in subclasses that override find_kernel_specs"""
ksm = self.ksm
resource_dir = tempfile.gettempdir()
native_name = kernelspec.NATIVE_KERNEL_NAME
native_kernel = ksm.get_kernel_spec(native_name)
class MyKSM(kernelspec.KernelSpecManager):
def get_kernel_spec(self, name):
spec = copy.copy(native_kernel)
if name == "fake":
spec.name = name
spec.resource_dir = resource_dir
elif name == native_name:
pass
else:
raise KeyError(name)
return spec
def find_kernel_specs(self):
return {
"fake": resource_dir,
native_name: native_kernel.resource_dir,
}
# ensure that get_all_specs doesn't raise if only
# find_kernel_specs and get_kernel_spec are defined
myksm = MyKSM()
specs = myksm.get_all_specs()
assert sorted(specs) == ["fake", native_name]

View File

@@ -0,0 +1,15 @@
# -----------------------------------------------------------------------------
# Copyright (c) The Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
# -----------------------------------------------------------------------------
from .. import localinterfaces
def test_load_ips():
# Override the machinery that skips it if it was called before
localinterfaces._load_ips.called = False
# Just check this doesn't error
localinterfaces._load_ips(suppress_exceptions=False)

View File

@@ -0,0 +1,34 @@
"""Tests for KernelManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import tempfile
from unittest import mock
from jupyter_client.kernelspec import KernelSpec
from jupyter_client.manager import KernelManager
def test_connection_file_real_path():
"""Verify realpath is used when formatting connection file"""
with mock.patch("os.path.realpath") as patched_realpath:
patched_realpath.return_value = "foobar"
km = KernelManager(
connection_file=os.path.join(tempfile.gettempdir(), "kernel-test.json"),
kernel_name="test_kernel",
)
# KernelSpec and launch args have to be mocked as we don't have an actual kernel on disk
km._kernel_spec = KernelSpec(
resource_dir="test",
**{
"argv": ["python.exe", "-m", "test_kernel", "-f", "{connection_file}"],
"env": {},
"display_name": "test_kernel",
"language": "python",
"metadata": {},
},
)
km._launch_args = {}
cmds = km.format_kernel_cmd()
assert cmds[4] == "foobar"

View File

@@ -0,0 +1,601 @@
"""Tests for the notebook kernel and session manager."""
import asyncio
import concurrent.futures
import os
import sys
import uuid
from asyncio import ensure_future
from subprocess import PIPE
from unittest import TestCase
import pytest
from jupyter_core import paths
from tornado.testing import AsyncTestCase
from tornado.testing import gen_test
from traitlets.config.loader import Config
from ..localinterfaces import localhost
from .utils import AsyncKMSubclass
from .utils import AsyncMKMSubclass
from .utils import install_kernel
from .utils import skip_win32
from .utils import SyncKMSubclass
from .utils import SyncMKMSubclass
from .utils import test_env
from jupyter_client import AsyncKernelManager
from jupyter_client import KernelManager
from jupyter_client.multikernelmanager import AsyncMultiKernelManager
from jupyter_client.multikernelmanager import MultiKernelManager
TIMEOUT = 30
async def now(awaitable):
"""Use this function ensure that this awaitable
happens before other awaitables defined after it.
"""
(out,) = await asyncio.gather(awaitable)
return out
class TestKernelManager(TestCase):
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
super().setUp()
def tearDown(self) -> None:
self.env_patch.stop()
return super().tearDown()
# static so picklable for multiprocessing on Windows
@staticmethod
def _get_tcp_km():
c = Config()
km = MultiKernelManager(config=c)
return km
@staticmethod
def _get_tcp_km_sub():
c = Config()
km = SyncMKMSubclass(config=c)
return km
# static so picklable for multiprocessing on Windows
@staticmethod
def _get_ipc_km():
c = Config()
c.KernelManager.transport = "ipc"
c.KernelManager.ip = "test"
km = MultiKernelManager(config=c)
return km
# static so picklable for multiprocessing on Windows
@staticmethod
def _run_lifecycle(km, test_kid=None):
if test_kid:
kid = km.start_kernel(stdout=PIPE, stderr=PIPE, kernel_id=test_kid)
assert kid == test_kid
else:
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
assert km.is_alive(kid)
assert km.get_kernel(kid).ready.done()
assert kid in km
assert kid in km.list_kernel_ids()
assert len(km) == 1, f"{len(km)} != {1}"
km.restart_kernel(kid, now=True)
assert km.is_alive(kid)
assert kid in km.list_kernel_ids()
km.interrupt_kernel(kid)
k = km.get_kernel(kid)
kc = k.client()
assert isinstance(k, KernelManager)
km.shutdown_kernel(kid, now=True)
assert kid not in km, f"{kid} not in {km}"
kc.stop_channels()
def _run_cinfo(self, km, transport, ip):
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
km.get_kernel(kid)
cinfo = km.get_connection_info(kid)
self.assertEqual(transport, cinfo["transport"])
self.assertEqual(ip, cinfo["ip"])
self.assertTrue("stdin_port" in cinfo)
self.assertTrue("iopub_port" in cinfo)
stream = km.connect_iopub(kid)
stream.close()
self.assertTrue("shell_port" in cinfo)
stream = km.connect_shell(kid)
stream.close()
self.assertTrue("hb_port" in cinfo)
stream = km.connect_hb(kid)
stream.close()
km.shutdown_kernel(kid, now=True)
# static so picklable for multiprocessing on Windows
@classmethod
def test_tcp_lifecycle(cls):
km = cls._get_tcp_km()
cls._run_lifecycle(km)
def test_tcp_lifecycle_with_kernel_id(self):
km = self._get_tcp_km()
self._run_lifecycle(km, test_kid=str(uuid.uuid4()))
def test_shutdown_all(self):
km = self._get_tcp_km()
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
km.shutdown_all()
def test_tcp_cinfo(self):
km = self._get_tcp_km()
self._run_cinfo(km, "tcp", localhost())
@skip_win32
def test_ipc_lifecycle(self):
km = self._get_ipc_km()
self._run_lifecycle(km)
@skip_win32
def test_ipc_cinfo(self):
km = self._get_ipc_km()
self._run_cinfo(km, "ipc", "test")
def test_start_sequence_tcp_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_lifecycle(self._get_tcp_km())
self._run_lifecycle(self._get_tcp_km())
self._run_lifecycle(self._get_tcp_km())
@skip_win32
def test_start_sequence_ipc_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_lifecycle(self._get_ipc_km())
self._run_lifecycle(self._get_ipc_km())
self._run_lifecycle(self._get_ipc_km())
def tcp_lifecycle_with_loop(self):
# Ensure each thread has an event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.test_tcp_lifecycle()
loop.close()
def test_start_parallel_thread_kernels(self):
self.test_tcp_lifecycle()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor:
future1 = thread_executor.submit(self.tcp_lifecycle_with_loop)
future2 = thread_executor.submit(self.tcp_lifecycle_with_loop)
future1.result()
future2.result()
@pytest.mark.skipif(
(sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)),
reason='"Bad file descriptor" error',
)
def test_start_parallel_process_kernels(self):
self.test_tcp_lifecycle()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor:
future1 = thread_executor.submit(self.tcp_lifecycle_with_loop)
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor:
# Windows tests needs this target to be picklable:
future2 = process_executor.submit(self.test_tcp_lifecycle)
future2.result()
future1.result()
def test_subclass_callables(self):
km = self._get_tcp_km_sub()
km.reset_counts()
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
assert km.call_count("start_kernel") == 1
assert isinstance(km.get_kernel(kid), SyncKMSubclass)
assert km.get_kernel(kid).call_count("start_kernel") == 1
assert km.get_kernel(kid).call_count("_launch_kernel") == 1
assert km.is_alive(kid)
assert kid in km
assert kid in km.list_kernel_ids()
assert len(km) == 1, f"{len(km)} != {1}"
km.get_kernel(kid).reset_counts()
km.reset_counts()
km.restart_kernel(kid, now=True)
assert km.call_count("restart_kernel") == 1
assert km.call_count("get_kernel") == 1
assert km.get_kernel(kid).call_count("restart_kernel") == 1
assert km.get_kernel(kid).call_count("shutdown_kernel") == 1
assert km.get_kernel(kid).call_count("interrupt_kernel") == 1
assert km.get_kernel(kid).call_count("_kill_kernel") == 1
assert km.get_kernel(kid).call_count("cleanup_resources") == 1
assert km.get_kernel(kid).call_count("start_kernel") == 1
assert km.get_kernel(kid).call_count("_launch_kernel") == 1
assert km.is_alive(kid)
assert kid in km.list_kernel_ids()
km.get_kernel(kid).reset_counts()
km.reset_counts()
km.interrupt_kernel(kid)
assert km.call_count("interrupt_kernel") == 1
assert km.call_count("get_kernel") == 1
assert km.get_kernel(kid).call_count("interrupt_kernel") == 1
km.get_kernel(kid).reset_counts()
km.reset_counts()
k = km.get_kernel(kid)
assert isinstance(k, SyncKMSubclass)
assert km.call_count("get_kernel") == 1
km.get_kernel(kid).reset_counts()
km.reset_counts()
km.shutdown_all(now=True)
assert km.call_count("shutdown_kernel") == 1
assert km.call_count("remove_kernel") == 1
assert km.call_count("request_shutdown") == 0
assert km.call_count("finish_shutdown") == 0
assert km.call_count("cleanup_resources") == 0
assert kid not in km, f"{kid} not in {km}"
class TestAsyncKernelManager(AsyncTestCase):
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
super().setUp()
def tearDown(self) -> None:
self.env_patch.stop()
return super().tearDown()
# static so picklable for multiprocessing on Windows
@staticmethod
def _get_tcp_km():
c = Config()
km = AsyncMultiKernelManager(config=c)
return km
@staticmethod
def _get_tcp_km_sub():
c = Config()
km = AsyncMKMSubclass(config=c)
return km
# static so picklable for multiprocessing on Windows
@staticmethod
def _get_ipc_km():
c = Config()
c.KernelManager.transport = "ipc"
c.KernelManager.ip = "test"
km = AsyncMultiKernelManager(config=c)
return km
@staticmethod
def _get_pending_kernels_km():
c = Config()
c.AsyncMultiKernelManager.use_pending_kernels = True
km = AsyncMultiKernelManager(config=c)
return km
# static so picklable for multiprocessing on Windows
@staticmethod
async def _run_lifecycle(km, test_kid=None):
if test_kid:
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE, kernel_id=test_kid)
assert kid == test_kid
else:
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
assert await km.is_alive(kid)
assert kid in km
assert kid in km.list_kernel_ids()
assert len(km) == 1, f"{len(km)} != {1}"
await km.restart_kernel(kid, now=True)
assert await km.is_alive(kid)
assert kid in km.list_kernel_ids()
await km.interrupt_kernel(kid)
k = km.get_kernel(kid)
assert isinstance(k, AsyncKernelManager)
await km.shutdown_kernel(kid, now=True)
assert kid not in km, f"{kid} not in {km}"
async def _run_cinfo(self, km, transport, ip):
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
km.get_kernel(kid)
cinfo = km.get_connection_info(kid)
self.assertEqual(transport, cinfo["transport"])
self.assertEqual(ip, cinfo["ip"])
self.assertTrue("stdin_port" in cinfo)
self.assertTrue("iopub_port" in cinfo)
stream = km.connect_iopub(kid)
stream.close()
self.assertTrue("shell_port" in cinfo)
stream = km.connect_shell(kid)
stream.close()
self.assertTrue("hb_port" in cinfo)
stream = km.connect_hb(kid)
stream.close()
await km.shutdown_kernel(kid, now=True)
self.assertNotIn(kid, km)
@gen_test
async def test_tcp_lifecycle(self):
await self.raw_tcp_lifecycle()
@gen_test
async def test_tcp_lifecycle_with_kernel_id(self):
await self.raw_tcp_lifecycle(test_kid=str(uuid.uuid4()))
@gen_test
async def test_shutdown_all(self):
km = self._get_tcp_km()
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
await km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
await km.shutdown_all()
@gen_test(timeout=20)
async def test_use_after_shutdown_all(self):
km = self._get_tcp_km()
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
await km.shutdown_all()
self.assertNotIn(kid, km)
# Start another kernel
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
await km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
await km.shutdown_all()
@gen_test(timeout=20)
async def test_shutdown_all_while_starting(self):
km = self._get_tcp_km()
kid_future = asyncio.ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
# This is relying on the ordering of the asyncio queue, not sure if guaranteed or not:
kid, _ = await asyncio.gather(kid_future, km.shutdown_all())
self.assertNotIn(kid, km)
# Start another kernel
kid = await ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
self.assertIn(kid, km)
self.assertEqual(len(km), 1)
await km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
await km.shutdown_all()
@gen_test
async def test_use_pending_kernels(self):
km = self._get_pending_kernels_km()
kid = await ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
kernel = km.get_kernel(kid)
assert not kernel.ready.done()
assert kid in km
assert kid in km.list_kernel_ids()
assert len(km) == 1, f"{len(km)} != {1}"
# Wait for the kernel to start.
await kernel.ready
await km.restart_kernel(kid, now=True)
out = await km.is_alive(kid)
assert out
assert kid in km.list_kernel_ids()
await km.interrupt_kernel(kid)
k = km.get_kernel(kid)
assert isinstance(k, AsyncKernelManager)
await ensure_future(km.shutdown_kernel(kid, now=True))
# Wait for the kernel to shutdown
await kernel.ready
assert kid not in km, f"{kid} not in {km}"
@gen_test
async def test_use_pending_kernels_early_restart(self):
km = self._get_pending_kernels_km()
kid = await ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
kernel = km.get_kernel(kid)
assert not kernel.ready.done()
with pytest.raises(RuntimeError):
await km.restart_kernel(kid, now=True)
await kernel.ready
await ensure_future(km.shutdown_kernel(kid, now=True))
# Wait for the kernel to shutdown
await kernel.ready
assert kid not in km, f"{kid} not in {km}"
@gen_test
async def test_use_pending_kernels_early_shutdown(self):
km = self._get_pending_kernels_km()
kid = await ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
kernel = km.get_kernel(kid)
assert not kernel.ready.done()
# Try shutting down while the kernel is pending
await ensure_future(km.shutdown_kernel(kid, now=True))
# Wait for the kernel to shutdown
await kernel.ready
assert kid not in km, f"{kid} not in {km}"
@gen_test
async def test_use_pending_kernels_early_interrupt(self):
km = self._get_pending_kernels_km()
kid = await ensure_future(km.start_kernel(stdout=PIPE, stderr=PIPE))
kernel = km.get_kernel(kid)
assert not kernel.ready.done()
with pytest.raises(RuntimeError):
await km.interrupt_kernel(kid)
# Now wait for the kernel to be ready.
await kernel.ready
await ensure_future(km.shutdown_kernel(kid, now=True))
# Wait for the kernel to shutdown
await kernel.ready
assert kid not in km, f"{kid} not in {km}"
@gen_test
async def test_tcp_cinfo(self):
km = self._get_tcp_km()
await self._run_cinfo(km, "tcp", localhost())
@skip_win32
@gen_test
async def test_ipc_lifecycle(self):
km = self._get_ipc_km()
await self._run_lifecycle(km)
@skip_win32
@gen_test
async def test_ipc_cinfo(self):
km = self._get_ipc_km()
await self._run_cinfo(km, "ipc", "test")
@gen_test
async def test_start_sequence_tcp_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
await self._run_lifecycle(self._get_tcp_km())
await self._run_lifecycle(self._get_tcp_km())
await self._run_lifecycle(self._get_tcp_km())
@skip_win32
@gen_test
async def test_start_sequence_ipc_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
await self._run_lifecycle(self._get_ipc_km())
await self._run_lifecycle(self._get_ipc_km())
await self._run_lifecycle(self._get_ipc_km())
def tcp_lifecycle_with_loop(self):
# Ensure each thread has an event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.raw_tcp_lifecycle())
loop.close()
# static so picklable for multiprocessing on Windows
@classmethod
async def raw_tcp_lifecycle(cls, test_kid=None):
# Since @gen_test creates an event loop, we need a raw form of
# test_tcp_lifecycle that assumes the loop already exists.
km = cls._get_tcp_km()
await cls._run_lifecycle(km, test_kid=test_kid)
# static so picklable for multiprocessing on Windows
@classmethod
def raw_tcp_lifecycle_sync(cls, test_kid=None):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(cls.raw_tcp_lifecycle(test_kid=test_kid))
loop.close()
@gen_test
async def test_start_parallel_thread_kernels(self):
await self.raw_tcp_lifecycle()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor:
future1 = thread_executor.submit(self.tcp_lifecycle_with_loop)
future2 = thread_executor.submit(self.tcp_lifecycle_with_loop)
future1.result()
future2.result()
@gen_test
async def test_start_parallel_process_kernels(self):
await self.raw_tcp_lifecycle()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor:
future1 = thread_executor.submit(self.tcp_lifecycle_with_loop)
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor:
# Windows tests needs this target to be picklable:
future2 = process_executor.submit(self.raw_tcp_lifecycle_sync)
future2.result()
future1.result()
@gen_test
async def test_subclass_callables(self):
mkm = self._get_tcp_km_sub()
mkm.reset_counts()
kid = await mkm.start_kernel(stdout=PIPE, stderr=PIPE)
assert mkm.call_count("start_kernel") == 1
assert isinstance(mkm.get_kernel(kid), AsyncKMSubclass)
assert mkm.get_kernel(kid).call_count("start_kernel") == 1
assert mkm.get_kernel(kid).call_count("_launch_kernel") == 1
assert await mkm.is_alive(kid)
assert kid in mkm
assert kid in mkm.list_kernel_ids()
assert len(mkm) == 1, f"{len(mkm)} != {1}"
mkm.get_kernel(kid).reset_counts()
mkm.reset_counts()
await mkm.restart_kernel(kid, now=True)
assert mkm.call_count("restart_kernel") == 1
assert mkm.call_count("get_kernel") == 1
assert mkm.get_kernel(kid).call_count("restart_kernel") == 1
assert mkm.get_kernel(kid).call_count("shutdown_kernel") == 1
assert mkm.get_kernel(kid).call_count("interrupt_kernel") == 1
assert mkm.get_kernel(kid).call_count("_kill_kernel") == 1
assert mkm.get_kernel(kid).call_count("cleanup_resources") == 1
assert mkm.get_kernel(kid).call_count("start_kernel") == 1
assert mkm.get_kernel(kid).call_count("_launch_kernel") == 1
assert await mkm.is_alive(kid)
assert kid in mkm.list_kernel_ids()
mkm.get_kernel(kid).reset_counts()
mkm.reset_counts()
await mkm.interrupt_kernel(kid)
assert mkm.call_count("interrupt_kernel") == 1
assert mkm.call_count("get_kernel") == 1
assert mkm.get_kernel(kid).call_count("interrupt_kernel") == 1
mkm.get_kernel(kid).reset_counts()
mkm.reset_counts()
k = mkm.get_kernel(kid)
assert isinstance(k, AsyncKMSubclass)
assert mkm.call_count("get_kernel") == 1
mkm.get_kernel(kid).reset_counts()
mkm.reset_counts()
await mkm.shutdown_all(now=True)
assert mkm.call_count("shutdown_kernel") == 1
assert mkm.call_count("remove_kernel") == 1
assert mkm.call_count("request_shutdown") == 0
assert mkm.call_count("finish_shutdown") == 0
assert mkm.call_count("cleanup_resources") == 0
assert kid not in mkm, f"{kid} not in {mkm}"
@gen_test
async def test_bad_kernelspec(self):
km = self._get_tcp_km()
install_kernel(
os.path.join(paths.jupyter_data_dir(), "kernels"),
argv=["non_existent_executable"],
name="bad",
)
with pytest.raises(FileNotFoundError):
await ensure_future(km.start_kernel(kernel_name="bad", stdout=PIPE, stderr=PIPE))
@gen_test
async def test_bad_kernelspec_pending(self):
km = self._get_pending_kernels_km()
install_kernel(
os.path.join(paths.jupyter_data_dir(), "kernels"),
argv=["non_existent_executable"],
name="bad",
)
kernel_id = await ensure_future(
km.start_kernel(kernel_name="bad", stdout=PIPE, stderr=PIPE)
)
with pytest.raises(FileNotFoundError):
await km.get_kernel(kernel_id).ready
assert kernel_id in km.list_kernel_ids()
await ensure_future(km.shutdown_kernel(kernel_id))
assert kernel_id not in km.list_kernel_ids()

View File

@@ -0,0 +1,350 @@
"""Test Provisioning"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import signal
import sys
from subprocess import PIPE
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import pytest
from entrypoints import EntryPoint
from entrypoints import NoSuchEntryPoint
from jupyter_core import paths
from traitlets import Int
from traitlets import Unicode
from ..connect import KernelConnectionInfo
from ..kernelspec import KernelSpecManager
from ..kernelspec import NoSuchKernel
from ..launcher import launch_kernel
from ..manager import AsyncKernelManager
from ..provisioning import KernelProvisionerBase
from ..provisioning import KernelProvisionerFactory
from ..provisioning import LocalProvisioner
pjoin = os.path.join
class SubclassedTestProvisioner(LocalProvisioner):
config_var_1: int = Int(config=True)
config_var_2: str = Unicode(config=True)
pass
class CustomTestProvisioner(KernelProvisionerBase):
process = None
pid = None
pgid = None
config_var_1: int = Int(config=True)
config_var_2: str = Unicode(config=True)
@property
def has_process(self) -> bool:
return self.process is not None
async def poll(self) -> Optional[int]:
ret = 0
if self.process:
ret = self.process.poll()
return ret
async def wait(self) -> Optional[int]:
ret = 0
if self.process:
while await self.poll() is None:
await asyncio.sleep(0.1)
# Process is no longer alive, wait and clear
ret = self.process.wait()
# Make sure all the fds get closed.
for attr in ['stdout', 'stderr', 'stdin']:
fid = getattr(self.process, attr)
if fid:
fid.close()
self.process = None
return ret
async def send_signal(self, signum: int) -> None:
if self.process:
if signum == signal.SIGINT and sys.platform == 'win32':
from ..win_interrupt import send_interrupt
send_interrupt(self.process.win32_interrupt_event)
return
# Prefer process-group over process
if self.pgid and hasattr(os, "killpg"):
try:
os.killpg(self.pgid, signum)
return
except OSError:
pass
return self.process.send_signal(signum)
async def kill(self, restart=False) -> None:
if self.process:
self.process.kill()
async def terminate(self, restart=False) -> None:
if self.process:
self.process.terminate()
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
km = self.parent
if km:
# save kwargs for use in restart
km._launch_args = kwargs.copy()
# build the Popen cmd
extra_arguments = kwargs.pop('extra_arguments', [])
# write connection file / get default ports
km.write_connection_file()
self.connection_info = km.get_connection_info()
kernel_cmd = km.format_kernel_cmd(
extra_arguments=extra_arguments
) # This needs to remain here for b/c
return await super().pre_launch(cmd=kernel_cmd, **kwargs)
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
scrubbed_kwargs = kwargs
self.process = launch_kernel(cmd, **scrubbed_kwargs)
pgid = None
if hasattr(os, "getpgid"):
try:
pgid = os.getpgid(self.process.pid)
except OSError:
pass
self.pid = self.process.pid
self.pgid = pgid
return self.connection_info
async def cleanup(self, restart=False) -> None:
pass
class NewTestProvisioner(CustomTestProvisioner):
pass
def build_kernelspec(name: str, provisioner: Optional[str] = None) -> None:
spec = {
'argv': [
sys.executable,
'-m',
'jupyter_client.tests.signalkernel',
'-f',
'{connection_file}',
],
'display_name': f"Signal Test Kernel w {provisioner}",
'env': {'TEST_VARS': '${TEST_VARS}:test_var_2'},
'metadata': {},
}
if provisioner:
kernel_provisioner = {'kernel_provisioner': {'provisioner_name': provisioner}}
spec['metadata'].update(kernel_provisioner)
if provisioner != 'local-provisioner':
spec['metadata']['kernel_provisioner']['config'] = {
'config_var_1': 42,
'config_var_2': name,
}
kernel_dir = pjoin(paths.jupyter_data_dir(), 'kernels', name)
os.makedirs(kernel_dir)
with open(pjoin(kernel_dir, 'kernel.json'), 'w') as f:
f.write(json.dumps(spec))
def new_provisioner():
build_kernelspec('new_provisioner', 'new-test-provisioner')
def custom_provisioner():
build_kernelspec('custom_provisioner', 'custom-test-provisioner')
@pytest.fixture
def all_provisioners():
build_kernelspec('no_provisioner')
build_kernelspec('missing_provisioner', 'missing-provisioner')
build_kernelspec('default_provisioner', 'local-provisioner')
build_kernelspec('subclassed_provisioner', 'subclassed-test-provisioner')
custom_provisioner()
@pytest.fixture(
params=[
'no_provisioner',
'default_provisioner',
'missing_provisioner',
'custom_provisioner',
'subclassed_provisioner',
]
)
def akm(request, all_provisioners):
return AsyncKernelManager(kernel_name=request.param)
initial_provisioner_map = {
'local-provisioner': ('jupyter_client.provisioning', 'LocalProvisioner'),
'subclassed-test-provisioner': (
'jupyter_client.tests.test_provisioning',
'SubclassedTestProvisioner',
),
'custom-test-provisioner': ('jupyter_client.tests.test_provisioning', 'CustomTestProvisioner'),
}
def mock_get_all_provisioners() -> List[EntryPoint]:
result = []
for name, epstr in initial_provisioner_map.items():
result.append(EntryPoint(name, epstr[0], epstr[1]))
return result
def mock_get_provisioner(factory, name) -> EntryPoint:
if name == 'new-test-provisioner':
return EntryPoint(
'new-test-provisioner', 'jupyter_client.tests.test_provisioning', 'NewTestProvisioner'
)
if name in initial_provisioner_map:
return EntryPoint(name, initial_provisioner_map[name][0], initial_provisioner_map[name][1])
raise NoSuchEntryPoint(KernelProvisionerFactory.GROUP_NAME, name)
@pytest.fixture
def kpf(monkeypatch):
"""Setup the Kernel Provisioner Factory, mocking the entrypoint fetch calls."""
monkeypatch.setattr(
KernelProvisionerFactory, '_get_all_provisioners', mock_get_all_provisioners
)
monkeypatch.setattr(KernelProvisionerFactory, '_get_provisioner', mock_get_provisioner)
factory = KernelProvisionerFactory.instance()
return factory
class TestDiscovery:
def test_find_all_specs(self, kpf, all_provisioners):
ksm = KernelSpecManager()
kernels = ksm.get_all_specs()
# Ensure specs for initial provisioners exist,
# and missing_provisioner & new_provisioner don't
assert 'no_provisioner' in kernels
assert 'default_provisioner' in kernels
assert 'subclassed_provisioner' in kernels
assert 'custom_provisioner' in kernels
assert 'missing_provisioner' not in kernels
assert 'new_provisioner' not in kernels
def test_get_missing(self, all_provisioners):
ksm = KernelSpecManager()
with pytest.raises(NoSuchKernel):
ksm.get_kernel_spec('missing_provisioner')
def test_get_new(self, kpf):
new_provisioner() # Introduce provisioner after initialization of KPF
ksm = KernelSpecManager()
kernel = ksm.get_kernel_spec('new_provisioner')
assert 'new-test-provisioner' == kernel.metadata['kernel_provisioner']['provisioner_name']
class TestRuntime:
async def akm_test(self, kernel_mgr):
"""Starts a kernel, validates the associated provisioner's config, shuts down kernel"""
assert kernel_mgr.provisioner is None
if kernel_mgr.kernel_name == 'missing_provisioner':
with pytest.raises(NoSuchKernel):
await kernel_mgr.start_kernel()
else:
await kernel_mgr.start_kernel()
TestRuntime.validate_provisioner(kernel_mgr)
await kernel_mgr.shutdown_kernel()
assert kernel_mgr.provisioner.has_process is False
@pytest.mark.asyncio
async def test_existing(self, kpf, akm):
await self.akm_test(akm)
@pytest.mark.asyncio
async def test_new(self, kpf):
new_provisioner() # Introduce provisioner after initialization of KPF
new_km = AsyncKernelManager(kernel_name='new_provisioner')
await self.akm_test(new_km)
@pytest.mark.asyncio
async def test_custom_lifecycle(self, kpf):
custom_provisioner()
async_km = AsyncKernelManager(kernel_name='custom_provisioner')
await async_km.start_kernel(stdout=PIPE, stderr=PIPE)
is_alive = await async_km.is_alive()
assert is_alive
await async_km.restart_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive
await async_km.interrupt_kernel()
assert isinstance(async_km, AsyncKernelManager)
await async_km.shutdown_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive is False
assert async_km.context.closed
@pytest.mark.asyncio
async def test_default_provisioner_config(self, kpf, all_provisioners):
kpf.default_provisioner_name = 'custom-test-provisioner'
async_km = AsyncKernelManager(kernel_name='no_provisioner')
await async_km.start_kernel(stdout=PIPE, stderr=PIPE)
is_alive = await async_km.is_alive()
assert is_alive
assert isinstance(async_km.provisioner, CustomTestProvisioner)
assert async_km.provisioner.config_var_1 == 0 # Not in kernelspec, so default of 0 exists
await async_km.shutdown_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive is False
assert async_km.context.closed
@staticmethod
def validate_provisioner(akm: AsyncKernelManager):
# Ensure the provisioner is managing a process at this point
assert akm.provisioner is not None and akm.provisioner.has_process
# Validate provisioner config
if akm.kernel_name in ['no_provisioner', 'default_provisioner']:
assert not hasattr(akm.provisioner, 'config_var_1')
assert not hasattr(akm.provisioner, 'config_var_2')
else:
assert akm.provisioner.config_var_1 == 42
assert akm.provisioner.config_var_2 == akm.kernel_name
# Validate provisioner class
if akm.kernel_name in ['no_provisioner', 'default_provisioner', 'subclassed_provisioner']:
assert isinstance(akm.provisioner, LocalProvisioner)
if akm.kernel_name == 'subclassed_provisioner':
assert isinstance(akm.provisioner, SubclassedTestProvisioner)
else:
assert not isinstance(akm.provisioner, SubclassedTestProvisioner)
else:
assert isinstance(akm.provisioner, CustomTestProvisioner)
assert not isinstance(akm.provisioner, LocalProvisioner)
if akm.kernel_name == 'new_provisioner':
assert isinstance(akm.provisioner, NewTestProvisioner)

View File

@@ -0,0 +1,29 @@
"""Test the jupyter_client public API
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import jupyter_client
from jupyter_client import connect
from jupyter_client import launcher
def test_kms():
for base in ("", "Async", "Multi"):
KM = base + "KernelManager"
assert KM in dir(jupyter_client)
def test_kcs():
for base in ("", "Blocking", "Async"):
KM = base + "KernelClient"
assert KM in dir(jupyter_client)
def test_launcher():
for name in launcher.__all__:
assert name in dir(jupyter_client)
def test_connect():
for name in connect.__all__:
assert name in dir(jupyter_client)

View File

@@ -0,0 +1,283 @@
"""Tests for the KernelManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import sys
import pytest
from jupyter_core import paths
from traitlets.config.loader import Config
from traitlets.log import get_logger
from jupyter_client.ioloop import AsyncIOLoopKernelManager
from jupyter_client.ioloop import IOLoopKernelManager
pjoin = os.path.join
def _install_kernel(name="problemtest", extra_env=None):
if extra_env is None:
extra_env = {}
kernel_dir = pjoin(paths.jupyter_data_dir(), "kernels", name)
os.makedirs(kernel_dir)
with open(pjoin(kernel_dir, "kernel.json"), "w") as f:
f.write(
json.dumps(
{
"argv": [
sys.executable,
"-m",
"jupyter_client.tests.problemkernel",
"-f",
"{connection_file}",
],
"display_name": "Problematic Test Kernel",
"env": {"TEST_VARS": "${TEST_VARS}:test_var_2", **extra_env},
}
)
)
return name
@pytest.fixture
def install_kernel():
return _install_kernel("problemtest")
@pytest.fixture
def install_fail_kernel():
return _install_kernel("problemtest-fail", extra_env={"FAIL_ON_START": "1"})
@pytest.fixture
def install_slow_fail_kernel():
return _install_kernel(
"problemtest-slow", extra_env={"STARTUP_DELAY": "5", "FAIL_ON_START": "1"}
)
@pytest.fixture(params=["tcp", "ipc"])
def transport(request):
if sys.platform == "win32" and request.param == "ipc": #
pytest.skip("Transport 'ipc' not supported on Windows.")
return request.param
@pytest.fixture
def config(transport):
c = Config()
c.KernelManager.transport = transport
if transport == "ipc":
c.KernelManager.ip = "test"
return c
@pytest.fixture
def debug_logging():
get_logger().setLevel("DEBUG")
@pytest.mark.asyncio
async def test_restart_check(config, install_kernel, debug_logging):
"""Test that the kernel is restarted and recovers"""
# If this test failes, run it with --log-cli-level=DEBUG to inspect
N_restarts = 1
config.KernelRestarter.restart_limit = N_restarts
config.KernelRestarter.debug = True
km = IOLoopKernelManager(kernel_name=install_kernel, config=config)
cbs = 0
restarts = [asyncio.Future() for i in range(N_restarts)]
def cb():
nonlocal cbs
if cbs >= N_restarts:
raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
restarts[cbs].set_result(True)
cbs += 1
try:
km.start_kernel()
km.add_restart_callback(cb, 'restart')
except BaseException:
if km.has_kernel:
km.shutdown_kernel()
raise
try:
for i in range(N_restarts + 1):
kc = km.client()
kc.start_channels()
kc.wait_for_ready(timeout=60)
kc.stop_channels()
if i < N_restarts:
# Kill without cleanup to simulate crash:
await km.provisioner.kill()
await restarts[i]
# Wait for kill + restart
max_wait = 10.0
waited = 0.0
while waited < max_wait and km.is_alive():
await asyncio.sleep(0.1)
waited += 0.1
while waited < max_wait and not km.is_alive():
await asyncio.sleep(0.1)
waited += 0.1
assert cbs == N_restarts
assert km.is_alive()
finally:
km.shutdown_kernel(now=True)
assert km.context.closed
@pytest.mark.asyncio
async def test_restarter_gives_up(config, install_fail_kernel, debug_logging):
"""Test that the restarter gives up after reaching the restart limit"""
# If this test failes, run it with --log-cli-level=DEBUG to inspect
N_restarts = 1
config.KernelRestarter.restart_limit = N_restarts
config.KernelRestarter.debug = True
km = IOLoopKernelManager(kernel_name=install_fail_kernel, config=config)
cbs = 0
restarts = [asyncio.Future() for i in range(N_restarts)]
def cb():
nonlocal cbs
if cbs >= N_restarts:
raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
restarts[cbs].set_result(True)
cbs += 1
died = asyncio.Future()
def on_death():
died.set_result(True)
try:
km.start_kernel()
km.add_restart_callback(cb, 'restart')
km.add_restart_callback(on_death, 'dead')
except BaseException:
if km.has_kernel:
km.shutdown_kernel()
raise
try:
for i in range(N_restarts):
await restarts[i]
assert await died
assert cbs == N_restarts
finally:
km.shutdown_kernel(now=True)
assert km.context.closed
@pytest.mark.asyncio
async def test_async_restart_check(config, install_kernel, debug_logging):
"""Test that the kernel is restarted and recovers"""
# If this test failes, run it with --log-cli-level=DEBUG to inspect
N_restarts = 1
config.KernelRestarter.restart_limit = N_restarts
config.KernelRestarter.debug = True
km = AsyncIOLoopKernelManager(kernel_name=install_kernel, config=config)
cbs = 0
restarts = [asyncio.Future() for i in range(N_restarts)]
def cb():
nonlocal cbs
if cbs >= N_restarts:
raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
restarts[cbs].set_result(True)
cbs += 1
try:
await km.start_kernel()
km.add_restart_callback(cb, 'restart')
except BaseException:
if km.has_kernel:
await km.shutdown_kernel()
raise
try:
for i in range(N_restarts + 1):
kc = km.client()
kc.start_channels()
await kc.wait_for_ready(timeout=60)
kc.stop_channels()
if i < N_restarts:
# Kill without cleanup to simulate crash:
await km.provisioner.kill()
await restarts[i]
# Wait for kill + restart
max_wait = 10.0
waited = 0.0
while waited < max_wait and await km.is_alive():
await asyncio.sleep(0.1)
waited += 0.1
while waited < max_wait and not await km.is_alive():
await asyncio.sleep(0.1)
waited += 0.1
assert cbs == N_restarts
assert await km.is_alive()
finally:
await km.shutdown_kernel(now=True)
assert km.context.closed
@pytest.mark.asyncio
async def test_async_restarter_gives_up(config, install_slow_fail_kernel, debug_logging):
"""Test that the restarter gives up after reaching the restart limit"""
# If this test failes, run it with --log-cli-level=DEBUG to inspect
N_restarts = 2
config.KernelRestarter.restart_limit = N_restarts
config.KernelRestarter.debug = True
config.KernelRestarter.stable_start_time = 30.0
km = AsyncIOLoopKernelManager(kernel_name=install_slow_fail_kernel, config=config)
cbs = 0
restarts = [asyncio.Future() for i in range(N_restarts)]
def cb():
nonlocal cbs
if cbs >= N_restarts:
raise RuntimeError("Kernel restarted more than %d times!" % N_restarts)
restarts[cbs].set_result(True)
cbs += 1
died = asyncio.Future()
def on_death():
died.set_result(True)
try:
await km.start_kernel()
km.add_restart_callback(cb, 'restart')
km.add_restart_callback(on_death, 'dead')
except BaseException:
if km.has_kernel:
await km.shutdown_kernel()
raise
try:
await asyncio.gather(*restarts)
assert await died
assert cbs == N_restarts
finally:
await km.shutdown_kernel(now=True)
assert km.context.closed

View File

@@ -0,0 +1,354 @@
"""test building messages with Session"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import hmac
import os
import platform
import uuid
from datetime import datetime
from unittest import mock
import pytest
import zmq
from tornado import ioloop
from zmq.eventloop.zmqstream import ZMQStream
from zmq.tests import BaseZMQTestCase
from jupyter_client import jsonutil
from jupyter_client import session as ss
def _bad_packer(obj):
raise TypeError("I don't work")
def _bad_unpacker(bytes):
raise TypeError("I don't work either")
class SessionTestCase(BaseZMQTestCase):
def setUp(self):
BaseZMQTestCase.setUp(self)
self.session = ss.Session()
@pytest.fixture
def no_copy_threshold():
"""Disable zero-copy optimizations in pyzmq >= 17"""
with mock.patch.object(zmq, "COPY_THRESHOLD", 1, create=True):
yield
@pytest.mark.usefixtures("no_copy_threshold")
class TestSession(SessionTestCase):
def test_msg(self):
"""message format"""
msg = self.session.msg("execute")
thekeys = set("header parent_header metadata content msg_type msg_id".split())
s = set(msg.keys())
self.assertEqual(s, thekeys)
self.assertTrue(isinstance(msg["content"], dict))
self.assertTrue(isinstance(msg["metadata"], dict))
self.assertTrue(isinstance(msg["header"], dict))
self.assertTrue(isinstance(msg["parent_header"], dict))
self.assertTrue(isinstance(msg["msg_id"], str))
self.assertTrue(isinstance(msg["msg_type"], str))
self.assertEqual(msg["header"]["msg_type"], "execute")
self.assertEqual(msg["msg_type"], "execute")
def test_serialize(self):
msg = self.session.msg("execute", content=dict(a=10, b=1.1))
msg_list = self.session.serialize(msg, ident=b"foo")
ident, msg_list = self.session.feed_identities(msg_list)
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b"foo")
self.assertEqual(new_msg["msg_id"], msg["msg_id"])
self.assertEqual(new_msg["msg_type"], msg["msg_type"])
self.assertEqual(new_msg["header"], msg["header"])
self.assertEqual(new_msg["content"], msg["content"])
self.assertEqual(new_msg["parent_header"], msg["parent_header"])
self.assertEqual(new_msg["metadata"], msg["metadata"])
# ensure floats don't come out as Decimal:
self.assertEqual(type(new_msg["content"]["b"]), type(new_msg["content"]["b"]))
def test_default_secure(self):
self.assertIsInstance(self.session.key, bytes)
self.assertIsInstance(self.session.auth, hmac.HMAC)
def test_send(self):
ctx = zmq.Context()
A = ctx.socket(zmq.PAIR)
B = ctx.socket(zmq.PAIR)
A.bind("inproc://test")
B.connect("inproc://test")
msg = self.session.msg("execute", content=dict(a=10))
self.session.send(A, msg, ident=b"foo", buffers=[b"bar"])
ident, msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b"foo")
self.assertEqual(new_msg["msg_id"], msg["msg_id"])
self.assertEqual(new_msg["msg_type"], msg["msg_type"])
self.assertEqual(new_msg["header"], msg["header"])
self.assertEqual(new_msg["content"], msg["content"])
self.assertEqual(new_msg["parent_header"], msg["parent_header"])
self.assertEqual(new_msg["metadata"], msg["metadata"])
self.assertEqual(new_msg["buffers"], [b"bar"])
content = msg["content"]
header = msg["header"]
header["msg_id"] = self.session.msg_id
parent = msg["parent_header"]
metadata = msg["metadata"]
header["msg_type"]
self.session.send(
A,
None,
content=content,
parent=parent,
header=header,
metadata=metadata,
ident=b"foo",
buffers=[b"bar"],
)
ident, msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b"foo")
self.assertEqual(new_msg["msg_id"], header["msg_id"])
self.assertEqual(new_msg["msg_type"], msg["msg_type"])
self.assertEqual(new_msg["header"], msg["header"])
self.assertEqual(new_msg["content"], msg["content"])
self.assertEqual(new_msg["metadata"], msg["metadata"])
self.assertEqual(new_msg["parent_header"], msg["parent_header"])
self.assertEqual(new_msg["buffers"], [b"bar"])
header["msg_id"] = self.session.msg_id
self.session.send(A, msg, ident=b"foo", buffers=[b"bar"])
ident, new_msg = self.session.recv(B)
self.assertEqual(ident[0], b"foo")
self.assertEqual(new_msg["msg_id"], header["msg_id"])
self.assertEqual(new_msg["msg_type"], msg["msg_type"])
self.assertEqual(new_msg["header"], msg["header"])
self.assertEqual(new_msg["content"], msg["content"])
self.assertEqual(new_msg["metadata"], msg["metadata"])
self.assertEqual(new_msg["parent_header"], msg["parent_header"])
self.assertEqual(new_msg["buffers"], [b"bar"])
# buffers must support the buffer protocol
with self.assertRaises(TypeError):
self.session.send(A, msg, ident=b"foo", buffers=[1])
# buffers must be contiguous
buf = memoryview(os.urandom(16))
with self.assertRaises(ValueError):
self.session.send(A, msg, ident=b"foo", buffers=[buf[::2]])
A.close()
B.close()
ctx.term()
def test_args(self):
"""initialization arguments for Session"""
s = self.session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
self.assertEqual(s.username, os.environ.get("USER", "username"))
s = ss.Session()
self.assertEqual(s.username, os.environ.get("USER", "username"))
self.assertRaises(TypeError, ss.Session, pack="hi")
self.assertRaises(TypeError, ss.Session, unpack="hi")
u = str(uuid.uuid4())
s = ss.Session(username="carrot", session=u)
self.assertEqual(s.session, u)
self.assertEqual(s.username, "carrot")
@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason='Test fails on PyPy')
def test_tracking(self):
"""test tracking messages"""
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
s = self.session
s.copy_threshold = 1
loop = ioloop.IOLoop(make_current=False)
ZMQStream(a, io_loop=loop)
msg = s.send(a, "hello", track=False)
self.assertTrue(msg["tracker"] is ss.DONE)
msg = s.send(a, "hello", track=True)
self.assertTrue(isinstance(msg["tracker"], zmq.MessageTracker))
M = zmq.Message(b"hi there", track=True)
msg = s.send(a, "hello", buffers=[M], track=True)
t = msg["tracker"]
self.assertTrue(isinstance(t, zmq.MessageTracker))
self.assertRaises(zmq.NotDone, t.wait, 0.1)
del M
t.wait(1) # this will raise
def test_unique_msg_ids(self):
"""test that messages receive unique ids"""
ids = set()
for i in range(2**12):
h = self.session.msg_header("test")
msg_id = h["msg_id"]
self.assertTrue(msg_id not in ids)
ids.add(msg_id)
def test_feed_identities(self):
"""scrub the front for zmq IDENTITIES"""
content = dict(code="whoda", stuff=object())
self.session.msg("execute", content=content)
def test_session_id(self):
session = ss.Session()
# get bs before us
bs = session.bsession
us = session.session
self.assertEqual(us.encode("ascii"), bs)
session = ss.Session()
# get us before bs
us = session.session
bs = session.bsession
self.assertEqual(us.encode("ascii"), bs)
# change propagates:
session.session = "something else"
bs = session.bsession
us = session.session
self.assertEqual(us.encode("ascii"), bs)
session = ss.Session(session="stuff")
# get us before bs
self.assertEqual(session.bsession, session.session.encode("ascii"))
self.assertEqual(b"stuff", session.bsession)
def test_zero_digest_history(self):
session = ss.Session(digest_history_size=0)
for i in range(11):
session._add_digest(uuid.uuid4().bytes)
self.assertEqual(len(session.digest_history), 0)
def test_cull_digest_history(self):
session = ss.Session(digest_history_size=100)
for i in range(100):
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 100)
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 91)
for i in range(9):
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 100)
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 91)
def test_bad_pack(self):
try:
ss.Session(pack=_bad_packer)
except ValueError as e:
self.assertIn("could not serialize", str(e))
self.assertIn("don't work", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_unpack(self):
try:
ss.Session(unpack=_bad_unpacker)
except ValueError as e:
self.assertIn("could not handle output", str(e))
self.assertIn("don't work either", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_packer(self):
try:
ss.Session(packer=__name__ + "._bad_packer")
except ValueError as e:
self.assertIn("could not serialize", str(e))
self.assertIn("don't work", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_unpacker(self):
try:
ss.Session(unpacker=__name__ + "._bad_unpacker")
except ValueError as e:
self.assertIn("could not handle output", str(e))
self.assertIn("don't work either", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_roundtrip(self):
with self.assertRaises(ValueError):
ss.Session(unpack=lambda b: 5)
def _datetime_test(self, session):
content = dict(t=ss.utcnow())
metadata = dict(t=ss.utcnow())
p = session.msg("msg")
msg = session.msg("msg", content=content, metadata=metadata, parent=p["header"])
smsg = session.serialize(msg)
msg2 = session.deserialize(session.feed_identities(smsg)[1])
assert isinstance(msg2["header"]["date"], datetime)
self.assertEqual(msg["header"], msg2["header"])
self.assertEqual(msg["parent_header"], msg2["parent_header"])
self.assertEqual(msg["parent_header"], msg2["parent_header"])
assert isinstance(msg["content"]["t"], datetime)
assert isinstance(msg["metadata"]["t"], datetime)
assert isinstance(msg2["content"]["t"], str)
assert isinstance(msg2["metadata"]["t"], str)
self.assertEqual(msg["content"], jsonutil.extract_dates(msg2["content"]))
self.assertEqual(msg["content"], jsonutil.extract_dates(msg2["content"]))
def test_datetimes(self):
self._datetime_test(self.session)
def test_datetimes_pickle(self):
session = ss.Session(packer="pickle")
self._datetime_test(session)
def test_datetimes_msgpack(self):
msgpack = pytest.importorskip("msgpack")
session = ss.Session(
pack=msgpack.packb,
unpack=lambda buf: msgpack.unpackb(buf, raw=False),
)
self._datetime_test(session)
def test_send_raw(self):
ctx = zmq.Context()
A = ctx.socket(zmq.PAIR)
B = ctx.socket(zmq.PAIR)
A.bind("inproc://test")
B.connect("inproc://test")
msg = self.session.msg("execute", content=dict(a=10))
msg_list = [
self.session.pack(msg[part])
for part in ["header", "parent_header", "metadata", "content"]
]
self.session.send_raw(A, msg_list, ident=b"foo")
ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(new_msg_list)
self.assertEqual(ident[0], b"foo")
self.assertEqual(new_msg["msg_type"], msg["msg_type"])
self.assertEqual(new_msg["header"], msg["header"])
self.assertEqual(new_msg["parent_header"], msg["parent_header"])
self.assertEqual(new_msg["content"], msg["content"])
self.assertEqual(new_msg["metadata"], msg["metadata"])
A.close()
B.close()
ctx.term()
def test_clone(self):
s = self.session
s._add_digest("initial")
s2 = s.clone()
assert s2.session == s.session
assert s2.digest_history == s.digest_history
assert s2.digest_history is not s.digest_history
digest = "abcdef"
s._add_digest(digest)
assert digest in s.digest_history
assert digest not in s2.digest_history

View File

@@ -0,0 +1,9 @@
from jupyter_client.ssh.tunnel import select_random_ports
def test_random_ports():
for i in range(4096):
ports = select_random_ports(10)
assert len(ports) == 10
for p in ports:
assert ports.count(p) == 1

View File

@@ -0,0 +1,256 @@
"""Testing utils for jupyter_client tests
"""
import json
import os
import sys
from tempfile import TemporaryDirectory
from typing import Dict
from unittest.mock import patch
import pytest
from jupyter_client import AsyncKernelManager
from jupyter_client import AsyncMultiKernelManager
from jupyter_client import KernelManager
from jupyter_client import MultiKernelManager
pjoin = os.path.join
skip_win32 = pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows")
sample_kernel_json = {
"argv": ["cat", "{connection_file}"],
"display_name": "Test kernel",
}
def install_kernel(kernels_dir, argv=None, name="test", display_name=None):
"""install a kernel in a kernels directory"""
kernel_dir = pjoin(kernels_dir, name)
os.makedirs(kernel_dir)
kernel_json = {
"argv": argv or sample_kernel_json["argv"],
"display_name": display_name or sample_kernel_json["display_name"],
}
json_file = pjoin(kernel_dir, "kernel.json")
with open(json_file, "w") as f:
json.dump(kernel_json, f)
return kernel_dir
class test_env(object):
"""Set Jupyter path variables to a temporary directory
Useful as a context manager or with explicit start/stop
"""
def start(self):
self.test_dir = td = TemporaryDirectory()
self.env_patch = patch.dict(
os.environ,
{
"JUPYTER_CONFIG_DIR": pjoin(td.name, "jupyter"),
"JUPYTER_DATA_DIR": pjoin(td.name, "jupyter_data"),
"JUPYTER_RUNTIME_DIR": pjoin(td.name, "jupyter_runtime"),
"IPYTHONDIR": pjoin(td.name, "ipython"),
"TEST_VARS": "test_var_1",
},
)
self.env_patch.start()
def stop(self):
self.env_patch.stop()
try:
self.test_dir.cleanup()
except (PermissionError, NotADirectoryError):
if os.name != 'nt':
raise
def __enter__(self):
self.start()
return self.test_dir.name
def __exit__(self, *exc_info):
self.stop()
def execute(code="", kc=None, **kwargs):
"""wrapper for doing common steps for validating an execution request"""
from .test_message_spec import validate_message
if kc is None:
kc = KC # noqa
msg_id = kc.execute(code=code, **kwargs)
reply = kc.get_shell_msg(timeout=TIMEOUT) # noqa
validate_message(reply, "execute_reply", msg_id)
busy = kc.get_iopub_msg(timeout=TIMEOUT) # noqa
validate_message(busy, "status", msg_id)
assert busy["content"]["execution_state"] == "busy"
if not kwargs.get("silent"):
execute_input = kc.get_iopub_msg(timeout=TIMEOUT) # noqa
validate_message(execute_input, "execute_input", msg_id)
assert execute_input["content"]["code"] == code
return msg_id, reply["content"]
class RecordCallMixin:
method_calls: Dict[str, int]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.method_calls = {}
def record(self, method_name: str) -> None:
if method_name not in self.method_calls:
self.method_calls[method_name] = 0
self.method_calls[method_name] += 1
def call_count(self, method_name: str) -> int:
if method_name not in self.method_calls:
self.method_calls[method_name] = 0
return self.method_calls[method_name]
def reset_counts(self) -> None:
for record in self.method_calls:
self.method_calls[record] = 0
def subclass_recorder(f):
def wrapped(self, *args, **kwargs):
# record this call
self.record(f.__name__)
method = getattr(self._superclass, f.__name__)
# call the superclass method
r = method(self, *args, **kwargs)
# call anything defined in the actual class method
f(self, *args, **kwargs)
return r
return wrapped
class KMSubclass(RecordCallMixin):
@subclass_recorder
def start_kernel(self, **kw):
"""Record call and defer to superclass"""
@subclass_recorder
def shutdown_kernel(self, now=False, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def restart_kernel(self, now=False, **kw):
"""Record call and defer to superclass"""
@subclass_recorder
def interrupt_kernel(self):
"""Record call and defer to superclass"""
@subclass_recorder
def request_shutdown(self, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def finish_shutdown(self, waittime=None, pollinterval=0.1, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def _launch_kernel(self, kernel_cmd, **kw):
"""Record call and defer to superclass"""
@subclass_recorder
def _kill_kernel(self):
"""Record call and defer to superclass"""
@subclass_recorder
def cleanup_resources(self, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def signal_kernel(self, signum: int):
"""Record call and defer to superclass"""
@subclass_recorder
def is_alive(self):
"""Record call and defer to superclass"""
@subclass_recorder
def _send_kernel_sigterm(self, restart: bool = False):
"""Record call and defer to superclass"""
class SyncKMSubclass(KMSubclass, KernelManager):
"""Used to test subclass hierarchies to ensure methods are called when expected."""
_superclass = KernelManager
class AsyncKMSubclass(KMSubclass, AsyncKernelManager):
"""Used to test subclass hierarchies to ensure methods are called when expected."""
_superclass = AsyncKernelManager
class MKMSubclass(RecordCallMixin):
def _kernel_manager_class_default(self):
return "jupyter_client.tests.utils.SyncKMSubclass"
@subclass_recorder
def get_kernel(self, kernel_id):
"""Record call and defer to superclass"""
@subclass_recorder
def remove_kernel(self, kernel_id):
"""Record call and defer to superclass"""
@subclass_recorder
def start_kernel(self, kernel_name=None, **kwargs):
"""Record call and defer to superclass"""
@subclass_recorder
def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def restart_kernel(self, kernel_id, now=False):
"""Record call and defer to superclass"""
@subclass_recorder
def interrupt_kernel(self, kernel_id):
"""Record call and defer to superclass"""
@subclass_recorder
def request_shutdown(self, kernel_id, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def cleanup_resources(self, kernel_id, restart=False):
"""Record call and defer to superclass"""
@subclass_recorder
def shutdown_all(self, now=False):
"""Record call and defer to superclass"""
class SyncMKMSubclass(MKMSubclass, MultiKernelManager):
_superclass = MultiKernelManager
def _kernel_manager_class_default(self):
return "jupyter_client.tests.utils.SyncKMSubclass"
class AsyncMKMSubclass(MKMSubclass, AsyncMultiKernelManager):
_superclass = AsyncMultiKernelManager
def _kernel_manager_class_default(self):
return "jupyter_client.tests.utils.AsyncKMSubclass"

View File

@@ -0,0 +1,307 @@
""" Defines a KernelClient that provides thread-safe sockets with async callbacks on message
replies.
"""
import asyncio
import atexit
import errno
import time
from threading import Event
from threading import Thread
from typing import Any
from typing import Awaitable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import zmq
from traitlets import Instance
from traitlets import Type
from zmq import ZMQError
from zmq.eventloop import ioloop
from zmq.eventloop import zmqstream
from .session import Session
from jupyter_client import KernelClient
from jupyter_client.channels import HBChannel
# Local imports
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
async def get_msg(msg: Awaitable) -> Union[List[bytes], List[zmq.Message]]:
return await msg
class ThreadedZMQSocketChannel(object):
"""A ZMQ socket invoking a callback in the ioloop"""
session = None
socket = None
ioloop = None
stream = None
_inspect = None
def __init__(
self,
socket: Optional[zmq.Socket],
session: Optional[Session],
loop: Optional[zmq.eventloop.ioloop.ZMQIOLoop],
) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
A pyzmq ioloop to connect the socket to using a ZMQStream
"""
super().__init__()
self.socket = socket
self.session = session
self.ioloop = loop
evt = Event()
def setup_stream():
self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
self.stream.on_recv(self._handle_recv)
evt.set()
assert self.ioloop is not None
self.ioloop.add_callback(setup_stream)
evt.wait()
_is_alive = False
def is_alive(self) -> bool:
return self._is_alive
def start(self) -> None:
self._is_alive = True
def stop(self) -> None:
self._is_alive = False
def close(self) -> None:
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def send(self, msg: Dict[str, Any]) -> None:
"""Queue a message to be sent from the IOLoop's thread.
Parameters
----------
msg : message to send
This is threadsafe, as it uses IOLoop.add_callback to give the loop's
thread control of the action.
"""
def thread_send():
assert self.session is not None
self.session.send(self.stream, msg)
assert self.ioloop is not None
self.ioloop.add_callback(thread_send)
def _handle_recv(self, future_msg: Awaitable) -> None:
"""Callback for stream.on_recv.
Unpacks message, and calls handlers with it.
"""
assert self.ioloop is not None
msg_list = self.ioloop._asyncio_event_loop.run_until_complete(get_msg(future_msg))
assert self.session is not None
ident, smsg = self.session.feed_identities(msg_list)
msg = self.session.deserialize(smsg)
# let client inspect messages
if self._inspect:
self._inspect(msg)
self.call_handlers(msg)
def call_handlers(self, msg: Dict[str, Any]) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
def process_events(self) -> None:
"""Subclasses should override this with a method
processing any pending GUI events.
"""
pass
def flush(self, timeout: float = 1.0) -> None:
"""Immediately processes all pending messages on this channel.
This is only used for the IOPub channel.
Callers should use this method to ensure that :meth:`call_handlers`
has been called for all messages that have been received on the
0MQ SUB socket of this channel.
This method is thread safe.
Parameters
----------
timeout : float, optional
The maximum amount of time to spend flushing, in seconds. The
default is one second.
"""
# We do the IOLoop callback process twice to ensure that the IOLoop
# gets to perform at least one full poll.
stop_time = time.time() + timeout
assert self.ioloop is not None
for _ in range(2):
self._flushed = False
self.ioloop.add_callback(self._flush)
while not self._flushed and time.time() < stop_time:
time.sleep(0.01)
def _flush(self) -> None:
"""Callback for :method:`self.flush`."""
assert self.stream is not None
self.stream.flush()
self._flushed = True
class IOLoopThread(Thread):
"""Run a pyzmq ioloop in a thread to send and receive messages"""
_exiting = False
ioloop = None
def __init__(self):
super().__init__()
self.daemon = True
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if IOLoopThread is not None:
IOLoopThread._exiting = True
def start(self) -> None:
"""Start the IOLoop thread
Don't return until self.ioloop is defined,
which is created in the thread
"""
self._start_event = Event()
Thread.start(self)
self._start_event.wait()
def run(self) -> None:
"""Run my loop, ignoring EINTR events in the poller"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.ioloop = ioloop.IOLoop()
self.ioloop._asyncio_event_loop = loop
# signal that self.ioloop is defined
self._start_event.set()
while True:
try:
self.ioloop.start()
except ZMQError as e:
if e.errno == errno.EINTR:
continue
else:
raise
except Exception:
if self._exiting:
break
else:
raise
else:
break
def stop(self) -> None:
"""Stop the channel's event loop and join its thread.
This calls :meth:`~threading.Thread.join` and returns when the thread
terminates. :class:`RuntimeError` will be raised if
:meth:`~threading.Thread.start` is called again.
"""
if self.ioloop is not None:
self.ioloop.add_callback(self.ioloop.stop)
self.join()
self.close()
self.ioloop = None
def __del__(self):
self.close()
def close(self) -> None:
if self.ioloop is not None:
try:
self.ioloop.close(all_fds=True)
except Exception:
pass
class ThreadedKernelClient(KernelClient):
"""A KernelClient that provides thread-safe sockets with async callbacks on message replies."""
@property
def ioloop(self):
return self.ioloop_thread.ioloop
ioloop_thread = Instance(IOLoopThread, allow_none=True)
def start_channels(
self,
shell: bool = True,
iopub: bool = True,
stdin: bool = True,
hb: bool = True,
control: bool = True,
) -> None:
self.ioloop_thread = IOLoopThread()
self.ioloop_thread.start()
if shell:
self.shell_channel._inspect = self._check_kernel_info_reply
super().start_channels(shell, iopub, stdin, hb, control)
def _check_kernel_info_reply(self, msg: Dict[str, Any]) -> None:
"""This is run in the ioloop thread when the kernel info reply is received"""
if msg["msg_type"] == "kernel_info_reply":
self._handle_kernel_info_reply(msg)
self.shell_channel._inspect = None
def stop_channels(self) -> None:
super().stop_channels()
if self.ioloop_thread.is_alive():
self.ioloop_thread.stop()
iopub_channel_class = Type(ThreadedZMQSocketChannel)
shell_channel_class = Type(ThreadedZMQSocketChannel)
stdin_channel_class = Type(ThreadedZMQSocketChannel)
hb_channel_class = Type(HBChannel)
control_channel_class = Type(ThreadedZMQSocketChannel)
def is_alive(self) -> bool:
"""Is the kernel process still running?"""
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True

View File

@@ -0,0 +1,118 @@
"""
utils:
- provides utility wrappers to run asynchronous functions in a blocking environment.
- vendor functions from ipython_genutils that should be retired at some point.
"""
import asyncio
import inspect
import os
def run_sync(coro):
def wrapped(*args, **kwargs):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Workaround for bugs.python.org/issue39529.
try:
loop = asyncio.get_event_loop_policy().get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
import nest_asyncio # type: ignore
nest_asyncio.apply(loop)
future = asyncio.ensure_future(coro(*args, **kwargs), loop=loop)
try:
return loop.run_until_complete(future)
except BaseException as e:
future.cancel()
raise e
wrapped.__doc__ = coro.__doc__
return wrapped
async def ensure_async(obj):
if inspect.isawaitable(obj):
return await obj
return obj
def _filefind(filename, path_dirs=None):
"""Find a file by looking through a sequence of paths.
This iterates through a sequence of paths looking for a file and returns
the full, absolute path of the first occurence of the file. If no set of
path dirs is given, the filename is tested as is, after running through
:func:`expandvars` and :func:`expanduser`. Thus a simple call::
filefind('myfile.txt')
will find the file in the current working dir, but::
filefind('~/myfile.txt')
Will find the file in the users home directory. This function does not
automatically try any paths, such as the cwd or the user's home directory.
Parameters
----------
filename : str
The filename to look for.
path_dirs : str, None or sequence of str
The sequence of paths to look for the file in. If None, the filename
need to be absolute or be in the cwd. If a string, the string is
put into a sequence and the searched. If a sequence, walk through
each element and join with ``filename``, calling :func:`expandvars`
and :func:`expanduser` before testing for existence.
Returns
-------
Raises :exc:`IOError` or returns absolute path to file.
"""
# If paths are quoted, abspath gets confused, strip them...
filename = filename.strip('"').strip("'")
# If the input is an absolute path, just check it exists
if os.path.isabs(filename) and os.path.isfile(filename):
return filename
if path_dirs is None:
path_dirs = ("",)
elif isinstance(path_dirs, str):
path_dirs = (path_dirs,)
for path in path_dirs:
if path == ".":
path = os.getcwd()
testname = _expand_path(os.path.join(path, filename))
if os.path.isfile(testname):
return os.path.abspath(testname)
raise IOError(
"File {!r} does not exist in any of the search paths: {!r}".format(filename, path_dirs)
)
def _expand_path(s):
"""Expand $VARS and ~names in a string, like a shell
:Examples:
In [2]: os.environ['FOO']='test'
In [3]: expand_path('variable FOO is $FOO')
Out[3]: 'variable FOO is test'
"""
# This is a pretty subtle hack. When expand user is given a UNC path
# on Windows (\\server\share$\%username%), os.path.expandvars, removes
# the $ to get (\\server\share\%username%). I think it considered $
# alone an empty var. But, we need the $ to remains there (it indicates
# a hidden share).
if os.name == "nt":
s = s.replace("$\\", "IPYTHON_TEMP")
s = os.path.expandvars(os.path.expanduser(s))
if os.name == "nt":
s = s.replace("IPYTHON_TEMP", "$\\")
return s

View File

@@ -0,0 +1,43 @@
"""Use a Windows event to interrupt a child process like SIGINT.
The child needs to explicitly listen for this - see
ipykernel.parentpoller.ParentPollerWindows for a Python implementation.
"""
import ctypes
from typing import no_type_check
@no_type_check
def create_interrupt_event():
"""Create an interrupt event handle.
The parent process should call this to create the
interrupt event that is passed to the child process. It should store
this handle and use it with ``send_interrupt`` to interrupt the child
process.
"""
# Create a security attributes struct that permits inheritance of the
# handle by new processes.
# FIXME: We can clean up this mess by requiring pywin32 for IPython.
class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = [
("nLength", ctypes.c_int),
("lpSecurityDescriptor", ctypes.c_void_p),
("bInheritHandle", ctypes.c_int),
]
sa = SECURITY_ATTRIBUTES()
sa_p = ctypes.pointer(sa)
sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES)
sa.lpSecurityDescriptor = 0
sa.bInheritHandle = 1
return ctypes.windll.kernel32.CreateEventA(
sa_p, False, False, "" # lpEventAttributes # bManualReset # bInitialState
) # lpName
@no_type_check
def send_interrupt(interrupt_handle):
"""Sends an interrupt event using the specified handle."""
ctypes.windll.kernel32.SetEvent(interrupt_handle)