Skip to content

Commit 33dc25a

Browse files
authored
Implement work item filtering (#128)
* Implement work item filtering * PR Feedback * PR feedback * Lint * Isolate no-filter test funcs
1 parent 76052ea commit 33dc25a

File tree

9 files changed

+1649
-15
lines changed

9 files changed

+1649
-15
lines changed

docs/features.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,69 @@ with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_cha
311311

312312
> [!NOTE]
313313
> The worker and client output many logs at the `DEBUG` level that will be useful when understanding orchestration flow and diagnosing issues with Durable applications. Before submitting issues, please attempt a repro of the issue with debug logging enabled.
314+
315+
### Work item filtering
316+
317+
By default a worker receives **all** work items from the backend,
318+
regardless of which orchestrations, activities, or entities are
319+
registered. Work item filtering lets you explicitly tell the backend
320+
which work items a worker can handle so that only matching items are
321+
dispatched. This is useful when running multiple specialized workers
322+
against the same task hub.
323+
324+
Work item filtering is **opt-in**. Call `use_work_item_filters()` on
325+
the worker before starting it.
326+
327+
#### Auto-generated filters
328+
329+
Calling `use_work_item_filters()` with no arguments builds filters
330+
automatically from the worker's registry at start time:
331+
332+
```python
333+
with DurableTaskSchedulerWorker(...) as w:
334+
w.add_orchestrator(my_orchestrator)
335+
w.add_activity(my_activity)
336+
w.use_work_item_filters() # auto-generate from registry
337+
w.start()
338+
```
339+
340+
When versioning is configured with `VersionMatchStrategy.STRICT`,
341+
the worker's version is included in every filter so the backend
342+
only dispatches work items that match that exact version.
343+
344+
#### Explicit filters
345+
346+
Pass a `WorkItemFilters` instance for fine-grained control:
347+
348+
```python
349+
from durabletask.worker import (
350+
WorkItemFilters,
351+
OrchestrationWorkItemFilter,
352+
ActivityWorkItemFilter,
353+
EntityWorkItemFilter,
354+
)
355+
356+
w.use_work_item_filters(WorkItemFilters(
357+
orchestrations=[
358+
OrchestrationWorkItemFilter(name="my_orch", versions=["2.0.0"]),
359+
],
360+
activities=[
361+
ActivityWorkItemFilter(name="my_activity"),
362+
],
363+
entities=[
364+
EntityWorkItemFilter(name="my_entity"),
365+
],
366+
))
367+
```
368+
369+
#### Clearing filters
370+
371+
Pass `None` to clear any previously configured filters and return
372+
to the default behaviour of processing all work items:
373+
374+
```python
375+
w.use_work_item_filters(None)
376+
```
377+
378+
See the full
379+
[work item filtering sample](../examples/work_item_filtering.py).

docs/supported-patterns.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,51 @@ def my_orchestrator(ctx: task.OrchestrationContext, order: Order):
120120

121121
See the full [version-aware orchestrator sample](../examples/version_aware_orchestrator.py)
122122

123+
### Work item filtering
124+
125+
When running multiple workers against the same task hub, each
126+
worker can declare which work items it handles. The backend then
127+
dispatches only the matching orchestrations, activities, and
128+
entities, avoiding unnecessary round-trips. Filtering is opt-in
129+
and supports both auto-generated and explicit filter sets.
130+
131+
The simplest approach auto-generates filters from the worker's
132+
registry:
133+
134+
```python
135+
with DurableTaskSchedulerWorker(...) as w:
136+
w.add_orchestrator(greeting_orchestrator)
137+
w.add_activity(greet)
138+
w.use_work_item_filters() # auto-generate from registry
139+
w.start()
140+
```
141+
142+
For more control you can provide explicit filters, including
143+
version constraints:
144+
145+
```python
146+
from durabletask.worker import (
147+
WorkItemFilters,
148+
OrchestrationWorkItemFilter,
149+
ActivityWorkItemFilter,
150+
)
151+
152+
w.use_work_item_filters(WorkItemFilters(
153+
orchestrations=[
154+
OrchestrationWorkItemFilter(
155+
name="greeting_orchestrator",
156+
versions=["2.0.0"],
157+
),
158+
],
159+
activities=[
160+
ActivityWorkItemFilter(name="greet"),
161+
],
162+
))
163+
```
164+
165+
See the full
166+
[work item filtering sample](../examples/work_item_filtering.py).
167+
123168
### Large payload externalization
124169

125170
When orchestrations work with very large inputs, outputs, or event

durabletask/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
"""Durable Task SDK for Python"""
55

66
from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore
7-
from durabletask.worker import ConcurrencyOptions, VersioningOptions
7+
from durabletask.worker import (
8+
ActivityWorkItemFilter,
9+
ConcurrencyOptions,
10+
EntityWorkItemFilter,
11+
OrchestrationWorkItemFilter,
12+
VersioningOptions,
13+
WorkItemFilters,
14+
)
815

916
__all__ = [
17+
"ActivityWorkItemFilter",
1018
"ConcurrencyOptions",
19+
"EntityWorkItemFilter",
1120
"LargePayloadStorageOptions",
21+
"OrchestrationWorkItemFilter",
1222
"PayloadStore",
1323
"VersioningOptions",
24+
"WorkItemFilters",
1425
]
1526

1627
PACKAGE_NAME = "durabletask"

durabletask/testing/in_memory_backend.py

Lines changed: 116 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import durabletask.internal.orchestrator_service_pb2 as pb
2727
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
2828
import durabletask.internal.helpers as helpers
29+
from durabletask.entities.entity_instance_id import EntityInstanceId
2930

3031

3132
@dataclass
@@ -56,6 +57,7 @@ class ActivityWorkItem:
5657
task_id: int
5758
input: Optional[str]
5859
completion_token: int
60+
version: Optional[str] = None
5961

6062

6163
@dataclass
@@ -451,16 +453,65 @@ def RestartInstance(self, request: pb.RestartInstanceRequest, context):
451453
f"Restarted instance '{request.instanceId}' as '{new_instance_id}'")
452454
return pb.RestartInstanceResponse(instanceId=new_instance_id)
453455

456+
@staticmethod
457+
def _parse_work_item_filters(request: pb.GetWorkItemsRequest):
458+
"""Extract filters from the request.
459+
460+
Returns a tuple of three values, one per work-item category. Each
461+
value is either ``None`` (no filtering -- dispatch everything) or a
462+
``dict`` mapping a task name to a ``frozenset`` of accepted versions
463+
(empty frozenset means *any* version of that name is accepted).
464+
An empty ``dict`` means the worker opted into filtering for that
465+
category but listed no names, so *nothing* should match.
466+
"""
467+
if not request.HasField("workItemFilters"):
468+
return None, None, None
469+
wf = request.workItemFilters
470+
471+
def _build_filter(filters):
472+
result: dict[str, frozenset[str]] = {}
473+
for f in filters:
474+
versions = frozenset(f.versions) if f.versions else frozenset()
475+
existing = result.get(f.name, frozenset())
476+
result[f.name] = existing | versions
477+
return result
478+
479+
orch_filter = _build_filter(wf.orchestrations)
480+
activity_filter = _build_filter(wf.activities)
481+
entity_filter = {f.name: frozenset() for f in wf.entities}
482+
return orch_filter, activity_filter, entity_filter
483+
484+
@staticmethod
485+
def _matches_filter(name: str, version: Optional[str],
486+
filt: Optional[dict[str, frozenset[str]]]) -> bool:
487+
"""Check whether a work item matches the parsed filter.
488+
489+
*filt* is ``None`` when the worker did not opt into filtering
490+
(everything matches). Otherwise it is a dict mapping accepted
491+
names to a frozenset of accepted versions. An empty frozenset
492+
means any version of that name is accepted.
493+
"""
494+
if filt is None:
495+
return True
496+
accepted_versions = filt.get(name)
497+
if accepted_versions is None:
498+
return False
499+
if not accepted_versions:
500+
return True # empty set -- any version
501+
return (version or "") in accepted_versions
502+
454503
def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
455504
"""Streams work items to the worker (orchestration and activity work items)."""
456505
self._logger.info("Worker connected and requesting work items")
506+
orch_filter, activity_filter, entity_filter = self._parse_work_item_filters(request)
457507

458508
try:
459509
while context.is_active() and not self._shutdown_event.is_set():
460510
work_item = None
461511

462512
with self._lock:
463513
# Check for orchestration work
514+
skipped_orchs: list[str] = []
464515
while self._orchestration_queue:
465516
instance_id = self._orchestration_queue.popleft()
466517
self._orchestration_queue_set.discard(instance_id)
@@ -469,11 +520,15 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
469520
if not instance or not instance.pending_events:
470521
continue
471522

523+
# Skip if orchestration doesn't match filters
524+
if not self._matches_filter(
525+
instance.name, instance.version, orch_filter):
526+
skipped_orchs.append(instance_id)
527+
continue
528+
472529
if instance_id in self._orchestration_in_flight:
473530
# Already being processed — re-add to queue
474-
if instance_id not in self._orchestration_queue_set:
475-
self._orchestration_queue.append(instance_id)
476-
self._orchestration_queue_set.add(instance_id)
531+
skipped_orchs.append(instance_id)
477532
break
478533

479534
# Move pending events to dispatched_events
@@ -500,27 +555,66 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
500555
)
501556
break
502557

558+
# Re-queue skipped orchestrations for other workers
559+
for s in skipped_orchs:
560+
if s not in self._orchestration_queue_set:
561+
self._orchestration_queue.append(s)
562+
self._orchestration_queue_set.add(s)
563+
503564
# Check for activity work
504565
if not work_item and self._activity_queue:
505-
activity = self._activity_queue.popleft()
506-
work_item = pb.WorkItem(
507-
completionToken=str(activity.completion_token),
508-
activityRequest=pb.ActivityRequest(
509-
name=activity.name,
510-
taskId=activity.task_id,
511-
input=wrappers_pb2.StringValue(value=activity.input) if activity.input else None,
512-
orchestrationInstance=pb.OrchestrationInstance(instanceId=activity.instance_id)
566+
# Scan for the first matching activity
567+
skipped: list = []
568+
matched_activity = None
569+
while self._activity_queue:
570+
candidate = self._activity_queue.popleft()
571+
if not self._matches_filter(
572+
candidate.name, candidate.version,
573+
activity_filter):
574+
skipped.append(candidate)
575+
continue
576+
matched_activity = candidate
577+
break
578+
# Put back non-matching items
579+
for s in skipped:
580+
self._activity_queue.append(s)
581+
582+
if matched_activity is not None:
583+
work_item = pb.WorkItem(
584+
completionToken=str(matched_activity.completion_token),
585+
activityRequest=pb.ActivityRequest(
586+
name=matched_activity.name,
587+
taskId=matched_activity.task_id,
588+
input=wrappers_pb2.StringValue(value=matched_activity.input) if matched_activity.input else None,
589+
orchestrationInstance=pb.OrchestrationInstance(instanceId=matched_activity.instance_id)
590+
)
513591
)
514-
)
515592

516593
# Check for entity work
517594
if not work_item:
595+
skipped_entities: list[str] = []
518596
while self._entity_queue:
519597
entity_id = self._entity_queue.popleft()
520598
self._entity_queue_set.discard(entity_id)
521599
entity = self._entities.get(entity_id)
522600

523601
if entity and entity.pending_operations:
602+
# Skip if entity name doesn't match filters
603+
if entity_filter is not None:
604+
try:
605+
parsed = EntityInstanceId.parse(entity_id)
606+
if not self._matches_filter(
607+
parsed.entity, None,
608+
entity_filter):
609+
skipped_entities.append(entity_id)
610+
continue
611+
except ValueError:
612+
self._logger.warning(
613+
f"Cannot parse entity ID '{entity_id}' "
614+
f"for filter matching; skipping")
615+
skipped_entities.append(entity_id)
616+
continue
617+
524618
# Skip if this entity is already being processed
525619
if entity_id in self._entity_in_flight:
526620
continue
@@ -547,6 +641,12 @@ def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
547641
)
548642
break
549643

644+
# Re-queue skipped entities for other workers
645+
for s in skipped_entities:
646+
if s not in self._entity_queue_set:
647+
self._entity_queue.append(s)
648+
self._entity_queue_set.add(s)
649+
550650
if work_item:
551651
yield work_item
552652
else:
@@ -1274,12 +1374,15 @@ def _process_schedule_task_action(self, instance: OrchestrationInstance,
12741374
instance.status = pb.ORCHESTRATION_STATUS_RUNNING
12751375

12761376
# Queue activity for execution
1377+
task_version = schedule_task.version.value \
1378+
if schedule_task.HasField("version") else None
12771379
self._activity_queue.append(ActivityWorkItem(
12781380
instance_id=instance.instance_id,
12791381
name=task_name,
12801382
task_id=task_id,
12811383
input=input_value,
1282-
completion_token=instance.completion_token
1384+
completion_token=instance.completion_token,
1385+
version=task_version,
12831386
))
12841387
self._work_available.set()
12851388

0 commit comments

Comments
 (0)