Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/mars/agent_step.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def agent(klass = nil)
end
end

def run(input)
self.class.agent.new.ask(input).content
def run(context)
self.class.agent.new.ask(context.current_input).content
end
end
end
7 changes: 4 additions & 3 deletions lib/mars/execution_context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

module MARS
class ExecutionContext
attr_reader :current_input, :outputs, :global_state
attr_reader :outputs, :global_state
attr_accessor :current_input

def initialize(input: nil, global_state: {})
@current_input = input
Expand All @@ -19,8 +20,8 @@ def record(step_name, output)
@current_input = output
end

def fork(input: current_input)
self.class.new(input: input, global_state: global_state)
def fork(input: current_input, state: {})
self.class.new(input: input, global_state: global_state.merge(state))
end

def merge(child_contexts)
Expand Down
5 changes: 3 additions & 2 deletions lib/mars/workflows/parallel.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def aggregate_results(results)
def execute_steps(context, errors, child_contexts)
Async do |workflow|
tasks = steps.map do |step|
child_ctx = context.fork
child_ctx = context.fork(state: step.state)
child_contexts << child_ctx

workflow.async do
Expand All @@ -54,7 +54,8 @@ def workflow_step(step, child_ctx)
step.run_before_hooks(child_ctx)

step_input = step.formatter.format_input(child_ctx)
result = step.run(step_input)
child_ctx.current_input = step_input
result = step.run(child_ctx)

if result.is_a?(Halt)
step.run_after_hooks(child_ctx, result)
Expand Down
3 changes: 2 additions & 1 deletion lib/mars/workflows/sequential.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def run(input)
step.run_before_hooks(context)

step_input = step.formatter.format_input(context)
result = step.run(step_input)
context.current_input = step_input
result = step.run(context)

if result.is_a?(Halt)
if result.global?
Expand Down
2 changes: 1 addition & 1 deletion spec/mars/agent_step_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

it "creates a new agent instance and calls ask" do
step = step_class.new
result = step.run("hello")
result = step.run(MARS::ExecutionContext.new(input: "hello"))

expect(result).to eq("agent response")
expect(mock_agent_class).to have_received(:new)
Expand Down
24 changes: 12 additions & 12 deletions spec/mars/workflows/sequential_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def initialize(value, **kwargs)
end

def run(input)
input + @value
input.current_input + @value
end
end
end
Expand All @@ -22,7 +22,7 @@ def initialize(multiplier, **kwargs)
end

def run(input)
input * @multiplier
input.current_input * @multiplier
end
end
end
Expand Down Expand Up @@ -67,11 +67,11 @@ def run(_input)

it "records outputs in context accessible by step name" do
step1 = Class.new(MARS::Runnable) do
def run(input) = "from_step1:#{input}"
def run(input) = "from_step1:#{input.current_input}"
end.new(name: "step1")

step2 = Class.new(MARS::Runnable) do
def run(input) = "from_step2:#{input}"
def run(input) = "from_step2:#{input.current_input}"
end.new(name: "step2")

context = MARS::ExecutionContext.new(input: "hello")
Expand All @@ -84,7 +84,7 @@ def run(input) = "from_step2:#{input}"

it "wraps raw input in ExecutionContext automatically" do
step = Class.new(MARS::Runnable) do
def run(input) = "processed:#{input}"
def run(input) = "processed:#{input.current_input}"
end.new(name: "step")

workflow = described_class.new("auto_wrap", steps: [step])
Expand All @@ -100,7 +100,7 @@ def format_output(output)
end

step = Class.new(MARS::Runnable) do
def run(input) = "result:#{input}"
def run(input) = "result:#{input.current_input}"
end.new(name: "step", formatter: uppercase_formatter.new)

workflow = described_class.new("fmt_workflow", steps: [step])
Expand All @@ -115,7 +115,7 @@ def run(input) = "result:#{input}"
before_run { |_ctx, step| hook_log << "before:#{step.name}" }
after_run { |_ctx, _result, step| hook_log << "after:#{step.name}" }

def run(input) = input
def run(input) = input.current_input
end

step = step_class.new(name: "hooked")
Expand All @@ -133,7 +133,7 @@ def run(input) = input
fallbacks: {
branch: Class.new(MARS::Runnable) do
def run(input)
"branched:#{input}"
"branched:#{input.current_input}"
end
end.new(name: "branch_step")
}
Expand All @@ -156,7 +156,7 @@ def run(input)
fallbacks: {
branch: Class.new(MARS::Runnable) do
def run(input)
"branched:#{input}"
"branched:#{input.current_input}"
end
end.new(name: "branch_step")
},
Expand All @@ -179,7 +179,7 @@ def run(input)
fallbacks: {
stop: Class.new(MARS::Runnable) do
def run(input)
"stopped:#{input}"
"stopped:#{input.current_input}"
end
end.new(name: "stop_step")
},
Expand All @@ -202,7 +202,7 @@ def run(input)
fallbacks: {
stop: Class.new(MARS::Runnable) do
def run(input)
"stopped:#{input}"
"stopped:#{input.current_input}"
end
end.new(name: "stop_step")
}
Expand All @@ -212,7 +212,7 @@ def run(input)

string_step = Class.new(MARS::Runnable) do
def run(input)
"after:#{input}"
"after:#{input.current_input}"
end
end.new(name: "after_step")

Expand Down
Loading