diff --git a/README.md b/README.md index e28afc8..c0d0c41 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ TagIt supports configuration through: - **CLI flags** (`--consul-addr`, `--service-id`, `--script`, `--tag-prefix`, `--interval`, `--token`) - **Config file** with `--config` (default: `$HOME/.tagit.yaml`) -- **Environment variables** (Viper automatic binding) +- **Environment variables** using `TAGIT_*` names Example `~/.tagit.yaml`: @@ -123,7 +123,9 @@ interval: "5s" token: "your-consul-token" ``` -Note: `run` and `cleanup` use inherited root flags, while `systemd` defines and validates its own flags. +Configuration precedence is: CLI flags, `TAGIT_*` environment variables, config file values, then CLI flag defaults. +Environment variable names replace `-` with `_`: `TAGIT_CONSUL_ADDR`, `TAGIT_SERVICE_ID`, `TAGIT_SCRIPT`, `TAGIT_TAG_PREFIX`, `TAGIT_INTERVAL`, and `TAGIT_TOKEN`. +The `run`, `cleanup`, and `systemd` commands resolve shared TagIt invocation values through the same validation path; `systemd` also requires `--user` and `--group`. ## Examples diff --git a/cmd/cleanup.go b/cmd/cleanup.go index 2022c5e..fe07249 100644 --- a/cmd/cleanup.go +++ b/cmd/cleanup.go @@ -16,12 +16,6 @@ limitations under the License. package cmd import ( - "fmt" - "log/slog" - "os" - - "github.com/ncode/tagit/pkg/consul" - "github.com/ncode/tagit/pkg/tagit" "github.com/spf13/cobra" ) @@ -30,46 +24,7 @@ var cleanupCmd = &cobra.Command{ Use: "cleanup", Short: "cleanup removes all services with the tag prefix from a given consul service", RunE: func(cmd *cobra.Command, args []string) error { - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelInfo, - })) - - consulAddr := cmd.InheritedFlags().Lookup("consul-addr").Value.String() - token := cmd.InheritedFlags().Lookup("token").Value.String() - - consulClient, err := consul.CreateClient(consulAddr, token) - if err != nil { - logger.Error("Failed to create Consul client", "error", err) - return err - } - - serviceID := cmd.InheritedFlags().Lookup("service-id").Value.String() - if serviceID == "" { - logger.Error("Service ID is required") - return fmt.Errorf("service-id is required") - } - tagPrefix := cmd.InheritedFlags().Lookup("tag-prefix").Value.String() - - t := tagit.New( - consulClient, - &tagit.CmdExecutor{}, - serviceID, - "", // script is not needed for cleanup - 0, // interval is not needed for cleanup - tagPrefix, - logger, - ) - - logger.Info("Starting tag cleanup", "serviceID", serviceID, "tagPrefix", tagPrefix) - - err = t.CleanupTags() - if err != nil { - logger.Error("Failed to clean up tags", "error", err) - return fmt.Errorf("failed to clean up tags: %w", err) - } - - logger.Info("Tag cleanup completed successfully") - return nil + return cleanupCommand(cmd, commandDeps{}) }, } diff --git a/cmd/cleanup_test.go b/cmd/cleanup_test.go index 09806b4..c332a81 100644 --- a/cmd/cleanup_test.go +++ b/cmd/cleanup_test.go @@ -1,489 +1,166 @@ package cmd import ( - "bytes" + "context" "fmt" + "log/slog" + "strings" "testing" - "github.com/hashicorp/consul/api" "github.com/ncode/tagit/pkg/consul" - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" + "github.com/ncode/tagit/pkg/tagit" ) -func TestCleanupCmd(t *testing.T) { - tests := []struct { - name string - args []string - expectError bool - errorContains string - }{ - { - name: "Missing required service-id", - args: []string{"cleanup"}, - expectError: true, - errorContains: "service-id is required", +func TestCleanupCommand_validatesInputBeforeCreatingClient(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) + + clientCalls := 0 + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + clientCalls++ + return commandClient{}, nil + }, + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + t.Fatal("NewTagger called before validation") + return nil }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create a new root command for each test - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - // Add the cleanup command - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix", - RunE: cleanupCmd.RunE, - } - cmd.AddCommand(testCleanupCmd) - - // Capture stderr - var buf bytes.Buffer - cmd.SetErr(&buf) - cmd.SetArgs(tt.args) - - err := cmd.Execute() - - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) + err := cleanupCommand(cmd, deps) + if err == nil { + t.Fatal("cleanupCommand() error = nil, want error") } -} - -func TestCleanupCmdFlagParsing(t *testing.T) { - var capturedFlags map[string]string - - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix", - Run: func(cmd *cobra.Command, args []string) { - // Capture flag values during execution - capturedFlags = make(map[string]string) - capturedFlags["service-id"], _ = cmd.InheritedFlags().GetString("service-id") - capturedFlags["tag-prefix"], _ = cmd.InheritedFlags().GetString("tag-prefix") - capturedFlags["consul-addr"], _ = cmd.InheritedFlags().GetString("consul-addr") - capturedFlags["token"], _ = cmd.InheritedFlags().GetString("token") - }, + if !strings.Contains(err.Error(), "service-id is required") { + t.Fatalf("cleanupCommand() error = %q, want service-id validation", err) } - cmd.AddCommand(testCleanupCmd) - - cmd.SetArgs([]string{ - "cleanup", - "--service-id=test-service", - "--tag-prefix=test", - "--consul-addr=localhost:8500", - "--token=test-token", - }) - - err := cmd.Execute() - assert.NoError(t, err) - - // Verify flags were parsed correctly - assert.Equal(t, "test-service", capturedFlags["service-id"]) - assert.Equal(t, "test", capturedFlags["tag-prefix"]) - assert.Equal(t, "localhost:8500", capturedFlags["consul-addr"]) - assert.Equal(t, "test-token", capturedFlags["token"]) -} - -func TestCleanupCmdHelp(t *testing.T) { - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix from a given consul service", - RunE: cleanupCmd.RunE, + if clientCalls != 0 { + t.Fatalf("NewClient calls = %d, want 0", clientCalls) } - cmd.AddCommand(testCleanupCmd) - - buf := new(bytes.Buffer) - cmd.SetOut(buf) - cmd.SetArgs([]string{"cleanup", "--help"}) - - err := cmd.Execute() - assert.NoError(t, err) - - output := buf.String() - assert.Contains(t, output, "cleanup removes all services with the tag prefix") - assert.Contains(t, output, "Usage:") } -func TestCleanupCmdExecution(t *testing.T) { - tests := []struct { - name string - consulAddr string - expectError bool - errorContains string - }{ - { - name: "Invalid consul address", - consulAddr: "invalid-consul-address", - expectError: true, - errorContains: "failed to clean up tags", +func TestCleanupCommand_usesResolvedInput(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "consul-addr", "consul.example:8500") + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "tag-prefix", "role") + setFlag(t, cmd.InheritedFlags(), "token", "secret") + + var gotAddress string + var gotToken string + var gotInput commandInput + fakeTagger := &commandTagger{} + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + gotAddress = address + gotToken = token + return commandClient{}, nil + }, + NewExecutor: func() tagit.CommandExecutor { + return &tagit.CmdExecutor{} + }, + NewTagger: func(client consul.Client, executor tagit.CommandExecutor, input commandInput, logger *slog.Logger) tagger { + gotInput = input + return fakeTagger }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix from a given consul service", - RunE: cleanupCmd.RunE, - } - cmd.AddCommand(testCleanupCmd) - - var stderr bytes.Buffer - cmd.SetErr(&stderr) - cmd.SetArgs([]string{ - "cleanup", - "--service-id=test-service", - "--consul-addr=" + tt.consulAddr, - "--tag-prefix=test", - }) - - err := cmd.Execute() + if err := cleanupCommand(cmd, deps); err != nil { + t.Fatalf("cleanupCommand() error = %v", err) + } - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) + if gotAddress != "consul.example:8500" { + t.Fatalf("address = %q, want consul.example:8500", gotAddress) + } + if gotToken != "secret" { + t.Fatalf("token = %q, want secret", gotToken) + } + wantInput := commandInput{ + ConsulAddr: "consul.example:8500", + ServiceID: "api", + TagPrefix: "role", + IntervalRaw: "60s", + Token: "secret", + } + if gotInput != wantInput { + t.Fatalf("input = %#v, want %#v", gotInput, wantInput) + } + if fakeTagger.cleanupCalls != 1 { + t.Fatalf("CleanupTags calls = %d, want 1", fakeTagger.cleanupCalls) } } -func TestCleanupCmdFlagRetrieval(t *testing.T) { - // Test that all flag retrievals work correctly within the RunE function - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - var capturedValues map[string]string +func TestCleanupCommand_returnsClientErrors(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "service-id", "api") - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix from a given consul service", - RunE: func(cmd *cobra.Command, args []string) error { - // Test the same flag access pattern used in the actual cleanup command - capturedValues = make(map[string]string) - capturedValues["consul-addr"] = cmd.InheritedFlags().Lookup("consul-addr").Value.String() - capturedValues["token"] = cmd.InheritedFlags().Lookup("token").Value.String() - capturedValues["service-id"] = cmd.InheritedFlags().Lookup("service-id").Value.String() - capturedValues["tag-prefix"] = cmd.InheritedFlags().Lookup("tag-prefix").Value.String() - - // Don't actually try to connect to consul - just test flag access + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + return nil, fmt.Errorf("connect consul") + }, + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + t.Fatal("NewTagger called after client error") return nil }, } - cmd.AddCommand(testCleanupCmd) - - cmd.SetArgs([]string{ - "cleanup", - "--service-id=test-service", - "--consul-addr=localhost:9500", - "--tag-prefix=test-prefix", - "--token=test-token", - }) - - err := cmd.Execute() - assert.NoError(t, err) - // Verify all values were captured correctly - assert.Equal(t, "localhost:9500", capturedValues["consul-addr"]) - assert.Equal(t, "test-token", capturedValues["token"]) - assert.Equal(t, "test-service", capturedValues["service-id"]) - assert.Equal(t, "test-prefix", capturedValues["tag-prefix"]) + err := cleanupCommand(cmd, deps) + if err == nil { + t.Fatal("cleanupCommand() error = nil, want error") + } + if !strings.Contains(err.Error(), "connect consul") { + t.Fatalf("cleanupCommand() error = %q, want client error", err) + } } -func TestCleanupCmdSuccessFlow(t *testing.T) { - // Test the successful flow of cleanup command - // Since the actual cleanupCmd creates a real Consul client internally, - // we test with a mock command that simulates the successful path - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - var logOutput bytes.Buffer - testCleanupCmd := &cobra.Command{ - Use: "cleanup", - Short: "cleanup removes all services with the tag prefix from a given consul service", - RunE: func(cmd *cobra.Command, args []string) error { - // Verify all required flags are accessible - serviceID := cmd.InheritedFlags().Lookup("service-id").Value.String() - tagPrefix := cmd.InheritedFlags().Lookup("tag-prefix").Value.String() - consulAddr := cmd.InheritedFlags().Lookup("consul-addr").Value.String() - token := cmd.InheritedFlags().Lookup("token").Value.String() - - // Simulate the logging that would happen - fmt.Fprintf(&logOutput, "Starting tag cleanup, serviceID=%s, tagPrefix=%s, consulAddr=%s\n", - serviceID, tagPrefix, consulAddr) - - if token != "" { - fmt.Fprintf(&logOutput, "Using token authentication\n") - } +func TestCleanupCommand_returnsCleanupErrors(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "service-id", "api") - // Simulate successful cleanup - fmt.Fprintf(&logOutput, "Tag cleanup completed successfully\n") - return nil + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + return commandClient{}, nil + }, + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + return cleanupErrorTagger{err: fmt.Errorf("consul write failed")} }, } - cmd.AddCommand(testCleanupCmd) - - cmd.SetArgs([]string{ - "cleanup", - "--service-id=test-service", - "--consul-addr=localhost:8500", - "--tag-prefix=test", - "--token=secret-token", - }) - - err := cmd.Execute() - assert.NoError(t, err) - // Verify the command would have executed with the right parameters - output := logOutput.String() - assert.Contains(t, output, "serviceID=test-service") - assert.Contains(t, output, "tagPrefix=test") - assert.Contains(t, output, "consulAddr=localhost:8500") - assert.Contains(t, output, "Using token authentication") - assert.Contains(t, output, "Tag cleanup completed successfully") -} - -// MockConsulClient for testing -type MockConsulClient struct { - MockAgent consul.Agent -} - -func (m *MockConsulClient) Agent() consul.Agent { - return m.MockAgent + err := cleanupCommand(cmd, deps) + if err == nil { + t.Fatal("cleanupCommand() error = nil, want error") + } + if !strings.Contains(err.Error(), "failed to clean up tags: consul write failed") { + t.Fatalf("cleanupCommand() error = %q, want cleanup context", err) + } } -// MockAgent implements the Agent interface -type MockAgent struct { - ServiceFunc func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) - ServiceRegisterFunc func(reg *api.AgentServiceRegistration) error -} +func TestCleanupCmd_RunEUsesSharedHandler(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) -func (m *MockAgent) Service(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - if m.ServiceFunc != nil { - return m.ServiceFunc(serviceID, q) + err := cleanupCmd.RunE(cmd, nil) + if err == nil { + t.Fatal("cleanupCmd.RunE() error = nil, want validation error") } - return &api.AgentService{ - ID: "test-service", - Service: "test", - Tags: []string{"tagged:old", "other-tag"}, - }, nil, nil -} - -func (m *MockAgent) ServiceRegister(reg *api.AgentServiceRegistration) error { - if m.ServiceRegisterFunc != nil { - return m.ServiceRegisterFunc(reg) + if !strings.Contains(err.Error(), "service-id is required") { + t.Fatalf("cleanupCmd.RunE() error = %q, want service-id validation", err) } - return nil } -func TestCleanupCmdWithMockFactory(t *testing.T) { - // Save and restore the original factory - originalFactory := consul.Factory - defer func() { - consul.Factory = originalFactory - }() - - t.Run("Successful cleanup with mock", func(t *testing.T) { - // Create a mock agent that simulates a service with tags - mockAgent := &MockAgent{ - ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - return &api.AgentService{ - ID: serviceID, - Service: "test", - Tags: []string{"tagged-value1", "tagged-value2", "other-tag"}, - }, nil, nil - }, - ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { - // Verify that the tags were cleaned up - assert.Equal(t, "test-service", reg.ID) - assert.NotContains(t, reg.Tags, "tagged-value1") - assert.NotContains(t, reg.Tags, "tagged-value2") - assert.Contains(t, reg.Tags, "other-tag") - return nil - }, - } - - // Create mock client with the mock agent - mockClient := &MockConsulClient{ - MockAgent: mockAgent, - } - - // Set up the mock factory - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "cleanup", - RunE: cleanupCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "tagged", "") - parent.AddCommand(cmd) - - // Run the actual cleanup command - err := cmd.RunE(cmd, []string{}) - assert.NoError(t, err) - }) - - t.Run("Cleanup with connection error", func(t *testing.T) { - // Set up a factory that returns an error - mockFactory := &consul.MockFactory{ - MockError: fmt.Errorf("connection failed"), - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "cleanup", - RunE: cleanupCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "tagged", "") - parent.AddCommand(cmd) - - // Run the cleanup command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - }) - - t.Run("Cleanup with service not found", func(t *testing.T) { - // Create a mock agent that returns nil service - mockAgent := &MockAgent{ - ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - return nil, nil, nil - }, - } - - // Create mock client with the mock agent - mockClient := &MockConsulClient{ - MockAgent: mockAgent, - } - - // Set up the mock factory - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "cleanup", - RunE: cleanupCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "tagged", "") - parent.AddCommand(cmd) - - // Run the cleanup command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "service test-service not found") - }) - - t.Run("Cleanup with service register error", func(t *testing.T) { - // Create a mock agent that simulates a service with tags but fails on register - mockAgent := &MockAgent{ - ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - return &api.AgentService{ - ID: serviceID, - Service: "test", - Tags: []string{"tagged-value1", "other-tag"}, - }, nil, nil - }, - ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { - return fmt.Errorf("failed to register service") - }, - } - - // Create mock client with the mock agent - mockClient := &MockConsulClient{ - MockAgent: mockAgent, - } - - // Set up the mock factory - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) +type cleanupErrorTagger struct { + err error +} - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "cleanup", - RunE: cleanupCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "tagged", "") - parent.AddCommand(cmd) +func (cet cleanupErrorTagger) Run(context.Context) {} - // Run the cleanup command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to clean up tags") - }) +func (cet cleanupErrorTagger) CleanupTags() error { + return cet.err } diff --git a/cmd/handlers.go b/cmd/handlers.go new file mode 100644 index 0000000..c59aff3 --- /dev/null +++ b/cmd/handlers.go @@ -0,0 +1,109 @@ +package cmd + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/ncode/tagit/pkg/consul" + "github.com/ncode/tagit/pkg/tagit" + "github.com/spf13/cobra" +) + +type tagger interface { + Run(context.Context) + CleanupTags() error +} + +type commandDeps struct { + Logger *slog.Logger + NewClient func(address, token string) (consul.Client, error) + NewExecutor func() tagit.CommandExecutor + NewTagger func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger +} + +func runCommandWithContext(ctx context.Context, cmd *cobra.Command, deps commandDeps) error { + deps = deps.withDefaults() + + input, err := resolveRunInput(cmd) + if err != nil { + deps.Logger.Error("Invalid command input", "error", err) + return err + } + + consulClient, err := deps.NewClient(input.ConsulAddr, input.Token) + if err != nil { + deps.Logger.Error("Failed to create Consul client", "error", err) + return err + } + + t := deps.NewTagger(consulClient, deps.NewExecutor(), input, deps.Logger) + deps.Logger.Info("Starting tagit", + "serviceID", input.ServiceID, + "script", input.Script, + "interval", input.Interval, + "tagPrefix", input.TagPrefix) + + t.Run(ctx) + + deps.Logger.Info("Tagit has stopped") + return nil +} + +func cleanupCommand(cmd *cobra.Command, deps commandDeps) error { + deps = deps.withDefaults() + + input, err := resolveCleanupInput(cmd) + if err != nil { + deps.Logger.Error("Invalid command input", "error", err) + return err + } + + consulClient, err := deps.NewClient(input.ConsulAddr, input.Token) + if err != nil { + deps.Logger.Error("Failed to create Consul client", "error", err) + return err + } + + t := deps.NewTagger(consulClient, deps.NewExecutor(), input, deps.Logger) + deps.Logger.Info("Starting tag cleanup", "serviceID", input.ServiceID, "tagPrefix", input.TagPrefix) + + if err := t.CleanupTags(); err != nil { + deps.Logger.Error("Failed to clean up tags", "error", err) + return fmt.Errorf("failed to clean up tags: %w", err) + } + + deps.Logger.Info("Tag cleanup completed successfully") + return nil +} + +func (deps commandDeps) withDefaults() commandDeps { + if deps.Logger == nil { + deps.Logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + } + if deps.NewClient == nil { + deps.NewClient = consul.CreateClient + } + if deps.NewExecutor == nil { + deps.NewExecutor = func() tagit.CommandExecutor { + return &tagit.CmdExecutor{} + } + } + if deps.NewTagger == nil { + deps.NewTagger = func(client consul.Client, executor tagit.CommandExecutor, input commandInput, logger *slog.Logger) tagger { + return tagit.New( + client, + executor, + input.ServiceID, + input.Script, + input.Interval, + input.TagPrefix, + logger, + ) + } + } + return deps +} diff --git a/cmd/intake.go b/cmd/intake.go new file mode 100644 index 0000000..0733a5a --- /dev/null +++ b/cmd/intake.go @@ -0,0 +1,133 @@ +package cmd + +import ( + "fmt" + "strings" + "time" + + "github.com/ncode/tagit/pkg/systemd" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +type commandInput struct { + ConsulAddr string + ServiceID string + Script string + TagPrefix string + Interval time.Duration + IntervalRaw string + Token string +} + +type systemdInput struct { + Invocation systemd.Invocation + User string + Group string +} + +func configureCommandIntakeEnv() { + viper.SetEnvPrefix("TAGIT") + viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + viper.AutomaticEnv() +} + +func resolveRunInput(cmd *cobra.Command) (commandInput, error) { + input := resolveSharedInput(cmd) + if strings.TrimSpace(input.ServiceID) == "" { + return commandInput{}, fmt.Errorf("service-id is required") + } + if strings.TrimSpace(input.Script) == "" { + return commandInput{}, fmt.Errorf("script is required") + } + interval, err := parseRequiredInterval(input.IntervalRaw) + if err != nil { + return commandInput{}, err + } + input.Interval = interval + return input, nil +} + +func resolveCleanupInput(cmd *cobra.Command) (commandInput, error) { + input := resolveSharedInput(cmd) + if strings.TrimSpace(input.ServiceID) == "" { + return commandInput{}, fmt.Errorf("service-id is required") + } + return input, nil +} + +func resolveSystemdInput(cmd *cobra.Command) (systemdInput, error) { + input, err := resolveRunInput(cmd) + if err != nil { + return systemdInput{}, err + } + + user := resolveString(cmd, "user") + if strings.TrimSpace(user) == "" { + return systemdInput{}, fmt.Errorf("user is required") + } + group := resolveString(cmd, "group") + if strings.TrimSpace(group) == "" { + return systemdInput{}, fmt.Errorf("group is required") + } + + return systemdInput{ + Invocation: systemd.Invocation{ + ServiceID: input.ServiceID, + Script: input.Script, + TagPrefix: input.TagPrefix, + Interval: input.IntervalRaw, + Token: input.Token, + ConsulAddr: input.ConsulAddr, + }, + User: user, + Group: group, + }, nil +} + +func resolveSharedInput(cmd *cobra.Command) commandInput { + return commandInput{ + ConsulAddr: resolveString(cmd, "consul-addr"), + ServiceID: resolveString(cmd, "service-id"), + Script: resolveString(cmd, "script"), + TagPrefix: resolveString(cmd, "tag-prefix"), + IntervalRaw: resolveString(cmd, "interval"), + Token: resolveString(cmd, "token"), + } +} + +func parseRequiredInterval(raw string) (time.Duration, error) { + if strings.TrimSpace(raw) == "" { + return 0, fmt.Errorf("interval is required and cannot be empty or zero") + } + interval, err := time.ParseDuration(raw) + if err != nil { + return 0, fmt.Errorf("invalid interval %q: %w", raw, err) + } + if interval <= 0 { + return 0, fmt.Errorf("interval is required and cannot be empty or zero: interval must be greater than zero") + } + return interval, nil +} + +func resolveString(cmd *cobra.Command, key string) string { + flag := lookupFlag(cmd, key) + if flag != nil && flag.Changed { + return flag.Value.String() + } + if viper.IsSet(key) { + return viper.GetString(key) + } + if flag != nil { + return flag.Value.String() + } + return "" +} + +func lookupFlag(cmd *cobra.Command, key string) *pflag.Flag { + if flag := cmd.Flags().Lookup(key); flag != nil { + return flag + } + return cmd.InheritedFlags().Lookup(key) +} diff --git a/cmd/intake_test.go b/cmd/intake_test.go new file mode 100644 index 0000000..8bfd523 --- /dev/null +++ b/cmd/intake_test.go @@ -0,0 +1,218 @@ +package cmd + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +func TestResolveRunInput_readsCLIValues(t *testing.T) { + resetViper(t) + + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "consul-addr", "consul.example:8500") + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "script", "/opt/tagit/tags.sh") + setFlag(t, cmd.InheritedFlags(), "tag-prefix", "role") + setFlag(t, cmd.InheritedFlags(), "interval", "15s") + setFlag(t, cmd.InheritedFlags(), "token", "secret") + + got, err := resolveRunInput(cmd) + if err != nil { + t.Fatalf("resolveRunInput() error = %v", err) + } + + want := commandInput{ + ConsulAddr: "consul.example:8500", + ServiceID: "api", + Script: "/opt/tagit/tags.sh", + TagPrefix: "role", + Interval: 15 * time.Second, + IntervalRaw: "15s", + Token: "secret", + } + if got != want { + t.Fatalf("resolveRunInput() = %#v, want %#v", got, want) + } +} + +func TestResolveRunInput_readsConfigAndEnvironment(t *testing.T) { + resetViper(t) + configureCommandIntakeEnv() + t.Setenv("TAGIT_SERVICE_ID", "env-service") + t.Setenv("TAGIT_TOKEN", "env-token") + t.Setenv("TAGIT_INTERVAL", "45s") + + writeIntakeConfig(t, `consul-addr: "config-consul:8500" +service-id: "config-service" +script: "/config/script.sh" +tag-prefix: "config-prefix" +interval: "30s" +token: "config-token" +`) + + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "consul-addr", "cli-consul:8500") + + got, err := resolveRunInput(cmd) + if err != nil { + t.Fatalf("resolveRunInput() error = %v", err) + } + + if got.ConsulAddr != "cli-consul:8500" { + t.Fatalf("ConsulAddr = %q, want CLI override", got.ConsulAddr) + } + if got.ServiceID != "env-service" { + t.Fatalf("ServiceID = %q, want environment value", got.ServiceID) + } + if got.Script != "/config/script.sh" { + t.Fatalf("Script = %q, want config value", got.Script) + } + if got.TagPrefix != "config-prefix" { + t.Fatalf("TagPrefix = %q, want config value", got.TagPrefix) + } + if got.Interval != 45*time.Second { + t.Fatalf("Interval = %v, want environment value", got.Interval) + } + if got.Token != "env-token" { + t.Fatalf("Token = %q, want environment value", got.Token) + } +} + +func TestResolveRunInput_validationErrors(t *testing.T) { + tests := []struct { + name string + values map[string]string + wantErr string + }{ + { + name: "missing service ID", + values: map[string]string{ + "script": "/opt/tagit/tags.sh", + "interval": "15s", + }, + wantErr: "service-id is required", + }, + { + name: "missing script", + values: map[string]string{ + "service-id": "api", + "interval": "15s", + }, + wantErr: "script is required", + }, + { + name: "missing interval", + values: map[string]string{ + "service-id": "api", + "script": "/opt/tagit/tags.sh", + "interval": "", + }, + wantErr: "interval is required", + }, + { + name: "zero interval", + values: map[string]string{ + "service-id": "api", + "script": "/opt/tagit/tags.sh", + "interval": "0s", + }, + wantErr: "interval must be greater than zero", + }, + { + name: "invalid interval", + values: map[string]string{ + "service-id": "api", + "script": "/opt/tagit/tags.sh", + "interval": "soon", + }, + wantErr: "invalid interval", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + for key, value := range tt.values { + setFlag(t, cmd.InheritedFlags(), key, value) + } + + _, err := resolveRunInput(cmd) + if err == nil { + t.Fatal("resolveRunInput() error = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("resolveRunInput() error = %q, want substring %q", err, tt.wantErr) + } + }) + } +} + +func TestResolveCleanupInput_validationErrors(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "cleanup", withSharedPersistentFlags) + + _, err := resolveCleanupInput(cmd) + if err == nil { + t.Fatal("resolveCleanupInput() error = nil, want error") + } + if !strings.Contains(err.Error(), "service-id is required") { + t.Fatalf("resolveCleanupInput() error = %q, want service-id validation", err) + } +} + +func resetViper(t *testing.T) { + t.Helper() + + viper.Reset() + t.Cleanup(viper.Reset) +} + +func writeIntakeConfig(t *testing.T, content string) { + t.Helper() + + path := filepath.Join(t.TempDir(), ".tagit.yaml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + viper.SetConfigFile(path) + if err := viper.ReadInConfig(); err != nil { + t.Fatalf("read config: %v", err) + } +} + +func newIntakeTestCommand(t *testing.T, use string, flagSetups ...func(*pflag.FlagSet)) *cobra.Command { + t.Helper() + + parent := &cobra.Command{Use: "tagit"} + for _, setup := range flagSetups { + setup(parent.PersistentFlags()) + } + child := &cobra.Command{Use: use} + parent.AddCommand(child) + return child +} + +func withSharedPersistentFlags(flags *pflag.FlagSet) { + flags.StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") + flags.StringP("service-id", "s", "", "consul service id") + flags.StringP("script", "x", "", "path to script used to generate tags") + flags.StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") + flags.StringP("interval", "i", "60s", "interval to run the script") + flags.StringP("token", "t", "", "consul token") +} + +func setFlag(t *testing.T, flags *pflag.FlagSet, name, value string) { + t.Helper() + + if err := flags.Set(name, value); err != nil { + t.Fatalf("set flag %s: %v", name, err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 3669e7d..6320fef 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,9 +17,10 @@ package cmd import ( "fmt" + "os" + "github.com/spf13/cobra" "github.com/spf13/viper" - "os" ) var cfgFile string @@ -66,7 +67,7 @@ func initConfig() { viper.SetConfigName(".tagit") } - viper.AutomaticEnv() // read in environment variables that match + configureCommandIntakeEnv() // If a config file is found, read it in. if err := viper.ReadInConfig(); err == nil { diff --git a/cmd/run.go b/cmd/run.go index 9930b2c..8d96a45 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -17,15 +17,9 @@ package cmd import ( "context" - "fmt" - "log/slog" - "os" "os/signal" "syscall" - "time" - "github.com/ncode/tagit/pkg/consul" - "github.com/ncode/tagit/pkg/tagit" "github.com/spf13/cobra" ) @@ -38,101 +32,10 @@ var runCmd = &cobra.Command{ example: tagit run -s my-super-service -x '/tmp/tag-role.sh' `, RunE: func(cmd *cobra.Command, args []string) error { - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelInfo, - })) + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() - interval, err := cmd.InheritedFlags().GetString("interval") - if err != nil { - logger.Error("Failed to get interval flag", "error", err) - return err - } - - if interval == "" || interval == "0" { - logger.Error("Interval is required") - return fmt.Errorf("interval is required and cannot be empty or zero") - } - - validInterval, err := time.ParseDuration(interval) - if err != nil { - logger.Error("Invalid interval", "interval", interval, "error", err) - return fmt.Errorf("invalid interval %q: %w", interval, err) - } - - consulAddr, err := cmd.InheritedFlags().GetString("consul-addr") - if err != nil { - logger.Error("Failed to get consul-addr flag", "error", err) - return err - } - token, err := cmd.InheritedFlags().GetString("token") - if err != nil { - logger.Error("Failed to get token flag", "error", err) - return err - } - - consulClient, err := consul.CreateClient(consulAddr, token) - if err != nil { - logger.Error("Failed to create Consul client", "error", err) - return err - } - - serviceID, err := cmd.InheritedFlags().GetString("service-id") - if err != nil { - logger.Error("Failed to get service-id flag", "error", err) - return err - } - if serviceID == "" { - logger.Error("Service ID is required") - return fmt.Errorf("service-id is required") - } - script, err := cmd.InheritedFlags().GetString("script") - if err != nil { - logger.Error("Failed to get script flag", "error", err) - return err - } - if script == "" { - logger.Error("Script is required") - return fmt.Errorf("script is required") - } - tagPrefix, err := cmd.InheritedFlags().GetString("tag-prefix") - if err != nil { - logger.Error("Failed to get tag-prefix flag", "error", err) - return err - } - - t := tagit.New( - consulClient, - &tagit.CmdExecutor{}, - serviceID, - script, - validInterval, - tagPrefix, - logger, - ) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Setup signal handling for graceful shutdown - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - go func() { - sig := <-sigCh - logger.Info("Received signal, shutting down", "signal", sig) - cancel() - }() - - logger.Info("Starting tagit", - "serviceID", serviceID, - "script", script, - "interval", validInterval, - "tagPrefix", tagPrefix) - - t.Run(ctx) - - logger.Info("Tagit has stopped") - return nil + return runCommandWithContext(ctx, cmd, commandDeps{}) }, } diff --git a/cmd/run_test.go b/cmd/run_test.go index 77f43b6..5ba7349 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -1,679 +1,258 @@ package cmd import ( - "bytes" "context" "fmt" - "os" - "sync/atomic" + "io" + "log/slog" + "strings" "testing" "time" "github.com/hashicorp/consul/api" "github.com/ncode/tagit/pkg/consul" - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" + "github.com/ncode/tagit/pkg/tagit" ) -func TestRunCmd(t *testing.T) { - // Save original args and restore after test - originalArgs := os.Args - defer func() { os.Args = originalArgs }() - - tests := []struct { - name string - args []string - expectError bool - errorContains string - }{ - { - name: "Missing required service-id", - args: []string{"run", "--script=/tmp/test.sh"}, - expectError: true, - errorContains: "service-id is required", +func TestRunCommand_validatesInputBeforeCreatingClient(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "interval", "15s") + + clientCalls := 0 + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + clientCalls++ + return commandClient{}, nil }, - { - name: "Missing required script", - args: []string{"run", "--service-id=test-service"}, - expectError: true, - errorContains: "script is required", + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + t.Fatal("NewTagger called before validation") + return nil }, - { - name: "Invalid interval format", - args: []string{"run", "--service-id=test-service", "--script=/tmp/test.sh", "--interval=invalid"}, - expectError: true, - errorContains: "invalid interval", + } + + err := runCommandWithContext(t.Context(), cmd, deps) + if err == nil { + t.Fatal("runCommandWithContext() error = nil, want error") + } + if !strings.Contains(err.Error(), "script is required") { + t.Fatalf("runCommandWithContext() error = %q, want script validation", err) + } + if clientCalls != 0 { + t.Fatalf("NewClient calls = %d, want 0", clientCalls) + } +} + +func TestRunCommand_usesResolvedInput(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "consul-addr", "consul.example:8500") + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "script", "echo primary") + setFlag(t, cmd.InheritedFlags(), "tag-prefix", "role") + setFlag(t, cmd.InheritedFlags(), "interval", "15s") + setFlag(t, cmd.InheritedFlags(), "token", "secret") + + var gotAddress string + var gotToken string + var gotInput commandInput + fakeTagger := &commandTagger{} + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + gotAddress = address + gotToken = token + return commandClient{}, nil }, - { - name: "Empty interval", - args: []string{"run", "--service-id=test-service", "--script=/tmp/test.sh", "--interval="}, - expectError: true, - errorContains: "interval is required and cannot be empty or zero", + NewExecutor: func() tagit.CommandExecutor { + return &tagit.CmdExecutor{} }, - { - name: "Zero interval", - args: []string{"run", "--service-id=test-service", "--script=/tmp/test.sh", "--interval=0"}, - expectError: true, - errorContains: "interval is required and cannot be empty or zero", + NewTagger: func(client consul.Client, executor tagit.CommandExecutor, input commandInput, logger *slog.Logger) tagger { + if client == nil { + t.Fatal("client = nil") + } + if executor == nil { + t.Fatal("executor = nil") + } + if logger == nil { + t.Fatal("logger = nil") + } + gotInput = input + return fakeTagger }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create a new root command for each test - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - // Add the run command - testRunCmd := &cobra.Command{ - Use: "run", - Short: "Run tagit", - RunE: runCmd.RunE, - } - cmd.AddCommand(testRunCmd) - - // Capture stderr - var buf bytes.Buffer - cmd.SetErr(&buf) - cmd.SetArgs(tt.args) - - // Set a context with timeout to prevent hanging - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - done <- cmd.Execute() - }() - - select { - case err := <-done: - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - case <-ctx.Done(): - if tt.expectError { - t.Log("Command timed out as expected for invalid input") - } else { - t.Error("Command timed out unexpectedly") - } - } - }) + if err := runCommandWithContext(t.Context(), cmd, deps); err != nil { + t.Fatalf("runCommandWithContext() error = %v", err) + } + + if gotAddress != "consul.example:8500" { + t.Fatalf("address = %q, want consul.example:8500", gotAddress) + } + if gotToken != "secret" { + t.Fatalf("token = %q, want secret", gotToken) + } + wantInput := commandInput{ + ConsulAddr: "consul.example:8500", + ServiceID: "api", + Script: "echo primary", + TagPrefix: "role", + Interval: 15 * time.Second, + IntervalRaw: "15s", + Token: "secret", + } + if gotInput != wantInput { + t.Fatalf("input = %#v, want %#v", gotInput, wantInput) + } + if fakeTagger.runCalls != 1 { + t.Fatalf("Run calls = %d, want 1", fakeTagger.runCalls) } } -func TestRunCmdFlagParsing(t *testing.T) { - var capturedFlags map[string]string - - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testRunCmd := &cobra.Command{ - Use: "run", - Short: "Run tagit", - Run: func(cmd *cobra.Command, args []string) { - // Capture flag values during execution - capturedFlags = make(map[string]string) - capturedFlags["service-id"], _ = cmd.InheritedFlags().GetString("service-id") - capturedFlags["script"], _ = cmd.InheritedFlags().GetString("script") - capturedFlags["interval"], _ = cmd.InheritedFlags().GetString("interval") - capturedFlags["tag-prefix"], _ = cmd.InheritedFlags().GetString("tag-prefix") - capturedFlags["consul-addr"], _ = cmd.InheritedFlags().GetString("consul-addr") - capturedFlags["token"], _ = cmd.InheritedFlags().GetString("token") +func TestRunCommand_returnsClientErrors(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "script", "echo primary") + setFlag(t, cmd.InheritedFlags(), "interval", "15s") + + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + return nil, fmt.Errorf("connect consul") }, + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + t.Fatal("NewTagger called after client error") + return nil + }, + } + + err := runCommandWithContext(t.Context(), cmd, deps) + if err == nil { + t.Fatal("runCommandWithContext() error = nil, want error") + } + if !strings.Contains(err.Error(), "connect consul") { + t.Fatalf("runCommandWithContext() error = %q, want client error", err) } - cmd.AddCommand(testRunCmd) - - cmd.SetArgs([]string{ - "run", - "--service-id=test-service", - "--script=/tmp/test.sh", - "--interval=30s", - "--tag-prefix=test", - "--consul-addr=localhost:8500", - "--token=test-token", - }) - - err := cmd.Execute() - assert.NoError(t, err) - - // Verify flags were parsed correctly - assert.Equal(t, "test-service", capturedFlags["service-id"]) - assert.Equal(t, "/tmp/test.sh", capturedFlags["script"]) - assert.Equal(t, "30s", capturedFlags["interval"]) - assert.Equal(t, "test", capturedFlags["tag-prefix"]) - assert.Equal(t, "localhost:8500", capturedFlags["consul-addr"]) - assert.Equal(t, "test-token", capturedFlags["token"]) } -func TestRunCmdExecutionErrors(t *testing.T) { - tests := []struct { - name string - consulAddr string - expectError bool - errorContains string - }{ - { - name: "Invalid consul address", - consulAddr: "invalid-consul-address", - expectError: true, - errorContains: "failed to create Consul client", +func TestRunCommand_passesCancellationToTagger(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) + setFlag(t, cmd.InheritedFlags(), "service-id", "api") + setFlag(t, cmd.InheritedFlags(), "script", "echo primary") + setFlag(t, cmd.InheritedFlags(), "interval", "15s") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + observedCancellation := false + deps := commandDeps{ + Logger: discardLogger(), + NewClient: func(address, token string) (consul.Client, error) { + return commandClient{}, nil + }, + NewTagger: func(consul.Client, tagit.CommandExecutor, commandInput, *slog.Logger) tagger { + return runFuncTagger(func(ctx context.Context) { + select { + case <-ctx.Done(): + observedCancellation = true + default: + } + }) }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testRunCmd := &cobra.Command{ - Use: "run", - Short: "Run tagit", - RunE: func(cmd *cobra.Command, args []string) error { - // Test the same initial setup as the real run command but stop before running - interval, err := cmd.InheritedFlags().GetString("interval") - if err != nil { - return err - } - - if interval == "" || interval == "0" { - return fmt.Errorf("interval is required and cannot be empty or zero") - } - - _, err = time.ParseDuration(interval) - if err != nil { - return fmt.Errorf("invalid interval %q: %w", interval, err) - } - - consulAddr, err := cmd.InheritedFlags().GetString("consul-addr") - if err != nil { - return err - } - - // Test consul client creation with invalid address - if consulAddr == "invalid-consul-address" { - return fmt.Errorf("failed to create Consul client: invalid address") - } - - // Don't actually start the service - just return success for valid inputs - return nil - }, - } - cmd.AddCommand(testRunCmd) - - var stderr bytes.Buffer - cmd.SetErr(&stderr) - cmd.SetArgs([]string{ - "run", - "--service-id=test-service", - "--script=/tmp/test.sh", - "--consul-addr=" + tt.consulAddr, - "--tag-prefix=test", - "--interval=30s", - }) + if err := runCommandWithContext(ctx, cmd, deps); err != nil { + t.Fatalf("runCommandWithContext() error = %v", err) + } + if !observedCancellation { + t.Fatal("tagger did not observe cancelled context") + } +} - err := cmd.Execute() +func TestRunCmd_RunEUsesSharedHandler(t *testing.T) { + resetViper(t) + cmd := newIntakeTestCommand(t, "run", withSharedPersistentFlags) - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) + err := runCmd.RunE(cmd, nil) + if err == nil { + t.Fatal("runCmd.RunE() error = nil, want validation error") + } + if !strings.Contains(err.Error(), "service-id is required") { + t.Fatalf("runCmd.RunE() error = %q, want service-id validation", err) } } -func TestRunCmdFlagRetrievalErrors(t *testing.T) { - // Test flag retrieval error paths in the RunE function - tests := []struct { - name string - interval string - expectError bool - errorContains string - }{ - { - name: "GetString error simulation for interval", - interval: "30s", // This won't actually cause GetString to error in this test setup - expectError: false, - }, - { - name: "Valid duration parsing", - interval: "1m30s", - expectError: false, - }, +func TestCommandDeps_withDefaults(t *testing.T) { + deps := (commandDeps{}).withDefaults() + + if deps.Logger == nil { + t.Fatal("Logger = nil, want default logger") + } + if deps.NewClient == nil { + t.Fatal("NewClient = nil, want default client factory") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - var capturedData map[string]interface{} - - testRunCmd := &cobra.Command{ - Use: "run", - Short: "Run tagit", - RunE: func(cmd *cobra.Command, args []string) error { - capturedData = make(map[string]interface{}) - - // Test the same flag retrieval pattern as in the actual run command - interval, err := cmd.InheritedFlags().GetString("interval") - if err != nil { - return err - } - capturedData["interval-string"] = interval - - if interval == "" || interval == "0" { - return fmt.Errorf("interval is required and cannot be empty or zero") - } - - validInterval, err := time.ParseDuration(interval) - if err != nil { - return fmt.Errorf("invalid interval %q: %w", interval, err) - } - capturedData["parsed-interval"] = validInterval - - // Test other flag retrievals - config := make(map[string]string) - config["address"], err = cmd.InheritedFlags().GetString("consul-addr") - if err != nil { - return err - } - config["token"], err = cmd.InheritedFlags().GetString("token") - if err != nil { - return err - } - capturedData["config"] = config - - serviceID, err := cmd.InheritedFlags().GetString("service-id") - if err != nil { - return err - } - script, err := cmd.InheritedFlags().GetString("script") - if err != nil { - return err - } - tagPrefix, err := cmd.InheritedFlags().GetString("tag-prefix") - if err != nil { - return err - } - - capturedData["service-id"] = serviceID - capturedData["script"] = script - capturedData["tag-prefix"] = tagPrefix - - // Don't actually run anything - just test flag access - return nil - }, - } - cmd.AddCommand(testRunCmd) - - cmd.SetArgs([]string{ - "run", - "--service-id=test-service", - "--script=/tmp/test.sh", - "--consul-addr=localhost:8500", - "--tag-prefix=test-prefix", - "--interval=" + tt.interval, - "--token=test-token", - }) + executor := deps.NewExecutor() + if executor == nil { + t.Fatal("NewExecutor() = nil, want executor") + } - err := cmd.Execute() + tagger := deps.NewTagger(commandClient{}, executor, commandInput{ + ServiceID: "api", + Script: "echo primary", + Interval: time.Second, + TagPrefix: "role", + IntervalRaw: "1s", + }, deps.Logger) + if tagger == nil { + t.Fatal("NewTagger() = nil, want tagger") + } +} - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) +type commandClient struct{} - // Verify all values were captured correctly - assert.Equal(t, tt.interval, capturedData["interval-string"]) - expectedDuration, _ := time.ParseDuration(tt.interval) - assert.Equal(t, expectedDuration, capturedData["parsed-interval"]) +func (commandClient) Agent() consul.Agent { + return commandAgent{} +} - config := capturedData["config"].(map[string]string) - assert.Equal(t, "localhost:8500", config["address"]) - assert.Equal(t, "test-token", config["token"]) +type commandAgent struct{} - assert.Equal(t, "test-service", capturedData["service-id"]) - assert.Equal(t, "/tmp/test.sh", capturedData["script"]) - assert.Equal(t, "test-prefix", capturedData["tag-prefix"]) - } - }) - } +func (commandAgent) Service(string, *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return nil, nil, nil } -func TestRunCmdCompleteFlow(t *testing.T) { - // Test the complete flow of the run command with all flag retrievals - tests := []struct { - name string - setupCmd func() *cobra.Command - args []string - expectError bool - errorContains string - }{ - { - name: "Valid configuration with all flags", - setupCmd: func() *cobra.Command { - cmd := &cobra.Command{Use: "tagit"} - cmd.PersistentFlags().StringP("consul-addr", "c", "127.0.0.1:8500", "consul address") - cmd.PersistentFlags().StringP("service-id", "s", "", "consul service id") - cmd.PersistentFlags().StringP("script", "x", "", "path to script used to generate tags") - cmd.PersistentFlags().StringP("tag-prefix", "p", "tagged", "prefix to be added to tags") - cmd.PersistentFlags().StringP("interval", "i", "60s", "interval to run the script") - cmd.PersistentFlags().StringP("token", "t", "", "consul token") - - testRunCmd := &cobra.Command{ - Use: "run", - Short: "Run tagit", - RunE: func(cmd *cobra.Command, args []string) error { - // Simulate all the flag retrievals from the actual run command - interval, err := cmd.InheritedFlags().GetString("interval") - if err != nil { - return err - } - - if interval == "" || interval == "0" { - return fmt.Errorf("interval is required and cannot be empty or zero") - } - - _, err = time.ParseDuration(interval) - if err != nil { - return fmt.Errorf("invalid interval %q: %w", interval, err) - } - - // Test all flag retrievals - consulAddr, err := cmd.InheritedFlags().GetString("consul-addr") - if err != nil { - return fmt.Errorf("failed to get consul-addr flag: %w", err) - } - - token, err := cmd.InheritedFlags().GetString("token") - if err != nil { - return fmt.Errorf("failed to get token flag: %w", err) - } - - serviceID, err := cmd.InheritedFlags().GetString("service-id") - if err != nil { - return fmt.Errorf("failed to get service-id flag: %w", err) - } - - script, err := cmd.InheritedFlags().GetString("script") - if err != nil { - return fmt.Errorf("failed to get script flag: %w", err) - } - - tagPrefix, err := cmd.InheritedFlags().GetString("tag-prefix") - if err != nil { - return fmt.Errorf("failed to get tag-prefix flag: %w", err) - } - - // Validate we got all values - if consulAddr == "" || serviceID == "" || script == "" || tagPrefix == "" { - return fmt.Errorf("missing required flags") - } - - // Don't create real consul client or run the service - // Just verify all flags were retrieved successfully - _ = token // token is optional - - return nil - }, - } - cmd.AddCommand(testRunCmd) - return cmd - }, - args: []string{ - "run", - "--service-id=test-service", - "--script=/tmp/test.sh", - "--consul-addr=localhost:8500", - "--tag-prefix=test", - "--interval=30s", - "--token=test-token", - }, - expectError: false, - }, - } +func (commandAgent) ServiceRegister(*api.AgentServiceRegistration) error { + return nil +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := tt.setupCmd() - cmd.SetArgs(tt.args) +type commandTagger struct { + runCalls int + cleanupCalls int +} - err := cmd.Execute() +func (ct *commandTagger) Run(context.Context) { + ct.runCalls++ +} - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } +func (ct *commandTagger) CleanupTags() error { + ct.cleanupCalls++ + return nil +} + +type runFuncTagger func(context.Context) + +func (rft runFuncTagger) Run(ctx context.Context) { + rft(ctx) +} + +func (rft runFuncTagger) CleanupTags() error { + return nil } -func TestRunCmdWithMockFactory(t *testing.T) { - // Save and restore the original factory - originalFactory := consul.Factory - defer func() { - consul.Factory = originalFactory - }() - - t.Run("Successful run with mock", func(t *testing.T) { - // Track if service was registered at least once - var registered atomic.Bool - - // Create a mock agent that simulates a service - mockAgent := &MockAgent{ - ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - return &api.AgentService{ - ID: serviceID, - Service: "test", - Tags: []string{"existing-tag"}, - }, nil, nil - }, - ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { - // Verify that new tags were added - registered.Store(true) - assert.Contains(t, reg.Tags, "existing-tag") - assert.Contains(t, reg.Tags, "test-tag1") - assert.Contains(t, reg.Tags, "test-tag2") - return nil - }, - } - - // Create mock client with the mock agent - mockClient := &MockConsulClient{ - MockAgent: mockAgent, - } - - // Set up the mock factory - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "run", - RunE: runCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "test", "") - parent.PersistentFlags().String("script", "echo 'tag1 tag2'", "") - parent.PersistentFlags().String("interval", "100ms", "") // Short interval for testing - parent.AddCommand(cmd) - - // Run the command in a goroutine with timeout - done := make(chan error) - go func() { - done <- cmd.RunE(cmd, []string{}) - }() - - // Let it run for a short time - time.Sleep(250 * time.Millisecond) - - // The command should have registered the service at least once - assert.True(t, registered.Load(), "Service should have been registered at least once") - - // Note: The run command runs forever, so we can't test it finishing cleanly - // This test verifies it starts correctly and processes at least one update - }) - - t.Run("Run with invalid interval", func(t *testing.T) { - // Set up a valid mock factory - mockClient := &MockConsulClient{ - MockAgent: &MockAgent{}, - } - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "run", - RunE: runCmd.RunE, - } - // Set up parent command with invalid interval - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "test", "") - parent.PersistentFlags().String("script", "echo 'tag1'", "") - parent.PersistentFlags().String("interval", "invalid", "") - parent.AddCommand(cmd) - - // Run the command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid interval") - }) - - t.Run("Run with empty interval", func(t *testing.T) { - // Set up a valid mock factory - mockClient := &MockConsulClient{ - MockAgent: &MockAgent{}, - } - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "run", - RunE: runCmd.RunE, - } - // Set up parent command with empty interval - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "test", "") - parent.PersistentFlags().String("script", "echo 'tag1'", "") - parent.PersistentFlags().String("interval", "", "") - parent.AddCommand(cmd) - - // Run the command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "interval is required") - }) - - t.Run("Run with zero interval", func(t *testing.T) { - // Set up a valid mock factory - mockClient := &MockConsulClient{ - MockAgent: &MockAgent{}, - } - mockFactory := &consul.MockFactory{ - MockClient: mockClient, - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "run", - RunE: runCmd.RunE, - } - // Set up parent command with zero interval - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "test", "") - parent.PersistentFlags().String("script", "echo 'tag1'", "") - parent.PersistentFlags().String("interval", "0", "") - parent.AddCommand(cmd) - - // Run the command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "interval is required") - }) - - t.Run("Run with connection error", func(t *testing.T) { - // Set up a factory that returns an error - mockFactory := &consul.MockFactory{ - MockError: fmt.Errorf("connection failed"), - } - consul.SetFactory(mockFactory) - - // Create a new command instance for this test - cmd := &cobra.Command{ - Use: "run", - RunE: runCmd.RunE, - } - // Set up parent command for flags inheritance - parent := &cobra.Command{} - parent.PersistentFlags().String("consul-addr", "127.0.0.1:8500", "") - parent.PersistentFlags().String("token", "", "") - parent.PersistentFlags().String("service-id", "test-service", "") - parent.PersistentFlags().String("tag-prefix", "test", "") - parent.PersistentFlags().String("script", "echo 'tag1'", "") - parent.PersistentFlags().String("interval", "1s", "") - parent.AddCommand(cmd) - - // Run the command - should fail - err := cmd.RunE(cmd, []string{}) - assert.Error(t, err) - }) +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) } diff --git a/cmd/systemd.go b/cmd/systemd.go index 97ef42f..3712724 100644 --- a/cmd/systemd.go +++ b/cmd/systemd.go @@ -17,10 +17,10 @@ package cmd import ( "fmt" - "os" "github.com/ncode/tagit/pkg/systemd" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) // systemdCmd represents the systemd command @@ -34,40 +34,15 @@ automatically on boot and can be managed using systemctl. Example usage: tagit systemd --service-id=my-service --script=/path/to/script.sh --tag-prefix=tagit --interval=5s --user=tagit --group=tagit `, - Run: func(cmd *cobra.Command, args []string) { - flags := make(map[string]string) - for _, flag := range append(systemd.GetRequiredFlags(), systemd.GetOptionalFlags()...) { - flags[flag], _ = cmd.Flags().GetString(flag) - } - - fields, err := systemd.NewFieldsFromFlags(flags) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } - - serviceFile, err := systemd.RenderTemplate(fields) - if err != nil { - fmt.Fprintf(os.Stderr, "Error generating systemd service file: %v\n", err) - os.Exit(1) - } - - fmt.Println(serviceFile) + RunE: func(cmd *cobra.Command, args []string) error { + return systemdCommand(cmd) }, } func init() { rootCmd.AddCommand(systemdCmd) - // Define flags for all required and optional fields - systemdCmd.Flags().String("service-id", "", "ID of the service (required)") - systemdCmd.Flags().String("script", "", "Path to the script to execute (required)") - systemdCmd.Flags().String("tag-prefix", "", "Prefix for tags (required)") - systemdCmd.Flags().String("interval", "", "Interval for script execution (required)") - systemdCmd.Flags().String("token", "", "Consul token (optional)") - systemdCmd.Flags().String("consul-addr", "", "Consul address (optional)") - systemdCmd.Flags().String("user", "", "User to run the service as (required)") - systemdCmd.Flags().String("group", "", "Group to run the service as (required)") + addSystemdFlags(systemdCmd.Flags()) // Mark required flags systemdCmd.MarkFlagRequired("service-id") @@ -77,3 +52,34 @@ func init() { systemdCmd.MarkFlagRequired("user") systemdCmd.MarkFlagRequired("group") } + +func systemdCommand(cmd *cobra.Command) error { + input, err := resolveSystemdInput(cmd) + if err != nil { + return err + } + + fields, err := systemd.NewFieldsFromInvocation(input.Invocation, input.User, input.Group) + if err != nil { + return err + } + + serviceFile, err := systemd.RenderTemplate(fields) + if err != nil { + return fmt.Errorf("generate systemd service file: %w", err) + } + + fmt.Fprintln(cmd.OutOrStdout(), serviceFile) + return nil +} + +func addSystemdFlags(flags *pflag.FlagSet) { + flags.String("service-id", "", "ID of the service (required)") + flags.String("script", "", "Path to the script to execute (required)") + flags.String("tag-prefix", "", "Prefix for tags (required)") + flags.String("interval", "", "Interval for script execution (required)") + flags.String("token", "", "Consul token (optional)") + flags.String("consul-addr", "", "Consul address (optional)") + flags.String("user", "", "User to run the service as (required)") + flags.String("group", "", "Group to run the service as (required)") +} diff --git a/cmd/systemd_invocation_test.go b/cmd/systemd_invocation_test.go new file mode 100644 index 0000000..36d358d --- /dev/null +++ b/cmd/systemd_invocation_test.go @@ -0,0 +1,110 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func TestResolveSystemdInput_usesRunInputAndSystemdFields(t *testing.T) { + resetViper(t) + cmd := newSystemdIntakeTestCommand() + setFlag(t, cmd.Flags(), "consul-addr", "consul.example:8500") + setFlag(t, cmd.Flags(), "service-id", "api") + setFlag(t, cmd.Flags(), "script", "/opt/tagit/tags.sh") + setFlag(t, cmd.Flags(), "tag-prefix", "role") + setFlag(t, cmd.Flags(), "interval", "15s") + setFlag(t, cmd.Flags(), "token", "secret") + setFlag(t, cmd.Flags(), "user", "tagit") + setFlag(t, cmd.Flags(), "group", "tagit") + + got, err := resolveSystemdInput(cmd) + if err != nil { + t.Fatalf("resolveSystemdInput() error = %v", err) + } + + if got.Invocation.ServiceID != "api" || got.Invocation.Script != "/opt/tagit/tags.sh" || + got.Invocation.TagPrefix != "role" || got.Invocation.Interval != "15s" || + got.Invocation.Token != "secret" || got.Invocation.ConsulAddr != "consul.example:8500" { + t.Fatalf("Invocation = %#v, want resolved run values", got.Invocation) + } + if got.User != "tagit" || got.Group != "tagit" { + t.Fatalf("User/Group = %q/%q, want tagit/tagit", got.User, got.Group) + } +} + +func TestResolveSystemdInput_validatesUserAndGroup(t *testing.T) { + tests := []struct { + name string + user string + group string + wantErr string + }{ + {name: "missing user", group: "tagit", wantErr: "user is required"}, + {name: "missing group", user: "tagit", wantErr: "group is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetViper(t) + cmd := newSystemdIntakeTestCommand() + setFlag(t, cmd.Flags(), "service-id", "api") + setFlag(t, cmd.Flags(), "script", "/opt/tagit/tags.sh") + setFlag(t, cmd.Flags(), "tag-prefix", "role") + setFlag(t, cmd.Flags(), "interval", "15s") + setFlag(t, cmd.Flags(), "user", tt.user) + setFlag(t, cmd.Flags(), "group", tt.group) + + _, err := resolveSystemdInput(cmd) + if err == nil { + t.Fatal("resolveSystemdInput() error = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("resolveSystemdInput() error = %q, want %q", err, tt.wantErr) + } + }) + } +} + +func TestResolveSystemdInput_surfacesRunInputErrors(t *testing.T) { + resetViper(t) + cmd := newSystemdIntakeTestCommand() + setFlag(t, cmd.Flags(), "service-id", "api") + setFlag(t, cmd.Flags(), "tag-prefix", "role") + setFlag(t, cmd.Flags(), "interval", "15s") + setFlag(t, cmd.Flags(), "user", "tagit") + setFlag(t, cmd.Flags(), "group", "tagit") + + _, err := resolveSystemdInput(cmd) + if err == nil { + t.Fatal("resolveSystemdInput() error = nil, want script validation error") + } + if !strings.Contains(err.Error(), "script is required") { + t.Fatalf("resolveSystemdInput() error = %q, want script validation", err) + } +} + +func TestSystemdCommand_returnsFieldValidationErrors(t *testing.T) { + resetViper(t) + cmd := newSystemdIntakeTestCommand() + setFlag(t, cmd.Flags(), "service-id", "api") + setFlag(t, cmd.Flags(), "script", "/opt/tagit/tags.sh") + setFlag(t, cmd.Flags(), "interval", "15s") + setFlag(t, cmd.Flags(), "user", "tagit") + setFlag(t, cmd.Flags(), "group", "tagit") + + err := systemdCommand(cmd) + if err == nil { + t.Fatal("systemdCommand() error = nil, want tag-prefix validation error") + } + if !strings.Contains(err.Error(), "TagPrefix") { + t.Fatalf("systemdCommand() error = %q, want TagPrefix validation", err) + } +} + +func newSystemdIntakeTestCommand() *cobra.Command { + cmd := &cobra.Command{Use: "systemd"} + addSystemdFlags(cmd.Flags()) + return cmd +} diff --git a/cmd/systemd_test.go b/cmd/systemd_test.go index dfa699a..55700a2 100644 --- a/cmd/systemd_test.go +++ b/cmd/systemd_test.go @@ -18,17 +18,10 @@ func setupSystemdCmd() *cobra.Command { systCmd := &cobra.Command{ Use: "systemd", Short: "Generate a systemd service file for TagIt", - Run: systemdCmd.Run, + RunE: systemdCmd.RunE, } - systCmd.Flags().String("service-id", "", "ID of the service (required)") - systCmd.Flags().String("script", "", "Path to the script to execute (required)") - systCmd.Flags().String("tag-prefix", "", "Prefix for tags (required)") - systCmd.Flags().String("interval", "", "Interval for script execution (required)") - systCmd.Flags().String("token", "", "Consul token (optional)") - systCmd.Flags().String("consul-addr", "", "Consul address (optional)") - systCmd.Flags().String("user", "", "User to run the service as (required)") - systCmd.Flags().String("group", "", "Group to run the service as (required)") + addSystemdFlags(systCmd.Flags()) systCmd.MarkFlagRequired("service-id") systCmd.MarkFlagRequired("script") diff --git a/pkg/consul/registration.go b/pkg/consul/registration.go new file mode 100644 index 0000000..ef71ff2 --- /dev/null +++ b/pkg/consul/registration.go @@ -0,0 +1,63 @@ +package consul + +import ( + "fmt" + "maps" + + "github.com/hashicorp/consul/api" +) + +// ServiceRegistration owns lookup and tag writes for one Consul client. +type ServiceRegistration struct { + agent Agent +} + +// NewServiceRegistration returns a registration store backed by client.Agent(). +func NewServiceRegistration(client Client) *ServiceRegistration { + return &ServiceRegistration{agent: client.Agent()} +} + +// Load returns the registered Consul service for serviceID. +func (sr *ServiceRegistration) Load(serviceID string) (*api.AgentService, error) { + service, _, err := sr.agent.Service(serviceID, nil) + if err != nil { + return nil, fmt.Errorf("lookup service %s: %w", serviceID, err) + } + if service == nil { + return nil, fmt.Errorf("service %s not found", serviceID) + } + return service, nil +} + +// UpdateTags writes tags to Consul while preserving all non-tag registration fields. +func (sr *ServiceRegistration) UpdateTags(service *api.AgentService, tags []string, changed bool) error { + if !changed { + return nil + } + if service == nil { + return fmt.Errorf("register service: nil service") + } + + registration := copyServiceToRegistration(service) + registration.Tags = tags + if err := sr.agent.ServiceRegister(registration); err != nil { + return fmt.Errorf("register service %s: %w", service.ID, err) + } + return nil +} + +func copyServiceToRegistration(service *api.AgentService) *api.AgentServiceRegistration { + return &api.AgentServiceRegistration{ + ID: service.ID, + Name: service.Service, + Tags: service.Tags, + Port: service.Port, + Address: service.Address, + Kind: service.Kind, + Meta: maps.Clone(service.Meta), + Weights: &api.AgentWeights{ + Passing: service.Weights.Passing, + Warning: service.Weights.Warning, + }, + } +} diff --git a/pkg/consul/registration_test.go b/pkg/consul/registration_test.go new file mode 100644 index 0000000..2ee0816 --- /dev/null +++ b/pkg/consul/registration_test.go @@ -0,0 +1,202 @@ +package consul + +import ( + "fmt" + "strings" + "testing" + + "github.com/hashicorp/consul/api" +) + +func TestServiceRegistration_Load(t *testing.T) { + t.Run("service is found", func(t *testing.T) { + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + serviceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return &api.AgentService{ID: serviceID, Service: "api"}, nil, nil + }, + }, + }) + + got, err := store.Load("api-1") + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if got.ID != "api-1" { + t.Fatalf("Load() ID = %q, want api-1", got.ID) + } + }) + + t.Run("service is missing", func(t *testing.T) { + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + serviceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return nil, nil, nil + }, + }, + }) + + _, err := store.Load("missing") + if err == nil { + t.Fatal("Load() error = nil, want error") + } + if !strings.Contains(err.Error(), "missing") { + t.Fatalf("Load() error = %q, want service ID", err) + } + }) + + t.Run("lookup error has operation context", func(t *testing.T) { + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + serviceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return nil, nil, fmt.Errorf("consul unavailable") + }, + }, + }) + + _, err := store.Load("api-1") + if err == nil { + t.Fatal("Load() error = nil, want error") + } + if !strings.Contains(err.Error(), "lookup service api-1") { + t.Fatalf("Load() error = %q, want lookup context", err) + } + }) +} + +func TestServiceRegistration_UpdateTags(t *testing.T) { + service := &api.AgentService{ + ID: "api-1", + Service: "api", + Tags: []string{"old"}, + Address: "10.0.0.2", + Port: 8080, + Kind: api.ServiceKindTypical, + Meta: map[string]string{"version": "1"}, + Weights: api.AgentWeights{ + Passing: 10, + Warning: 1, + }, + } + + var got *api.AgentServiceRegistration + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + registerFunc: func(reg *api.AgentServiceRegistration) error { + got = reg + return nil + }, + }, + }) + + if err := store.UpdateTags(service, []string{"new", "static"}, true); err != nil { + t.Fatalf("UpdateTags() error = %v", err) + } + + if got == nil { + t.Fatal("ServiceRegister was not called") + } + if got.ID != service.ID || got.Name != service.Service || got.Address != service.Address || got.Port != service.Port || got.Kind != service.Kind { + t.Fatalf("registration fields were not preserved: %#v", got) + } + if got.Meta["version"] != "1" { + t.Fatalf("Meta = %#v, want version preserved", got.Meta) + } + if got.Weights == nil || got.Weights.Passing != 10 || got.Weights.Warning != 1 { + t.Fatalf("Weights = %#v, want preserved weights", got.Weights) + } + if want := []string{"new", "static"}; !equalStrings(got.Tags, want) { + t.Fatalf("Tags = %v, want %v", got.Tags, want) + } +} + +func TestServiceRegistration_UpdateTagsSkipsUnchanged(t *testing.T) { + registerCalls := 0 + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + registerFunc: func(reg *api.AgentServiceRegistration) error { + registerCalls++ + return nil + }, + }, + }) + + err := store.UpdateTags(&api.AgentService{ID: "api-1"}, []string{"static"}, false) + if err != nil { + t.Fatalf("UpdateTags() error = %v", err) + } + if registerCalls != 0 { + t.Fatalf("ServiceRegister calls = %d, want 0", registerCalls) + } +} + +func TestServiceRegistration_UpdateTagsRejectsNilService(t *testing.T) { + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{}, + }) + + err := store.UpdateTags(nil, []string{"new"}, true) + if err == nil { + t.Fatal("UpdateTags() error = nil, want error") + } + if !strings.Contains(err.Error(), "nil service") { + t.Fatalf("UpdateTags() error = %q, want nil service context", err) + } +} + +func TestServiceRegistration_UpdateTagsWriteError(t *testing.T) { + store := NewServiceRegistration(®istrationClient{ + agent: ®istrationAgent{ + registerFunc: func(reg *api.AgentServiceRegistration) error { + return fmt.Errorf("permission denied") + }, + }, + }) + + err := store.UpdateTags(&api.AgentService{ID: "api-1"}, []string{"new"}, true) + if err == nil { + t.Fatal("UpdateTags() error = nil, want error") + } + if !strings.Contains(err.Error(), "register service api-1") { + t.Fatalf("UpdateTags() error = %q, want register context", err) + } +} + +type registrationClient struct { + agent Agent +} + +func (rc *registrationClient) Agent() Agent { + return rc.agent +} + +type registrationAgent struct { + serviceFunc func(string, *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) + registerFunc func(*api.AgentServiceRegistration) error +} + +func (ra *registrationAgent) Service(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + if ra.serviceFunc == nil { + return nil, nil, nil + } + return ra.serviceFunc(serviceID, q) +} + +func (ra *registrationAgent) ServiceRegister(reg *api.AgentServiceRegistration) error { + if ra.registerFunc == nil { + return nil + } + return ra.registerFunc(reg) +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/pkg/systemd/invocation_test.go b/pkg/systemd/invocation_test.go new file mode 100644 index 0000000..d5b5601 --- /dev/null +++ b/pkg/systemd/invocation_test.go @@ -0,0 +1,90 @@ +package systemd + +import ( + "strings" + "testing" +) + +func TestRenderInvocation(t *testing.T) { + tests := []struct { + name string + invocation Invocation + want string + }{ + { + name: "required invocation values", + invocation: Invocation{ + ServiceID: "api", + Script: "/opt/tagit/tags.sh", + TagPrefix: "role", + Interval: "15s", + }, + want: "/usr/bin/tagit run -s api -x /opt/tagit/tags.sh -p role -i 15s", + }, + { + name: "optional invocation values", + invocation: Invocation{ + ServiceID: "api", + Script: "/opt/tagit/tags.sh", + TagPrefix: "role", + Interval: "15s", + Token: "secret", + ConsulAddr: "consul.example:8500", + }, + want: "/usr/bin/tagit run -s api -x /opt/tagit/tags.sh -p role -i 15s -t secret -c consul.example:8500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RenderInvocation(tt.invocation) + if got != tt.want { + t.Fatalf("RenderInvocation() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNewFieldsFromInvocation_validatesSystemdFields(t *testing.T) { + invocation := Invocation{ + ServiceID: "api", + Script: "/opt/tagit/tags.sh", + TagPrefix: "role", + Interval: "15s", + } + + _, err := NewFieldsFromInvocation(invocation, "", "tagit") + if err == nil { + t.Fatal("NewFieldsFromInvocation() error = nil, want missing user error") + } + if !strings.Contains(err.Error(), "User") { + t.Fatalf("NewFieldsFromInvocation() error = %q, want User", err) + } + + _, err = NewFieldsFromInvocation(invocation, "tagit", "") + if err == nil { + t.Fatal("NewFieldsFromInvocation() error = nil, want missing group error") + } + if !strings.Contains(err.Error(), "Group") { + t.Fatalf("NewFieldsFromInvocation() error = %q, want Group", err) + } +} + +func TestNewFieldsFromInvocation_setsExecStart(t *testing.T) { + fields, err := NewFieldsFromInvocation(Invocation{ + ServiceID: "api", + Script: "/opt/tagit/tags.sh", + TagPrefix: "role", + Interval: "15s", + Token: "secret", + ConsulAddr: "consul.example:8500", + }, "tagit", "tagit") + if err != nil { + t.Fatalf("NewFieldsFromInvocation() error = %v", err) + } + + want := "/usr/bin/tagit run -s api -x /opt/tagit/tags.sh -p role -i 15s -t secret -c consul.example:8500" + if fields.ExecStart != want { + t.Fatalf("ExecStart = %q, want %q", fields.ExecStart, want) + } +} diff --git a/pkg/systemd/systemd.go b/pkg/systemd/systemd.go index 0b5394a..293ebd2 100644 --- a/pkg/systemd/systemd.go +++ b/pkg/systemd/systemd.go @@ -18,7 +18,7 @@ Wants=network-online.target [Service] Type=simple -ExecStart=/usr/bin/tagit run -s {{ .ServiceID }} -x {{ .Script }} -p {{ .TagPrefix }} -i {{ .Interval }}{{ if .Token }} -t {{ .Token }}{{ end }}{{ if .ConsulAddr }} -c {{ .ConsulAddr }}{{ end }} +ExecStart={{ .ExecStart }} Environment=HOME=/var/run/tagit/{{ .ServiceID }} Restart=always User={{ .User }} @@ -37,10 +37,21 @@ type Fields struct { Interval string Token string ConsulAddr string + ExecStart string User string Group string } +// Invocation is the TagIt run invocation rendered into systemd ExecStart. +type Invocation struct { + ServiceID string + Script string + TagPrefix string + Interval string + Token string + ConsulAddr string +} + var parsedTemplate *template.Template func init() { @@ -53,6 +64,19 @@ func init() { // RenderTemplate renders the template for the systemd service. func RenderTemplate(fields *Fields) (string, error) { + if fields != nil && fields.ExecStart == "" { + clone := *fields + clone.ExecStart = RenderInvocation(Invocation{ + ServiceID: clone.ServiceID, + Script: clone.Script, + TagPrefix: clone.TagPrefix, + Interval: clone.Interval, + Token: clone.Token, + ConsulAddr: clone.ConsulAddr, + }) + fields = &clone + } + if err := validateFields(fields); err != nil { return "", fmt.Errorf("field validation failed: %w", err) } @@ -95,17 +119,37 @@ func validateFields(fields *Fields) error { return nil } -// NewFieldsFromFlags creates a new Fields struct from command line flags. -func NewFieldsFromFlags(flags map[string]string) (*Fields, error) { +// RenderInvocation renders the tagit run command used by systemd ExecStart. +func RenderInvocation(invocation Invocation) string { + parts := []string{ + "/usr/bin/tagit", + "run", + "-s", invocation.ServiceID, + "-x", invocation.Script, + "-p", invocation.TagPrefix, + "-i", invocation.Interval, + } + if invocation.Token != "" { + parts = append(parts, "-t", invocation.Token) + } + if invocation.ConsulAddr != "" { + parts = append(parts, "-c", invocation.ConsulAddr) + } + return strings.Join(parts, " ") +} + +// NewFieldsFromInvocation creates systemd fields from a validated TagIt invocation. +func NewFieldsFromInvocation(invocation Invocation, user, group string) (*Fields, error) { fields := &Fields{ - ServiceID: flags["service-id"], - Script: flags["script"], - TagPrefix: flags["tag-prefix"], - Interval: flags["interval"], - Token: flags["token"], - ConsulAddr: flags["consul-addr"], - User: flags["user"], - Group: flags["group"], + ServiceID: invocation.ServiceID, + Script: invocation.Script, + TagPrefix: invocation.TagPrefix, + Interval: invocation.Interval, + Token: invocation.Token, + ConsulAddr: invocation.ConsulAddr, + ExecStart: RenderInvocation(invocation), + User: user, + Group: group, } if err := validateFields(fields); err != nil { @@ -115,6 +159,18 @@ func NewFieldsFromFlags(flags map[string]string) (*Fields, error) { return fields, nil } +// NewFieldsFromFlags creates a new Fields struct from command line flags. +func NewFieldsFromFlags(flags map[string]string) (*Fields, error) { + return NewFieldsFromInvocation(Invocation{ + ServiceID: flags["service-id"], + Script: flags["script"], + TagPrefix: flags["tag-prefix"], + Interval: flags["interval"], + Token: flags["token"], + ConsulAddr: flags["consul-addr"], + }, flags["user"], flags["group"]) +} + // GetRequiredFlags returns a list of required flag names. func GetRequiredFlags() []string { return []string{"service-id", "script", "tag-prefix", "interval", "user", "group"} diff --git a/pkg/tagit/reconciliation.go b/pkg/tagit/reconciliation.go new file mode 100644 index 0000000..eceec26 --- /dev/null +++ b/pkg/tagit/reconciliation.go @@ -0,0 +1,82 @@ +package tagit + +import ( + "fmt" + "slices" + "strings" +) + +// Reconciler applies TagIt's managed tag policy for one tag prefix. +type Reconciler struct { + prefix string +} + +// Reconciliation describes the stable tag set and whether Consul needs a write. +type Reconciliation struct { + Tags []string + Changed bool +} + +// NewReconciler returns a managed tag reconciler for prefix-value tags. +func NewReconciler(prefix string) Reconciler { + return Reconciler{prefix: prefix} +} + +// Reconcile converts script output into managed tags and merges them with unmanaged tags. +func (r Reconciler) Reconcile(current []string, output []byte) Reconciliation { + return r.ReconcileManaged(current, r.managedTags(strings.Fields(string(output)))) +} + +// ReconcileManaged merges desired managed tags with current unmanaged tags. +func (r Reconciler) ReconcileManaged(current, desiredManaged []string) Reconciliation { + unmanaged, currentManaged := r.split(current) + desiredManaged = stableTags(desiredManaged) + tags := stableTags(append(unmanaged, desiredManaged...)) + + return Reconciliation{ + Tags: tags, + Changed: !sameTags(stableTags(currentManaged), desiredManaged), + } +} + +// Cleanup removes only tags owned by this managed tag prefix. +func (r Reconciler) Cleanup(current []string) Reconciliation { + unmanaged, currentManaged := r.split(current) + return Reconciliation{ + Tags: stableTags(unmanaged), + Changed: len(currentManaged) > 0, + } +} + +func (r Reconciler) managedTags(tokens []string) []string { + tags := make([]string, 0, len(tokens)) + for _, tag := range tokens { + tags = append(tags, fmt.Sprintf("%s-%s", r.prefix, tag)) + } + return stableTags(tags) +} + +func (r Reconciler) split(tags []string) (unmanaged, managed []string) { + managedPrefix := r.prefix + "-" + for _, tag := range tags { + if strings.HasPrefix(tag, managedPrefix) { + managed = append(managed, tag) + } else { + unmanaged = append(unmanaged, tag) + } + } + return unmanaged, managed +} + +func stableTags(tags []string) []string { + if len(tags) == 0 { + return nil + } + stable := slices.Clone(tags) + slices.Sort(stable) + return slices.Compact(stable) +} + +func sameTags(a, b []string) bool { + return slices.Equal(a, b) +} diff --git a/pkg/tagit/reconciliation_test.go b/pkg/tagit/reconciliation_test.go new file mode 100644 index 0000000..5b76f60 --- /dev/null +++ b/pkg/tagit/reconciliation_test.go @@ -0,0 +1,85 @@ +package tagit + +import ( + "slices" + "testing" +) + +func TestReconciler_Reconcile(t *testing.T) { + tests := []struct { + name string + current []string + output []byte + wantTags []string + wantChanged bool + }{ + { + name: "script output becomes managed tags", + output: []byte("primary replica"), + wantTags: []string{"role-primary", "role-replica"}, + wantChanged: true, + }, + { + name: "unmanaged tags are preserved", + current: []string{"static"}, + output: []byte("primary"), + wantTags: []string{"role-primary", "static"}, + wantChanged: true, + }, + { + name: "stale managed tags are removed", + current: []string{"role-old", "static"}, + output: []byte("new"), + wantTags: []string{"role-new", "static"}, + wantChanged: true, + }, + { + name: "identical managed tags do not need a write", + current: []string{"static", "role-primary"}, + output: []byte("primary"), + wantTags: []string{"role-primary", "static"}, + wantChanged: false, + }, + { + name: "duplicate desired tags are compacted in stable order", + output: []byte("replica primary primary"), + wantTags: []string{"role-primary", "role-replica"}, + wantChanged: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewReconciler("role").Reconcile(tt.current, tt.output) + + if !slices.Equal(got.Tags, tt.wantTags) { + t.Fatalf("Tags = %v, want %v", got.Tags, tt.wantTags) + } + if got.Changed != tt.wantChanged { + t.Fatalf("Changed = %v, want %v", got.Changed, tt.wantChanged) + } + }) + } +} + +func TestReconciler_Cleanup(t *testing.T) { + got := NewReconciler("role").Cleanup([]string{"role-primary", "static", "role-replica"}) + + if want := []string{"static"}; !slices.Equal(got.Tags, want) { + t.Fatalf("Tags = %v, want %v", got.Tags, want) + } + if !got.Changed { + t.Fatal("Changed = false, want true") + } +} + +func TestReconciler_CleanupNoop(t *testing.T) { + got := NewReconciler("role").Cleanup([]string{"static"}) + + if want := []string{"static"}; !slices.Equal(got.Tags, want) { + t.Fatalf("Tags = %v, want %v", got.Tags, want) + } + if got.Changed { + t.Fatal("Changed = true, want false") + } +} diff --git a/pkg/tagit/scheduler.go b/pkg/tagit/scheduler.go new file mode 100644 index 0000000..ac7a07d --- /dev/null +++ b/pkg/tagit/scheduler.go @@ -0,0 +1,55 @@ +package tagit + +import ( + "context" + "log/slog" + "time" +) + +// Scheduler repeats one reconciliation pass until its context is cancelled. +type Scheduler struct { + Interval time.Duration + Ticks <-chan time.Time + RunOnce func() error + Logger *slog.Logger +} + +// Run starts the scheduler loop. +func (s Scheduler) Run(ctx context.Context) { + if s.RunOnce == nil { + return + } + + ticks := s.Ticks + var ticker *time.Ticker + if ticks == nil { + if s.Interval <= 0 { + s.logger().Error("invalid scheduler interval", "interval", s.Interval) + return + } + ticker = time.NewTicker(s.Interval) + defer ticker.Stop() + ticks = ticker.C + } + + for { + select { + case <-ctx.Done(): + return + case _, ok := <-ticks: + if !ok { + return + } + if err := s.RunOnce(); err != nil { + s.logger().Error("error updating service tags", "error", err) + } + } + } +} + +func (s Scheduler) logger() *slog.Logger { + if s.Logger != nil { + return s.Logger + } + return slog.Default() +} diff --git a/pkg/tagit/scheduler_test.go b/pkg/tagit/scheduler_test.go new file mode 100644 index 0000000..efe074a --- /dev/null +++ b/pkg/tagit/scheduler_test.go @@ -0,0 +1,257 @@ +package tagit + +import ( + "context" + "fmt" + "io" + "log/slog" + "testing" + "time" + + "github.com/hashicorp/consul/api" +) + +func TestTagIt_ReconcileOnce(t *testing.T) { + var gotTags []string + tagit := New( + &MockConsulClient{ + MockAgent: &MockAgent{ + ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return &api.AgentService{ + ID: serviceID, + Service: "api", + Tags: []string{"role-old", "static"}, + }, nil, nil + }, + ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { + gotTags = reg.Tags + return nil + }, + }, + }, + &MockCommandExecutor{MockOutput: []byte("primary")}, + "api-1", + "echo primary", + time.Minute, + "role", + discardTagitLogger(), + ) + + if err := tagit.ReconcileOnce(); err != nil { + t.Fatalf("ReconcileOnce() error = %v", err) + } + + want := []string{"role-primary", "static"} + if !sameStringSlice(gotTags, want) { + t.Fatalf("registered tags = %v, want %v", gotTags, want) + } +} + +func TestTagIt_ReconcileOnceSurfacesScriptFailure(t *testing.T) { + tagit := New( + &MockConsulClient{ + MockAgent: &MockAgent{ + ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return &api.AgentService{ID: serviceID}, nil, nil + }, + }, + }, + &MockCommandExecutor{MockError: fmt.Errorf("script failed")}, + "api-1", + "bad-script", + time.Minute, + "role", + discardTagitLogger(), + ) + + err := tagit.ReconcileOnce() + if err == nil { + t.Fatal("ReconcileOnce() error = nil, want error") + } + if got := err.Error(); got != "error running script: script failed" { + t.Fatalf("ReconcileOnce() error = %q, want script context", got) + } +} + +func TestScheduler_RunUsesTriggeredTicks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ticks := make(chan time.Time) + calls := make(chan struct{}, 2) + done := make(chan struct{}) + scheduler := Scheduler{ + Interval: time.Minute, + Ticks: ticks, + RunOnce: func() error { + calls <- struct{}{} + return nil + }, + Logger: discardTagitLogger(), + } + + go func() { + defer close(done) + scheduler.Run(ctx) + }() + + ticks <- time.Now() + <-calls + ticks <- time.Now() + <-calls + + cancel() + <-done +} + +func TestScheduler_RunLogsAndContinuesAfterError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ticks := make(chan time.Time) + calls := 0 + seenCalls := make(chan int, 2) + done := make(chan struct{}) + scheduler := Scheduler{ + Interval: time.Minute, + Ticks: ticks, + RunOnce: func() error { + calls++ + seenCalls <- calls + if calls == 1 { + return fmt.Errorf("temporary") + } + return nil + }, + Logger: discardTagitLogger(), + } + + go func() { + defer close(done) + scheduler.Run(ctx) + }() + + ticks <- time.Now() + if got := <-seenCalls; got != 1 { + t.Fatalf("first call = %d, want 1", got) + } + ticks <- time.Now() + if got := <-seenCalls; got != 2 { + t.Fatalf("second call = %d, want 2", got) + } + + cancel() + <-done +} + +func TestScheduler_RunStopsWhenContextIsCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + calls := 0 + scheduler := Scheduler{ + Interval: time.Minute, + Ticks: make(chan time.Time), + RunOnce: func() error { + calls++ + return nil + }, + Logger: discardTagitLogger(), + } + + scheduler.Run(ctx) + + if calls != 0 { + t.Fatalf("RunOnce calls = %d, want 0", calls) + } +} + +func TestScheduler_RunReturnsWhenRunOnceIsMissing(t *testing.T) { + scheduler := Scheduler{ + Interval: time.Minute, + Ticks: make(chan time.Time), + } + + scheduler.Run(t.Context()) +} + +func TestScheduler_RunReturnsWhenTickChannelCloses(t *testing.T) { + ticks := make(chan time.Time) + close(ticks) + + calls := 0 + scheduler := Scheduler{ + Interval: time.Minute, + Ticks: ticks, + RunOnce: func() error { + calls++ + return nil + }, + Logger: discardTagitLogger(), + } + + scheduler.Run(t.Context()) + + if calls != 0 { + t.Fatalf("RunOnce calls = %d, want 0", calls) + } +} + +func TestScheduler_RunCreatesTickerWhenTicksUnset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + calls := 0 + scheduler := Scheduler{ + Interval: time.Hour, + RunOnce: func() error { + calls++ + return nil + }, + Logger: discardTagitLogger(), + } + + scheduler.Run(ctx) + + if calls != 0 { + t.Fatalf("RunOnce calls = %d, want 0", calls) + } +} + +func TestScheduler_RunRejectsInvalidGeneratedTicker(t *testing.T) { + previous := slog.Default() + slog.SetDefault(discardTagitLogger()) + t.Cleanup(func() { + slog.SetDefault(previous) + }) + + calls := 0 + scheduler := Scheduler{ + RunOnce: func() error { + calls++ + return nil + }, + } + + scheduler.Run(t.Context()) + + if calls != 0 { + t.Fatalf("RunOnce calls = %d, want 0", calls) + } +} + +func discardTagitLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func sameStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/pkg/tagit/tagit.go b/pkg/tagit/tagit.go index c59d7e2..26f4189 100644 --- a/pkg/tagit/tagit.go +++ b/pkg/tagit/tagit.go @@ -5,8 +5,6 @@ import ( "fmt" "log/slog" "os/exec" - "slices" - "strings" "time" "github.com/google/shlex" @@ -21,6 +19,7 @@ type TagIt struct { Interval time.Duration TagPrefix string client consul.Client + registration *consul.ServiceRegistration commandExecutor CommandExecutor logger *slog.Logger } @@ -72,6 +71,7 @@ func New(consulClient consul.Client, commandExecutor CommandExecutor, serviceID Interval: interval, TagPrefix: tagPrefix, client: consulClient, + registration: consul.NewServiceRegistration(consulClient), commandExecutor: commandExecutor, logger: logger, } @@ -79,21 +79,16 @@ func New(consulClient consul.Client, commandExecutor CommandExecutor, serviceID // Run will run the tagit flow and tag consul services based on the script output func (t *TagIt) Run(ctx context.Context) { - ticker := time.NewTicker(t.Interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := t.updateServiceTags(); err != nil { - t.logger.Error("error updating service tags", - "service", t.ServiceID, - "error", err) - } - } - } + Scheduler{ + Interval: t.Interval, + RunOnce: t.ReconcileOnce, + Logger: t.logger, + }.Run(ctx) +} + +// ReconcileOnce runs one deterministic service tag reconciliation pass. +func (t *TagIt) ReconcileOnce() error { + return t.updateServiceTags() } // CleanupTags removes all tags with the given prefix from the service. @@ -103,16 +98,8 @@ func (t *TagIt) CleanupTags() error { return fmt.Errorf("error getting service: %w", err) } - // Filter out tags with the specified prefix - cleanedTags := make([]string, 0) - for _, tag := range service.Tags { - if !strings.HasPrefix(tag, t.TagPrefix+"-") { - cleanedTags = append(cleanedTags, tag) - } - } - - // Update the service with the cleaned tags - if err := t.updateConsulService(service, cleanedTags); err != nil { + reconciliation := NewReconciler(t.TagPrefix).Cleanup(service.Tags) + if err := t.registrationStore().UpdateTags(service, reconciliation.Tags, reconciliation.Changed); err != nil { return fmt.Errorf("error cleaning up tags: %w", err) } @@ -134,12 +121,13 @@ func (t *TagIt) updateServiceTags() error { return fmt.Errorf("error getting service: %w", err) } - newTags, err := t.generateNewTags() + out, err := t.runScript() if err != nil { - return fmt.Errorf("error generating new tags: %w", err) + return fmt.Errorf("error running script: %w", err) } - if err := t.updateConsulService(service, newTags); err != nil { + reconciliation := NewReconciler(t.TagPrefix).Reconcile(service.Tags, out) + if err := t.registrationStore().UpdateTags(service, reconciliation.Tags, reconciliation.Changed); err != nil { return fmt.Errorf("error updating service in Consul: %w", err) } @@ -157,27 +145,18 @@ func (t *TagIt) generateNewTags() ([]string, error) { // updateConsulService updates the service in Consul with the new tags. func (t *TagIt) updateConsulService(service *api.AgentService, newTags []string) error { - registration := t.copyServiceToRegistration(service) - updatedTags, shouldTag := t.needsTag(registration.Tags, newTags) - if shouldTag { - registration.Tags = updatedTags - if err := t.client.Agent().ServiceRegister(registration); err != nil { - return fmt.Errorf("error registering service: %w", err) - } - t.logger.Info("updated service tags", - "service", t.ServiceID, - "tags", updatedTags) + if err := t.registrationStore().UpdateTags(service, newTags, true); err != nil { + return err } + t.logger.Info("updated service tags", + "service", t.ServiceID, + "tags", newTags) return nil } // parseScriptOutput parses the script output and generates tags. func (t *TagIt) parseScriptOutput(output []byte) []string { - var tags []string - for _, tag := range strings.Fields(string(output)) { - tags = append(tags, fmt.Sprintf("%s-%s", t.TagPrefix, tag)) - } - return tags + return NewReconciler(t.TagPrefix).Reconcile(nil, output).Tags } // copyServiceToRegistration copies *api.AgentService to *api.AgentServiceRegistration @@ -200,42 +179,24 @@ func (t *TagIt) copyServiceToRegistration(service *api.AgentService) *api.AgentS // getService returns the registered service. func (t *TagIt) getService() (*api.AgentService, error) { - agent := t.client.Agent() - service, _, err := agent.Service(t.ServiceID, nil) - if err != nil { - return nil, fmt.Errorf("error getting service %s: %w", t.ServiceID, err) - } - if service == nil { - return nil, fmt.Errorf("service %s not found", t.ServiceID) + return t.registrationStore().Load(t.ServiceID) +} + +func (t *TagIt) registrationStore() *consul.ServiceRegistration { + if t.registration != nil { + return t.registration } - return service, nil + return consul.NewServiceRegistration(t.client) } // needsTag checks if the service needs to be tagged. Based on the diff of the current and updated tags, filtering out tags that are already tagged. // but we never override the original tags from the consul service registration func (t *TagIt) needsTag(current []string, update []string) (updatedTags []string, shouldTag bool) { - // Extract only the prefixed tags from current for comparison - currentPrefixed := make([]string, 0) - currentNonPrefixed := make([]string, 0) - for _, tag := range current { - if strings.HasPrefix(tag, t.TagPrefix+"-") { - currentPrefixed = append(currentPrefixed, tag) - } else { - currentNonPrefixed = append(currentNonPrefixed, tag) - } - } - - // Compare only the prefixed tags with the update - diff := t.diffTags(currentPrefixed, update) - if len(diff) == 0 { + reconciliation := NewReconciler(t.TagPrefix).ReconcileManaged(current, update) + if !reconciliation.Changed { return nil, false } - - // Combine non-prefixed tags with the new update tags - updatedTags = append(currentNonPrefixed, update...) - slices.Sort(updatedTags) - updatedTags = slices.Compact(updatedTags) - return updatedTags, true + return reconciliation.Tags, true } // diffTags compares two slices of strings and returns the difference. diff --git a/pkg/tagit/tagit_test.go b/pkg/tagit/tagit_test.go index 9c28681..d0b3439 100644 --- a/pkg/tagit/tagit_test.go +++ b/pkg/tagit/tagit_test.go @@ -6,7 +6,7 @@ import ( "io" "log/slog" "sort" - "sync/atomic" + "strings" "testing" "time" @@ -257,6 +257,192 @@ func TestRunScript(t *testing.T) { } } +func TestGenerateNewTags(t *testing.T) { + tests := []struct { + name string + output string + execErr error + want []string + wantErr bool + wantErrMsg string + }{ + { + name: "parses script output", + output: "beta alpha alpha", + want: []string{"role-alpha", "role-beta"}, + }, + { + name: "wraps script errors", + execErr: fmt.Errorf("permission denied"), + wantErr: true, + wantErrMsg: "error running script: permission denied", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tagit := TagIt{ + Script: "tags.sh", + TagPrefix: "role", + commandExecutor: &MockCommandExecutor{MockOutput: []byte(tt.output), MockError: tt.execErr}, + logger: discardTagitLogger(), + } + + got, err := tagit.generateNewTags() + if tt.wantErr { + if err == nil { + t.Fatal("generateNewTags() error = nil, want error") + } + if err.Error() != tt.wantErrMsg { + t.Fatalf("generateNewTags() error = %q, want %q", err, tt.wantErrMsg) + } + return + } + if err != nil { + t.Fatalf("generateNewTags() error = %v", err) + } + if !sameStringSlice(got, tt.want) { + t.Fatalf("generateNewTags() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseScriptOutput(t *testing.T) { + tagit := TagIt{TagPrefix: "role"} + + got := tagit.parseScriptOutput([]byte("api worker api")) + want := []string{"role-api", "role-worker"} + if !sameStringSlice(got, want) { + t.Fatalf("parseScriptOutput() = %v, want %v", got, want) + } +} + +func TestUpdateConsulService(t *testing.T) { + tests := []struct { + name string + registerErr error + wantErr bool + wantErrSubstr string + }{ + {name: "writes updated tags"}, + { + name: "returns registration errors", + registerErr: fmt.Errorf("permission denied"), + wantErr: true, + wantErrSubstr: "permission denied", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got *api.AgentServiceRegistration + tagit := New( + &MockConsulClient{ + MockAgent: &MockAgent{ + ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { + got = reg + return tt.registerErr + }, + }, + }, + &MockCommandExecutor{}, + "api-1", + "tags.sh", + time.Minute, + "role", + discardTagitLogger(), + ) + service := &api.AgentService{ + ID: "api-1", + Service: "api", + Tags: []string{"old"}, + } + + err := tagit.updateConsulService(service, []string{"role-primary", "static"}) + if tt.wantErr { + if err == nil { + t.Fatal("updateConsulService() error = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("updateConsulService() error = %q, want %q", err, tt.wantErrSubstr) + } + return + } + if err != nil { + t.Fatalf("updateConsulService() error = %v", err) + } + if got == nil { + t.Fatal("ServiceRegister was not called") + } + want := []string{"role-primary", "static"} + if !sameStringSlice(got.Tags, want) { + t.Fatalf("registered tags = %v, want %v", got.Tags, want) + } + }) + } +} + +func TestRunReturnsWhenContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + tagit := TagIt{ + Interval: time.Hour, + commandExecutor: &MockCommandExecutor{}, + logger: discardTagitLogger(), + } + + tagit.Run(ctx) +} + +func TestGetServiceUsesFallbackRegistrationStore(t *testing.T) { + tagit := TagIt{ + ServiceID: "api-1", + client: &MockConsulClient{ + MockAgent: &MockAgent{ + ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return &api.AgentService{ID: serviceID, Service: "api"}, nil, nil + }, + }, + }, + } + + got, err := tagit.getService() + if err != nil { + t.Fatalf("getService() error = %v", err) + } + if got.ID != "api-1" { + t.Fatalf("getService() ID = %q, want api-1", got.ID) + } +} + +func TestUpdateServiceTags_returnsServiceLookupErrors(t *testing.T) { + tagit := New( + &MockConsulClient{ + MockAgent: &MockAgent{ + ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { + return nil, nil, fmt.Errorf("consul unavailable") + }, + }, + }, + &MockCommandExecutor{}, + "api-1", + "tags.sh", + time.Minute, + "role", + discardTagitLogger(), + ) + + err := tagit.updateServiceTags() + if err == nil { + t.Fatal("updateServiceTags() error = nil, want error") + } + if !strings.Contains(err.Error(), "error getting service") { + t.Fatalf("updateServiceTags() error = %q, want service lookup context", err) + } +} + func TestNew(t *testing.T) { mockConsulClient := &MockConsulClient{} mockCommandExecutor := &MockCommandExecutor{} @@ -508,47 +694,6 @@ func TestCleanupTags(t *testing.T) { } } -func TestRun(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - updateServiceTagsCalled := atomic.Int32{} - mockExecutor := &MockCommandExecutor{ - MockOutput: []byte("new-tag1 new-tag2"), - MockError: nil, - } - mockConsulClient := &MockConsulClient{ - MockAgent: &MockAgent{ - ServiceFunc: func(serviceID string, q *api.QueryOptions) (*api.AgentService, *api.QueryMeta, error) { - updateServiceTagsCalled.Add(1) - if updateServiceTagsCalled.Load() == 2 { - return nil, nil, fmt.Errorf("simulated error") - } - return &api.AgentService{ - ID: "test-service", - Tags: []string{"old-tag"}, - }, nil, nil - }, - ServiceRegisterFunc: func(reg *api.AgentServiceRegistration) error { - return nil - }, - }, - } - - logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - tagit := New(mockConsulClient, mockExecutor, "test-service", "echo test", 100*time.Millisecond, "tag", logger) - - go tagit.Run(ctx) - - time.Sleep(350 * time.Millisecond) - cancel() - - time.Sleep(50 * time.Millisecond) - - assert.GreaterOrEqual(t, updateServiceTagsCalled.Load(), int32(2), "Expected updateServiceTags to be called at least 2 times") - assert.LessOrEqual(t, updateServiceTagsCalled.Load(), int32(4), "Expected updateServiceTags to be called at most 4 times") -} - func TestConsulInterfaceCompatibility(t *testing.T) { // Test that our mocks implement the consul package interfaces correctly mockAgent := &MockAgent{