From cd96aa900e7fbdf9b94f29b80e953f384182defd Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Fri, 12 Jun 2026 23:24:52 -0700 Subject: [PATCH] refactor: implement stateful Display and inject io.Writer Closes #64. Adds stateful Display printing for streaming outputs. Removes os.Stdout.Sync() overhead and cleans up dead code. Replaces global stdout redirecting with injected io.Writer dependency, enabling safe parallel unit tests. --- cmd/ax/exec.go | 12 +-- cmd/ax/internal/display.go | 80 +++++++++++++---- cmd/ax/internal/display_test.go | 147 ++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 24 deletions(-) create mode 100644 cmd/ax/internal/display_test.go diff --git a/cmd/ax/exec.go b/cmd/ax/exec.go index 0dfba85..4efc578 100644 --- a/cmd/ax/exec.go +++ b/cmd/ax/exec.go @@ -124,7 +124,7 @@ func runExec(cmd *cobra.Command, args []string) error { } func execLoop(ctx context.Context, id string, agentID string, input string, lastSeq int32) error { - d := internal.NewDisplay(id) + d := internal.NewDisplay(id, os.Stdout) d.DisplayHeader() var inputs []*proto.Message @@ -347,7 +347,7 @@ func displayContents(d *internal.Display, contents []*proto.Message) { } switch o := content.Type.(type) { case *proto.Content_Text: - d.DisplayOutput(o.Text.Text) + d.DisplayText(o.Text.Text) case *proto.Content_Confirmation: // Let the confirmation prompt handle displaying the question. case *proto.Content_ToolCall: @@ -359,20 +359,20 @@ func displayContents(d *internal.Display, contents []*proto.Message) { if fr.GetResponse() != nil { respMap := fr.GetResponse().AsMap() if errStr, ok := respMap["error"]; ok { - d.DisplayOutput(fmt.Sprintf("\n[TOOL ERROR for %s]\n%v\n", fr.Name, errStr)) + d.DisplaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) } } } case *proto.Content_Thought: for _, summary := range o.Thought.GetSummary() { if textContent := summary.GetText(); textContent != nil { - d.DisplayOutput(fmt.Sprintf("Thinking: %s", textContent.Text)) + d.DisplayThought(textContent.Text) } } case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: - d.DisplayOutput(fmt.Sprintf("unsupported output type for display: %T", o)) + d.DisplaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) default: - d.DisplayOutput(fmt.Sprintf("unknown output type: %v", o)) + d.DisplaySystem(fmt.Sprintf("unknown output type: %v", o)) } } } diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index 22434e3..46efbd5 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -16,8 +16,8 @@ package internal import ( "fmt" + "io" "os" - "sync/atomic" "charm.land/huh/v2" "charm.land/lipgloss/v2" @@ -38,55 +38,99 @@ var ( // ErrUserAborted is returned when the user aborts a prompt. var ErrUserAborted = huh.ErrUserAborted +type displayState int + +const ( + stateNone displayState = iota + stateText + stateThought +) + type Display struct { id string + w io.Writer // Target output writer, e.g., os.Stdout or a test buffer userStyle lipgloss.Style checkpointStyle lipgloss.Style idStyle lipgloss.Style resumeStyle lipgloss.Style - loadingVisible atomic.Bool - loadingStopCh chan bool + state displayState // Tracks the last printed chunk type to correctly format transition newlines } -func NewDisplay(id string) *Display { +func NewDisplay(id string, w io.Writer) *Display { + if w == nil { + w = os.Stdout + } return &Display{ id: id, + w: w, userStyle: lipgloss.NewStyle().Foreground(purple), checkpointStyle: lipgloss.NewStyle().Foreground(comment), idStyle: lipgloss.NewStyle().Foreground(comment), resumeStyle: lipgloss.NewStyle().Foreground(comment), - loadingStopCh: make(chan bool), + state: stateNone, } } // DisplayInput displays the user input. func (d *Display) DisplayInput(text string) { - fmt.Printf("%s %s\n", + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintf(d.w, "%s %s\n", d.userStyle.Render("⏺"), text, ) - fmt.Println() + fmt.Fprintln(d.w) +} + +// DisplayText prints a chunk of model text response. +func (d *Display) DisplayText(text string) { + if d.state == stateThought { + fmt.Fprintln(d.w) // end the thinking line + } + d.state = stateText + fmt.Fprint(d.w, text) } -// DisplayOutput displays an output fragment. -func (d *Display) DisplayOutput(text string) { - fmt.Println(text) - fmt.Println() +// DisplayThought prints a chunk of model thinking process. +func (d *Display) DisplayThought(text string) { + if d.state != stateThought { + if d.state == stateText { + fmt.Fprintln(d.w) + } + fmt.Fprint(d.w, "Thinking: ") + } + d.state = stateThought + fmt.Fprint(d.w, text) +} + +// DisplaySystem prints a system/error message on a new line. +func (d *Display) DisplaySystem(text string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintln(d.w, text) } // FinishOutput completes the streaming output and shows info if provided func (d *Display) FinishOutput(info string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone if info != "" { - fmt.Println(d.checkpointStyle.Render(info)) + fmt.Fprintln(d.w, d.checkpointStyle.Render(info)) } - fmt.Println() + fmt.Fprintln(d.w) } func (d *Display) DisplayHeader() { - fmt.Println(d.idStyle.Render("Conversation: " + d.id)) - fmt.Println() + fmt.Fprintln(d.w, d.idStyle.Render("Conversation: " + d.id)) + fmt.Fprintln(d.w) } // PromptForApproval shows an accept/reject dialog @@ -128,10 +172,10 @@ func (d *Display) PromptForInput() (string, error) { } func (d *Display) ShowResumption(id string, server string) { - fmt.Println(d.resumeStyle.Render("To resume the conversation,")) + fmt.Fprintln(d.w, d.resumeStyle.Render("To resume the conversation,")) if server != "" { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) } else { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) } } diff --git a/cmd/ax/internal/display_test.go b/cmd/ax/internal/display_test.go new file mode 100644 index 0000000..2551170 --- /dev/null +++ b/cmd/ax/internal/display_test.go @@ -0,0 +1,147 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "testing" +) + +func TestDisplay_Streaming(t *testing.T) { + t.Run("consecutive text chunks are concatenated", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello ") + d.DisplayText("world") + d.DisplayText("!") + + got := buf.String() + want := "Hello world!" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("consecutive thought chunks are concatenated with prefix", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayThought("thinking ") + d.DisplayThought("deeply") + + got := buf.String() + want := "Thinking: thinking deeply" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from thought to text adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayThought("thinking") + d.DisplayText("Hello") + + got := buf.String() + want := "Thinking: thinking\nHello" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from text to thought adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplayThought("thinking") + + got := buf.String() + want := "Hello\nThinking: thinking" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput empty resets state and adds newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.FinishOutput("") + + got := buf.String() + want := "Hello\n\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput with info prints info and resets state", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.FinishOutput("seq=1") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("seq=1")) { + t.Errorf("expected output to contain seq=1, got %q", got) + } + }) + + t.Run("DisplaySystem resets state and prints newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplaySystem("system message") + + got := buf.String() + want := "Hello\nsystem message\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("DisplayInput resets state and adds separation newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplayInput("prompt") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("prompt")) { + t.Errorf("expected output to contain prompt, got %q", got) + } + }) +}