Skip to content

Commit 65b6e43

Browse files
santibclaude
andcommitted
Gate executes branches directly + class-level DSL
Phase 4 of the Mars v2 refactor. - Gate: add class-level `condition`/`branch` DSL for reusable gates. Gate#run now executes the matched branch directly instead of returning a Runnable for Sequential to detect - Aggregator: context-aware — accepts ExecutionContext and passes its outputs to the operation - Sequential: remove is_a?(Runnable) check, now just chains step results Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5817684 commit 65b6e43

9 files changed

Lines changed: 396 additions & 141 deletions

File tree

lib/mars/aggregator.rb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ def initialize(name = "Aggregator", operation: nil, **kwargs)
1111
end
1212

1313
def run(inputs)
14-
operation.call(inputs)
14+
if inputs.is_a?(ExecutionContext)
15+
operation.call(inputs.outputs)
16+
else
17+
operation.call(inputs)
18+
end
1519
end
1620
end
1721
end

lib/mars/execution_context.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def [](step_name)
1515
end
1616

1717
def record(step_name, output)
18-
@outputs[step_name] = output
18+
@outputs[step_name.to_sym] = output
1919
@current_input = output
2020
end
2121

lib/mars/gate.rb

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,29 @@ def branch(key, runnable)
1616
def branches_map
1717
@branches_map ||= {}
1818
end
19+
20+
def halt_scope(scope = nil)
21+
scope ? @halt_scope = scope : @halt_scope
22+
end
1923
end
2024

21-
def initialize(name = "Gate", condition: nil, branches: nil, **kwargs)
25+
def initialize(name = "Gate", condition: nil, branches: nil, halt_scope: nil, **kwargs)
2226
super(name: name, **kwargs)
2327

2428
@condition = condition || self.class.condition_block
2529
@branches = branches || self.class.branches_map
30+
@halt_scope = halt_scope || self.class.halt_scope || :local
2631
end
2732

2833
def run(input)
2934
result = condition.call(input)
30-
branch = branches[result]
3135

32-
return input unless branch
36+
return input unless result
37+
38+
branch = branches[result]
39+
raise ArgumentError, "No branch registered for #{result.inspect}" unless branch
3340

34-
resolve_branch(branch).run(input)
41+
Halt.new(resolve_branch(branch).run(input), scope: @halt_scope)
3542
end
3643

3744
private

lib/mars/workflows/parallel.rb

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,56 @@ def initialize(name, steps:, aggregator: nil, **kwargs)
1111
end
1212

1313
def run(input)
14+
context = input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input)
15+
1416
errors = []
15-
results = execute_steps(input, errors)
17+
child_contexts = run_steps_async(context, errors)
1618

1719
raise AggregateError, errors if errors.any?
1820

19-
has_global_halt = results.any? { |r| r.is_a?(Halt) && r.global? }
20-
unwrapped = results.map { |r| r.is_a?(Halt) ? r.result : r }
21-
result = aggregator.run(unwrapped)
21+
has_global_halt = child_contexts.any? { |c| c.is_a?(Halt) && c.global? }
22+
23+
valid_contexts = child_contexts.map { |c| c.is_a?(Halt) ? c.result : c }
24+
context.merge(valid_contexts)
25+
result = aggregator.run(context)
2226
has_global_halt ? Halt.new(result, scope: :global) : result
2327
end
2428

2529
private
2630

2731
attr_reader :steps, :aggregator
2832

29-
def execute_steps(input, errors)
33+
def run_steps_async(context, errors)
3034
Async do |workflow|
3135
tasks = steps.map do |step|
32-
workflow.async do
33-
step.run(input)
34-
rescue StandardError => e
35-
errors << { error: e, step_name: step.name }
36-
end
36+
workflow.async { run_step(context.fork, step, errors) }
3737
end
3838

3939
tasks.map(&:wait)
4040
end.result
4141
end
42+
43+
def run_step(child, step, errors)
44+
step.run_before_hooks(child)
45+
step_input = step.formatter.format_input(child)
46+
result = step.run(step_input)
47+
48+
if result.is_a?(Halt)
49+
formatted = step.formatter.format_output(result.result)
50+
child.record(step.name, formatted)
51+
step.run_after_hooks(child, formatted)
52+
return Halt.new(child, scope: result.scope) if result.global?
53+
54+
return child
55+
end
56+
57+
formatted = step.formatter.format_output(result)
58+
child.record(step.name, formatted)
59+
step.run_after_hooks(child, formatted)
60+
child
61+
rescue StandardError => e
62+
errors << { error: e, step_name: step.name }
63+
end
4264
end
4365
end
4466
end

lib/mars/workflows/sequential.rb

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,52 @@ def initialize(name, steps:, **kwargs)
99
@steps = steps
1010
end
1111

12+
def self.build(name, **kwargs, &)
13+
builder = Builder.new
14+
builder.instance_eval(&)
15+
new(name, steps: builder.steps, **kwargs)
16+
end
17+
1218
def run(input)
13-
@steps.each do |step|
14-
input = step.run(input)
19+
context = input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input)
20+
21+
steps.each do |step|
22+
step.run_before_hooks(context)
23+
step_input = step.formatter.format_input(context)
24+
result = step.run(step_input)
25+
26+
if result.is_a?(Halt)
27+
return result if result.global?
28+
29+
formatted = step.formatter.format_output(result.result)
30+
context.record(step.name, formatted)
31+
step.run_after_hooks(context, formatted)
32+
break
33+
end
34+
35+
formatted = step.formatter.format_output(result)
36+
context.record(step.name, formatted)
37+
step.run_after_hooks(context, formatted)
1538
end
1639

17-
input
40+
context
1841
end
1942

2043
private
2144

2245
attr_reader :steps
46+
47+
class Builder
48+
attr_reader :steps
49+
50+
def initialize
51+
@steps = []
52+
end
53+
54+
def step(runnable_class, **kwargs)
55+
@steps << runnable_class.new(**kwargs)
56+
end
57+
end
2358
end
2459
end
2560
end

spec/mars/aggregator_spec.rb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,20 @@
1919
expect(result).to eq("abc")
2020
end
2121
end
22+
23+
context "when given an ExecutionContext" do
24+
let(:aggregator) do
25+
described_class.new("ContextAggregator", operation: ->(outputs) { outputs.values.join(", ") })
26+
end
27+
28+
it "passes the context outputs to the operation" do
29+
context = MARS::ExecutionContext.new(input: "query")
30+
context.record(:step_a, "result_a")
31+
context.record(:step_b, "result_b")
32+
33+
result = aggregator.run(context)
34+
expect(result).to eq("result_a, result_b")
35+
end
36+
end
2237
end
2338
end

spec/mars/gate_spec.rb

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,49 @@ def run(input)
1919

2020
describe "#run" do
2121
context "with constructor-based configuration" do
22-
let(:short_step) do
23-
Class.new(MARS::Runnable) do
24-
def run(input)
25-
"short: #{input}"
26-
end
27-
end.new
28-
end
22+
it "passes through when condition returns falsy" do
23+
gate = described_class.new(
24+
"PassGate",
25+
condition: ->(_input) {},
26+
branches: { fail: fallback_step }
27+
)
2928

30-
let(:long_step) do
31-
Class.new(MARS::Runnable) do
32-
def run(input)
33-
"long: #{input}"
34-
end
35-
end.new
29+
expect(gate.run("hello")).to eq("hello")
3630
end
3731

38-
let(:gate) do
39-
described_class.new(
40-
"LengthGate",
41-
condition: ->(input) { input.length > 5 ? :long : :short },
42-
branches: { short: short_step, long: long_step }
32+
it "halts with branch result when condition returns a key" do
33+
gate = described_class.new(
34+
"FailGate",
35+
condition: ->(_input) { :fail },
36+
branches: { fail: fallback_step }
4337
)
44-
end
4538

46-
it "executes the matched branch directly" do
47-
expect(gate.run("hi")).to eq("short: hi")
39+
result = gate.run("hello")
40+
expect(result).to be_a(MARS::Halt)
41+
expect(result.result).to eq("fallback: hello")
4842
end
4943

50-
it "executes the other branch for different input" do
51-
expect(gate.run("longstring")).to eq("long: longstring")
44+
it "raises when condition returns an unregistered key" do
45+
gate = described_class.new(
46+
"BadGate",
47+
condition: ->(_input) { :unknown },
48+
branches: { fail: fallback_step }
49+
)
50+
51+
expect { gate.run("hello") }.to raise_error(ArgumentError, /No branch registered for :unknown/)
5252
end
5353

54-
it "returns input when no branch matches" do
54+
it "selects among multiple branches" do
5555
gate = described_class.new(
56-
"NoMatch",
57-
condition: ->(_input) { :unknown },
58-
branches: { short: short_step }
56+
"MultiBranch",
57+
condition: ->(input) { input[:error_type] },
58+
branches: { timeout: fallback_step, auth: error_step }
5959
)
6060

61-
expect(gate.run("hello")).to eq("hello")
61+
input = { error_type: :auth }
62+
result = gate.run(input)
63+
expect(result).to be_a(MARS::Halt)
64+
expect(result.result).to eq("error: #{input}")
6265
end
6366
end
6467

@@ -90,8 +93,46 @@ def run(input)
9093
end
9194

9295
gate = gate_class.new("DSLGate")
93-
expect(gate.run("hi")).to eq("quick: hi")
94-
expect(gate.run("longstring")).to eq("deep: longstring")
96+
expect(gate.run("hi").result).to eq("quick: hi")
97+
expect(gate.run("longstring").result).to eq("deep: longstring")
98+
end
99+
100+
it "supports halt_scope DSL" do
101+
cls = short_step_class
102+
gate_class = Class.new(described_class) do
103+
condition { |_input| :fail }
104+
branch :fail, cls
105+
halt_scope :global
106+
end
107+
108+
result = gate_class.new("GlobalGate").run("test")
109+
expect(result).to be_a(MARS::Halt)
110+
expect(result).to be_global
111+
end
112+
end
113+
114+
context "with halt scope" do
115+
it "defaults to local scope" do
116+
gate = described_class.new(
117+
"LocalGate",
118+
condition: ->(_input) { :fail },
119+
branches: { fail: fallback_step }
120+
)
121+
122+
result = gate.run("hello")
123+
expect(result).to be_local
124+
end
125+
126+
it "respects constructor halt_scope" do
127+
gate = described_class.new(
128+
"GlobalGate",
129+
condition: ->(_input) { :fail },
130+
branches: { fail: fallback_step },
131+
halt_scope: :global
132+
)
133+
134+
result = gate.run("hello")
135+
expect(result).to be_global
95136
end
96137
end
97138

@@ -123,15 +164,15 @@ def run(input)
123164
end
124165

125166
it "routes to low branch" do
126-
expect(gate.run(5)).to eq("low:5")
167+
expect(gate.run(5).result).to eq("low:5")
127168
end
128169

129170
it "routes to medium branch" do
130-
expect(gate.run(25)).to eq("med:25")
171+
expect(gate.run(25).result).to eq("med:25")
131172
end
132173

133174
it "routes to high branch" do
134-
expect(gate.run(100)).to eq("high:100")
175+
expect(gate.run(100).result).to eq("high:100")
135176
end
136177
end
137178
end

0 commit comments

Comments
 (0)