Skip to content

Commit 00f407b

Browse files
committed
Shift contextvars into utils notably, removing kernel.job.
Add Kernel.concurrency_mode replacing get_handler_and_run_mode with get_run_mode and get_handler. Add more methods to utils.
1 parent e5d462b commit 00f407b

7 files changed

Lines changed: 236 additions & 218 deletions

File tree

docs/notebooks/simple_example.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@
4040
"source": [
4141
"import anyio\n",
4242
"import ipywidgets as ipw\n",
43-
"import utils\n",
4443
"\n",
45-
"from async_kernel import Caller\n",
44+
"from async_kernel import Caller, utils\n",
4645
"\n",
4746
"\n",
4847
"async def demo():\n",
@@ -53,7 +52,7 @@
5352
" for i in range(1, 4):\n",
5453
" b.description = f\"Continue {i}\"\n",
5554
" event = anyio.Event()\n",
56-
" b.on_click(lambda _: caller.call_soon(event.set)) # noqa: B023\n",
55+
" b.on_click(lambda _: caller.call_soon(event.set)) # noqa: B023 # pyright: ignore[reportUnknownLambdaType]\n",
5756
" print(f\"Waiting {i}\", end=\"\\r\")\n",
5857
" await event.wait()\n",
5958
" b.close()\n",
@@ -156,7 +155,9 @@
156155
"slideshow": {
157156
"slide_type": ""
158157
},
159-
"tags": []
158+
"tags": [
159+
"suppress-error"
160+
]
160161
},
161162
"outputs": [],
162163
"source": [

src/async_kernel/asyncshell.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import json
88
import pathlib
99
import sys
10-
from contextvars import ContextVar
1110
from typing import TYPE_CHECKING, Any, ClassVar, Literal
1211

1312
import anyio
@@ -22,6 +21,7 @@
2221
from typing_extensions import override
2322

2423
import async_kernel
24+
from async_kernel import utils
2525
from async_kernel.caller import Caller
2626
from async_kernel.compiler import XCachingCompiler
2727
from async_kernel.typing import Content, Tags
@@ -97,7 +97,7 @@ def publish( # pyright: ignore[reportIncompatibleMethodOverride]
9797
9898
[Reference](https://jupyter-client.readthedocs.io/en/stable/messaging.html#update-display-data)
9999
"""
100-
async_kernel.Kernel().iopub_send(
100+
utils.get_kernel().iopub_send(
101101
msg_or_type="update_display_data" if update else "display_data",
102102
content={"data": data, "metadata": metadata or {}, "transient": transient or {}} | kwargs,
103103
ident=self.topic,
@@ -112,7 +112,7 @@ def clear_output(self, wait: bool = False) -> None:
112112
instead waiting for the next display before clearing.
113113
This reduces bounce during repeated clear & display loops.
114114
"""
115-
async_kernel.Kernel().iopub_send(msg_or_type="clear_output", content={"wait": wait}, ident=self.topic)
115+
utils.get_kernel().iopub_send(msg_or_type="clear_output", content={"wait": wait}, ident=self.topic)
116116

117117

118118
class AsyncInteractiveShell(InteractiveShell):
@@ -136,7 +136,7 @@ class AsyncInteractiveShell(InteractiveShell):
136136
compile: Instance[XCachingCompiler]
137137
user_ns_hidden = Dict()
138138
_main_mod_cache = Dict()
139-
_execute_request_timeout: ContextVar[float | None] = ContextVar("execute_request_timeout", default=None)
139+
140140
run_cell = None # pyright: ignore[reportAssignmentType]
141141
"**not-supported**"
142142
should_run_async = None # pyright: ignore[reportAssignmentType]
@@ -162,21 +162,7 @@ def _default_banner1(self) -> str:
162162
@property
163163
def kernel(self) -> Kernel:
164164
"The current kernel."
165-
return async_kernel.Kernel()
166-
167-
@property
168-
def execute_request_timeout(self) -> float | None:
169-
"""A timeout in context of the [run_cell_async][async_kernel.asyncshell.AsyncInteractiveShell].
170-
171-
See also:
172-
173-
- [async_kernel.typing.MetadataKeys.timeout][].
174-
"""
175-
return self._execute_request_timeout.get()
176-
177-
@execute_request_timeout.setter
178-
def execute_request_timeout(self, value: float | None) -> None:
179-
self._execute_request_timeout.set(value)
165+
return utils.get_kernel()
180166

181167
@observe("exit_now")
182168
def _update_exit_now(self, _) -> None:
@@ -250,7 +236,7 @@ async def run_cell_async(
250236
This function runs [execute requests][async_kernel.Kernel.execute_request] for the kernel
251237
wrapping [InteractiveShell][IPython.core.interactiveshell.InteractiveShell.run_cell_async].
252238
"""
253-
with anyio.fail_after(delay=self.execute_request_timeout):
239+
with anyio.fail_after(delay=utils.get_execute_request_timeout()):
254240
result: ExecutionResult = await super().run_cell_async(
255241
raw_cell=raw_cell,
256242
store_history=store_history,
@@ -269,7 +255,7 @@ async def run_cell_async(
269255
def _showtraceback(self, etype, evalue, stb) -> None:
270256
if Tags.suppress_error in async_kernel.utils.get_tags():
271257
return
272-
if self.execute_request_timeout is not None and etype is self.kernel.CancelledError:
258+
if utils.get_execute_request_timeout() is not None and etype is self.kernel.CancelledError:
273259
etype, evalue, stb = TimeoutError, "Cell execute timeout", []
274260
self.kernel.iopub_send(
275261
msg_or_type="error",
@@ -294,14 +280,11 @@ class KernelMagics(Magics):
294280
@line_magic
295281
def connect_info(self, _) -> None:
296282
"""Print information for connecting other clients to this kernel."""
297-
298-
kernel = async_kernel.Kernel()
283+
kernel = utils.get_kernel()
299284
connection_file = pathlib.Path(kernel.connection_file)
300-
301285
# if it's in the default dir, truncate to basename
302286
if jupyter_runtime_dir() == str(connection_file.parent):
303287
connection_file = connection_file.name
304-
305288
info = kernel.get_connection_info()
306289
print(
307290
json.dumps(info, indent=2, default=json_default),

src/async_kernel/caller.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ class Caller:
238238
_pool_instances: ClassVar[weakref.WeakSet[Self]] = weakref.WeakSet()
239239
_executor_queue: dict
240240
_taskgroup: TaskGroup | None = None
241-
_jobs: deque[tuple[contextvars.Context, tuple[Future, float, float, Callable, tuple, dict]] | Callable[[], Any]]
242-
_jobs_added: threading.Event
241+
_callers: deque[tuple[contextvars.Context, tuple[Future, float, float, Callable, tuple, dict]] | Callable[[], Any]]
242+
_callers_added: threading.Event
243243
_stopped = False
244244
_protected = False
245245
_running = False
@@ -292,8 +292,8 @@ def __new__(
292292
inst.backend = Backend(sniffio.current_async_library())
293293
inst.thread = thread
294294
inst.log = log or logging.LoggerAdapter(logging.getLogger())
295-
inst._jobs = deque()
296-
inst._jobs_added = threading.Event()
295+
inst._callers = deque()
296+
inst._callers_added = threading.Event()
297297
inst._protected = protected
298298
inst._executor_queue = {}
299299
cls._instances[thread] = inst
@@ -325,8 +325,8 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
325325
self.iopub_sockets[self.thread] = socket
326326
task_status.started()
327327
while not self._stopped:
328-
while len(self._jobs):
329-
job = self._jobs.popleft()
328+
while len(self._callers):
329+
job = self._callers.popleft()
330330
if isinstance(job, Callable):
331331
try:
332332
job()
@@ -335,11 +335,11 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
335335
else:
336336
context, args = job
337337
context.run(tg.start_soon, self._wrap_call, *args)
338-
self._jobs_added.clear()
339-
await wait_thread_event(self._jobs_added)
338+
self._callers_added.clear()
339+
await wait_thread_event(self._callers_added)
340340
finally:
341341
self._running = False
342-
for job in self._jobs:
342+
for job in self._callers:
343343
if not callable(job):
344344
job[1][0].set_exception(FutureCancelledError())
345345
socket.close()
@@ -417,7 +417,7 @@ def stop(self, *, force=False) -> None:
417417
if self._protected and not force:
418418
return
419419
self._stopped = True
420-
self._jobs_added.set()
420+
self._callers_added.set()
421421
self._instances.pop(self.thread, None)
422422
if self in self._to_thread_pool:
423423
self._to_thread_pool.remove(self)
@@ -439,8 +439,8 @@ def call_later(
439439
if threading.current_thread() is self.thread and (tg := self._taskgroup):
440440
tg.start_soon(self._wrap_call, fut, time.monotonic(), delay, func, args, kwargs)
441441
else:
442-
self._jobs.append((contextvars.copy_context(), (fut, time.monotonic(), delay, func, args, kwargs)))
443-
self._jobs_added.set()
442+
self._callers.append((contextvars.copy_context(), (fut, time.monotonic(), delay, func, args, kwargs)))
443+
self._callers_added.set()
444444
self._outstanding += 1
445445
return fut
446446

@@ -462,8 +462,8 @@ def call_no_context(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.
462462
*args: Arguments to use with func.
463463
**kwargs: Keyword arguments to use with func.
464464
"""
465-
self._jobs.append(functools.partial(func, *args, **kwargs))
466-
self._jobs_added.set()
465+
self._callers.append(functools.partial(func, *args, **kwargs))
466+
self._callers_added.set()
467467

468468
def has_execution_queue(self, func: Callable) -> bool:
469469
"Returns True if an execution queue exists for `func`."

0 commit comments

Comments
 (0)