From de4153d9d2776a36e11469e1cb9a56d4c65349f9 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Tue, 19 May 2026 14:31:55 -0700 Subject: [PATCH 1/6] fix: update ssh port test --- pkg/cmd/util/externalnode_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/cmd/util/externalnode_test.go b/pkg/cmd/util/externalnode_test.go index c8521b83d..c842634ba 100644 --- a/pkg/cmd/util/externalnode_test.go +++ b/pkg/cmd/util/externalnode_test.go @@ -83,8 +83,8 @@ func TestResolveExternalNodeSSH_UsesServerPortNotPortNumber(t *testing.T) { Ports: []*nodev1.Port{ { Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, - PortNumber: 22, // well-known port — NOT what we connect to - ServerPort: 41920, // netbird-assigned port — this is correct + PortNumber: 41920, // netbird-assigned port — this is correct + ServerPort: 22, // well-known port — NOT what we connect to Hostname: strPtr("gateway.example.com"), }, }, From c82da6c9d91f19e6e90c26db6f32826c7a362b67 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Tue, 19 May 2026 14:42:26 -0700 Subject: [PATCH 2/6] fix --- pkg/cmd/refresh/refresh_test.go | 4 ++-- pkg/cmd/shell/shell_test.go | 4 ++-- pkg/cmd/util/externalnode_test.go | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/cmd/refresh/refresh_test.go b/pkg/cmd/refresh/refresh_test.go index bf13a84ac..216e785d9 100644 --- a/pkg/cmd/refresh/refresh_test.go +++ b/pkg/cmd/refresh/refresh_test.go @@ -19,8 +19,8 @@ func TestResolveNodeSSHEntry_HappyPath(t *testing.T) { Ports: []*nodev1.Port{ { Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, - PortNumber: 22, - ServerPort: 41920, + PortNumber: 41920, + ServerPort: 22, Hostname: strPtr("10.0.0.5"), }, }, diff --git a/pkg/cmd/shell/shell_test.go b/pkg/cmd/shell/shell_test.go index 172e3b3c9..991ea9c49 100644 --- a/pkg/cmd/shell/shell_test.go +++ b/pkg/cmd/shell/shell_test.go @@ -22,8 +22,8 @@ func TestResolveExternalNodeSSH_BuildsCorrectInfo(t *testing.T) { Ports: []*nodev1.Port{ { Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, - PortNumber: 22, - ServerPort: 41920, + PortNumber: 41920, + ServerPort: 22, Hostname: strPtr("10.0.0.5"), }, }, diff --git a/pkg/cmd/util/externalnode_test.go b/pkg/cmd/util/externalnode_test.go index c842634ba..667972d64 100644 --- a/pkg/cmd/util/externalnode_test.go +++ b/pkg/cmd/util/externalnode_test.go @@ -32,7 +32,7 @@ func (m *mockExternalNodeStore) GetCurrentUser() (*entity.User, error) { func strPtr(s string) *string { return &s } -func makeTestNode(name, userID, linuxUser, hostname string, serverPort int32) *nodev1.ExternalNode { +func makeTestNode(name, userID, linuxUser, hostname string, portNumber int32) *nodev1.ExternalNode { return &nodev1.ExternalNode{ ExternalNodeId: "unode_test", Name: name, @@ -42,8 +42,8 @@ func makeTestNode(name, userID, linuxUser, hostname string, serverPort int32) *n Ports: []*nodev1.Port{ { Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, - PortNumber: 22, - ServerPort: serverPort, + PortNumber: portNumber, + ServerPort: 22, Hostname: &hostname, }, }, From 48fb5934eae774fcded4c81787702eaf817b2b54 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Tue, 19 May 2026 14:49:07 -0700 Subject: [PATCH 3/6] another one --- pkg/cmd/refresh/refresh_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/cmd/refresh/refresh_test.go b/pkg/cmd/refresh/refresh_test.go index 216e785d9..7c53b598c 100644 --- a/pkg/cmd/refresh/refresh_test.go +++ b/pkg/cmd/refresh/refresh_test.go @@ -53,8 +53,8 @@ func TestResolveNodeSSHEntry_UsesServerPortNotPortNumber(t *testing.T) { Ports: []*nodev1.Port{ { Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, - PortNumber: 22, // well-known port — NOT what we should connect to - ServerPort: 51234, // netbird-assigned port — correct + PortNumber: 51234, // netbird-assigned port — correct + ServerPort: 22, // well-known port — NOT what we should connect to Hostname: strPtr("gateway.example.com"), }, }, From 6887ca7b1abe1a1b4dfcf3342b8209eac975fe32 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 20 May 2026 12:41:06 -0700 Subject: [PATCH 4/6] use port ids --- Makefile | 4 +- go.mod | 8 +- go.sum | 12 +- pkg/cmd/enablessh/enablessh.go | 44 +-- pkg/cmd/enablessh/enablessh_test.go | 293 ++--------------- pkg/cmd/grantssh/grantssh.go | 53 +++- pkg/cmd/grantssh/grantssh_test.go | 76 +++++ pkg/cmd/register/register.go | 4 +- pkg/cmd/register/register_test.go | 4 +- pkg/cmd/register/sshkeys.go | 300 ++++++++++++------ pkg/cmd/register/sshkeys_port_resolve_test.go | 137 ++++++++ pkg/cmd/register/sshkeys_port_test.go | 45 +++ pkg/cmd/register/sshkeys_test.go | 259 +++++++-------- pkg/cmd/revokessh/revokessh.go | 63 ++-- pkg/cmd/revokessh/revokessh_test.go | 13 +- 15 files changed, 743 insertions(+), 572 deletions(-) create mode 100644 pkg/cmd/register/sshkeys_port_resolve_test.go create mode 100644 pkg/cmd/register/sshkeys_port_test.go diff --git a/Makefile b/Makefile index 9e3b27b3c..cfe7c7fa0 100644 --- a/Makefile +++ b/Makefile @@ -346,11 +346,11 @@ develop-with-nix: nix develop . .PHONY: update-devplane-deps -update-devplane-deps: ## update devplane dependencies (use: make update-devplane-deps commit=, defaults to latest) +update-devplane-deps: ## update devplane Buf modules (use: make update-devplane-deps commit=, defaults to latest) @COMMIT=$${commit:-latest}; \ echo "Updating devplane dependencies to: $$COMMIT"; \ GOPRIVATE=github.com/brevdev/* go get -u github.com/brevdev/dev-plane@$$COMMIT; \ - go get buf.build/gen/go/brevdev/devplane/grpc/go@$$COMMIT; \ + go get buf.build/gen/go/brevdev/devplane/connectrpc/go@$$COMMIT; \ go get buf.build/gen/go/brevdev/devplane/protocolbuffers/go@$$COMMIT; \ GOPRIVATE=github.com/brevdev/* go mod tidy; \ echo "Successfully updated to $$COMMIT" \ No newline at end of file diff --git a/go.mod b/go.mod index 459fd2ec1..6e88264ce 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/brevdev/brev-cli go 1.25.0 require ( - buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2 - buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260309172248-8105d701fdce.1 - connectrpc.com/connect v1.19.1 + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.2-20260520183101-9f4cb67aff2c.1 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260520183101-9f4cb67aff2c.1 + connectrpc.com/connect v1.19.2 github.com/NVIDIA/go-nvml v0.13.0-1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 @@ -150,7 +150,7 @@ require ( golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.42.0 - golang.org/x/term v0.41.0 // indirect + golang.org/x/term v0.41.0 golang.org/x/time v0.12.0 // indirect google.golang.org/protobuf v1.36.11 gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index dca355ac5..ad346e93d 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2 h1:Sq0kIa/xKzScbJcqB5EbPVhOL0QYHPr3araQaupL2lk= -buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2/go.mod h1:Yh34p9aADmWsKv2umYlMpnCZuBmNBE9N+HImgRriJXM= -buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260309172248-8105d701fdce.1 h1:lWdcuXsXpMfPOer4yawjwomVbtSAnGgFAWYF8ggK9g4= -buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260309172248-8105d701fdce.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.2-20260520183101-9f4cb67aff2c.1 h1:OtdZWOk/dypzAe4bylO+TFfcw9J3Ndyeh1yylWSNgRc= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.2-20260520183101-9f4cb67aff2c.1/go.mod h1:eaa0R5ozu4wxcy62DEtRxO6hahJ0WuFsMAG33Zj/lVQ= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260520183101-9f4cb67aff2c.1 h1:fDUuYv/K3h8IpEGf0uic/1/A1nBN+Vao4jzVWDRMLLc= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260520183101-9f4cb67aff2c.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= @@ -41,8 +41,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= -connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= -connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= +connectrpc.com/connect v1.19.2 h1:McQ83FGdzL+t60peksi0gXC7MQ/iLKgLduAnThbM0mo= +connectrpc.com/connect v1.19.2/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go index 43a846708..9788b0e6f 100644 --- a/pkg/cmd/enablessh/enablessh.go +++ b/pkg/cmd/enablessh/enablessh.go @@ -33,6 +33,7 @@ type enableSSHDeps struct { platform externalnode.PlatformChecker nodeClients externalnode.NodeClientFactory registrationStore register.RegistrationStore + prompter terminal.Selector } func defaultEnableSSHDeps() enableSSHDeps { @@ -40,6 +41,7 @@ func defaultEnableSSHDeps() enableSSHDeps { platform: register.LinuxPlatform{}, nodeClients: register.DefaultNodeClientFactory{}, registrationStore: register.NewFileRegistrationStore(), + prompter: register.TerminalPrompter{}, } } @@ -103,27 +105,17 @@ func enableSSH( t.Vprintf(" Linux user: %s\n", linuxUsername) t.Vprint("") - // Check if the node already has an SSH port allocated (e.g. for another linux user) - port, err := existingSSHPort(ctx, deps, tokenProvider, reg) + node, err := fetchRegisteredNode(ctx, deps, tokenProvider, reg) if err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not check for existing ports: %v", err))) + return fmt.Errorf("enable SSH failed: %w", err) } - if port != 0 { - t.Vprintf(" Using existing SSH port %d.\n", port) - } else { - t.Vprint("") - port, err = register.PromptSSHPort(t) - if err != nil { - return fmt.Errorf("invalid SSH port: %w", err) - } - - if err := register.OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port); err != nil { - return fmt.Errorf("enable SSH failed: %w", err) - } + brevPortID, err := register.ResolveSSHAccessPort(ctx, t, deps.prompter, deps.nodeClients, tokenProvider, reg, node) + if err != nil { + return fmt.Errorf("enable SSH failed: %w", err) } - if err := register.SetupAndRegisterNodeSSHAccess(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, linuxUsername); err != nil { + if err := register.SetupAndRegisterNodeSSHAccess(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, linuxUsername, brevPortID); err != nil { return fmt.Errorf("enable SSH failed: %w", err) } @@ -131,25 +123,21 @@ func enableSSH( return nil } -// existingSSHPort calls GetNode and returns the PortNumber of an already-allocated -// SSH port, or 0 if none exists -func existingSSHPort(ctx context.Context, deps enableSSHDeps, tokenProvider externalnode.TokenProvider, reg *register.DeviceRegistration) (int32, error) { +func fetchRegisteredNode( + ctx context.Context, + deps enableSSHDeps, + tokenProvider externalnode.TokenProvider, + reg *register.DeviceRegistration, +) (*nodev1.ExternalNode, error) { client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) resp, err := client.GetNode(ctx, connect.NewRequest(&nodev1.GetNodeRequest{ ExternalNodeId: reg.ExternalNodeID, OrganizationId: reg.OrgID, })) if err != nil { - return 0, fmt.Errorf("error retrieving node: %w", err) - } - - for _, p := range resp.Msg.GetExternalNode().GetPorts() { - // TODO if we ever allow more than one SSH port, this should be modified - if p.GetProtocol() == nodev1.PortProtocol_PORT_PROTOCOL_SSH { - return p.GetPortNumber(), nil - } + return nil, fmt.Errorf("error retrieving node: %w", err) } - return 0, nil + return resp.Msg.GetExternalNode(), nil } // checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index beec62831..7df94144d 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -2,7 +2,6 @@ package enablessh import ( "context" - "fmt" "net/http/httptest" "os" "os/user" @@ -34,146 +33,6 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string { return string(data) } -// --- InstallAuthorizedKey --- - -func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { - u := tempUser(t) - - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - - content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { - t.Errorf("expected key tagged with %q, got:\n%s", register.BrevKeyPrefix, content) - } -} - -func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { - u := tempUser(t) - - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("first install: %v", err) - } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("second install: %v", err) - } - - content := readAuthorizedKeys(t, u) - count := strings.Count(content, "ssh-rsa AAAA testkey") - if count != 1 { - t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) - } -} - -func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { - u := tempUser(t) - - // Pre-seed a tagged key (as if brev already installed it). - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+register.BrevKeyPrefix+"\n"), 0o600); err != nil { - t.Fatal(err) - } - - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - - content := readAuthorizedKeys(t, u) - count := strings.Count(content, "ssh-rsa AAAA testkey") - if count != 1 { - t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) - } -} - -func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { - u := tempUser(t) - - if _, err := register.InstallAuthorizedKey(u, "", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - if _, err := register.InstallAuthorizedKey(u, " ", ""); err != nil { - t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) - } - - // authorized_keys should not exist since nothing was written. - path := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") - if _, err := os.Stat(path); !os.IsNotExist(err) { - t.Errorf("expected authorized_keys to not exist, but it does") - } -} - -func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { - u := tempUser(t) - - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - - info, err := os.Stat(filepath.Join(u.HomeDir, ".ssh")) - if err != nil { - t.Fatalf("stat .ssh: %v", err) - } - if !info.IsDir() { - t.Error(".ssh is not a directory") - } -} - -func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { - u := tempUser(t) - - // Pre-seed a non-brev key. - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa EXISTING user@host\n"), 0o600); err != nil { - t.Fatal(err) - } - - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - - content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa EXISTING user@host") { - t.Errorf("existing key was lost:\n%s", content) - } - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { - t.Errorf("new key not found:\n%s", content) - } -} - -func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) { - u := tempUser(t) - - // Pre-seed a key without the brev-cli tag. - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey\n"), 0o600); err != nil { - t.Fatal(err) - } - - // InstallAuthorizedKey should tag the existing key rather than adding a duplicate. - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) - } - - content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { - t.Errorf("expected existing key to be tagged with %q, got:\n%s", register.BrevKeyPrefix, content) - } - count := strings.Count(content, "ssh-rsa AAAA testkey") - if count != 1 { - t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) - } -} - // --- RemoveBrevAuthorizedKeys --- func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { @@ -185,9 +44,9 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { content := strings.Join([]string{ "ssh-rsa EXISTING user@host", - "ssh-rsa BREVKEY1 " + register.BrevKeyPrefix, + "ssh-rsa BREVKEY1 " + register.DevplaneAuthorizedKeysComment("p1", "u1"), "ssh-ed25519 OTHERKEY admin@server", - "ssh-rsa BREVKEY2 " + register.BrevKeyPrefix, + "ssh-rsa BREVKEY2 " + register.DevplaneAuthorizedKeysComment("p2", "u2"), "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { @@ -204,7 +63,7 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { } result := readAuthorizedKeys(t, u) - if strings.Contains(result, register.BrevKeyPrefix) { + if strings.Contains(result, "#brev-portID:") { t.Errorf("brev keys still present:\n%s", result) } if !strings.Contains(result, "ssh-rsa EXISTING user@host") { @@ -264,7 +123,7 @@ func Test_RemoveAuthorizedKey_RemovesOnlyTargetKey(t *testing.T) { content := strings.Join([]string{ "ssh-rsa KEEP1 user@host", - "ssh-rsa TARGET " + register.BrevKeyPrefix, + "ssh-rsa TARGET " + register.DevplaneAuthorizedKeysComment("p1", "u1"), "ssh-rsa KEEP2 admin@server", "", }, "\n") @@ -337,8 +196,8 @@ func Test_RemoveAuthorizedKey_DoesNotRemoveOtherBrevKeys(t *testing.T) { } content := strings.Join([]string{ - "ssh-rsa ALICE_KEY " + register.BrevKeyPrefix, - "ssh-rsa BOB_KEY " + register.BrevKeyPrefix, + "ssh-rsa ALICE_KEY " + register.DevplaneAuthorizedKeysComment("p1", "u1"), + "ssh-rsa BOB_KEY " + register.DevplaneAuthorizedKeysComment("p2", "u2"), "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { @@ -359,71 +218,6 @@ func Test_RemoveAuthorizedKey_DoesNotRemoveOtherBrevKeys(t *testing.T) { } } -// --- Round-trip: install then remove (all brev keys) --- - -func Test_InstallThenRemove_RoundTrip(t *testing.T) { - u := tempUser(t) - - // Pre-seed a non-brev key. - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa EXISTING user@host\n"), 0o600); err != nil { - t.Fatal(err) - } - - // Install two brev keys. - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1", "user_1"); err != nil { - t.Fatal(err) - } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2", "user_2"); err != nil { - t.Fatal(err) - } - - // Remove all brev keys. - if _, err := register.RemoveBrevAuthorizedKeys(u); err != nil { - t.Fatal(err) - } - - result := readAuthorizedKeys(t, u) - if strings.Contains(result, register.BrevKeyPrefix) { - t.Errorf("brev keys still present after removal:\n%s", result) - } - if !strings.Contains(result, "ssh-rsa EXISTING user@host") { - t.Errorf("non-brev key was removed:\n%s", result) - } -} - -// --- Round-trip: install then rollback specific key --- - -func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) { - u := tempUser(t) - - // Install two brev keys (simulating two users granted access). - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE", "user_a"); err != nil { - t.Fatal(err) - } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB", "user_b"); err != nil { - t.Fatal(err) - } - - // Simulate rollback: remove only Bob's key (e.g. his grant RPC failed). - if err := register.RemoveAuthorizedKey(u, "ssh-rsa BOB"); err != nil { - t.Fatal(err) - } - - result := readAuthorizedKeys(t, u) - if strings.Contains(result, "BOB") { - t.Errorf("Bob's key still present after rollback:\n%s", result) - } - if !strings.Contains(result, "ssh-rsa ALICE") { - t.Errorf("Alice's key was removed during Bob's rollback:\n%s", result) - } -} - -// --- existingSSHPort --- - type mockNodeClientFactory struct{ serverURL string } func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { @@ -461,64 +255,27 @@ func startFakeServer(t *testing.T, svc *fakeNodeService) (enableSSHDeps, *httpte }, server } -func Test_existingSSHPort(t *testing.T) { - ssh := nodev1.PortProtocol_PORT_PROTOCOL_SSH - tcp := nodev1.PortProtocol_PORT_PROTOCOL_TCP - - tests := []struct { - name string - resp *nodev1.GetNodeResponse - rpcErr error - wantPort int32 - wantErr bool - }{ - { - name: "ReturnsExistingPort", - resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{ - Ports: []*nodev1.Port{{Protocol: ssh, PortNumber: 2222}}, - }}, - wantPort: 2222, - }, - { - name: "ReturnsZeroWhenNoPorts", - resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{}}, - wantPort: 0, - }, - { - name: "ReturnsErrorOnRPCFailure", - rpcErr: connect.NewError(connect.CodeInternal, fmt.Errorf("server error")), - wantErr: true, - }, - { - name: "IgnoresNonSSHPorts", - resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{ - Ports: []*nodev1.Port{ - {Protocol: tcp, PortNumber: 8080}, - {Protocol: ssh, PortNumber: 3333}, - }, - }}, - wantPort: 3333, +func Test_fetchRegisteredNode(t *testing.T) { + svc := &fakeNodeService{ + getNodeFn: func(req *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + if req.GetExternalNodeId() != "unode_abc" { + t.Fatalf("unexpected node id %q", req.GetExternalNodeId()) + } + return &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + Ports: []*nodev1.Port{{PortId: "port_1", PortNumber: 11640, ServerPort: 22}}, + }}, nil }, } + deps, _ := startFakeServer(t, svc) + store := &mockEnableSSHStore{token: "tok"} + reg := ®ister.DeviceRegistration{ExternalNodeID: "unode_abc", OrgID: "org_1"} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - svc := &fakeNodeService{ - getNodeFn: func(_ *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { - return tt.resp, tt.rpcErr - }, - } - deps, _ := startFakeServer(t, svc) - store := &mockEnableSSHStore{token: "tok"} - reg := ®ister.DeviceRegistration{ExternalNodeID: "unode_abc"} - - port, err := existingSSHPort(context.Background(), deps, store, reg) - if (err != nil) != tt.wantErr { - t.Fatalf("err = %v, wantErr = %v", err, tt.wantErr) - } - if port != tt.wantPort { - t.Errorf("port = %d, want %d", port, tt.wantPort) - } - }) + node, err := fetchRegisteredNode(context.Background(), deps, store, reg) + if err != nil { + t.Fatal(err) + } + if len(node.GetPorts()) != 1 || node.GetPorts()[0].GetPortId() != "port_1" { + t.Fatalf("unexpected node: %+v", node) } } diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index 7597bce05..a0d705acd 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -60,6 +60,7 @@ func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { var nodeFlag string var userFlag string var linuxUser string + var portIDFlag string var approveFlag bool cmd := &cobra.Command{ @@ -67,8 +68,8 @@ func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { Use: "grant-ssh", DisableFlagsInUseLine: true, Short: "Grant SSH access to a node for another org member", - Long: "Grant SSH access to a node for another member of your organization. Interactive: no flags, prompts for org and user. Non-interactive: --org, --user, --linux-user required.", - Example: " brev grant-ssh\n brev grant-ssh --org my-org --node my-node --user user@example.com --linux-user ubuntu --approve", + Long: "Grant SSH access to a node for another member of your organization. Interactive: no flags, prompts for org, node, port, and user. Non-interactive: --org, --node, --user, --linux-user, and --port-id required.", + Example: " brev grant-ssh\n brev grant-ssh --org my-org --node my-node --user user@example.com --linux-user ubuntu --port-id port_abc --approve", RunE: func(cmd *cobra.Command, args []string) error { interactive := orgFlag == "" && nodeFlag == "" && userFlag == "" opts := grantSSHOpts{ @@ -77,6 +78,7 @@ func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { nodeName: nodeFlag, userIDOrEmail: userFlag, linuxUser: linuxUser, + portID: portIDFlag, skipConfirm: approveFlag, } return runGrantSSH(cmd.Context(), t, store, opts, defaultGrantSSHDeps()) @@ -87,6 +89,7 @@ func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { cmd.Flags().StringVarP(&nodeFlag, "node", "n", "", "node name (required in non-interactive mode)") cmd.Flags().StringVarP(&userFlag, "user", "u", "", "Brev user ID or email to grant (required in non-interactive mode)") cmd.Flags().StringVar(&linuxUser, "linux-user", "", "Linux username on the target node (required in non-interactive mode)") + cmd.Flags().StringVar(&portIDFlag, "port-id", "", "Brev port ID to grant access on (required in non-interactive mode)") cmd.Flags().BoolVar(&approveFlag, "approve", false, "skip confirmation prompt (assume yes)") return cmd @@ -99,12 +102,12 @@ type grantSSHOpts struct { nodeName string userIDOrEmail string linuxUser string + portID string skipConfirm bool } // runGrantSSH runs the grant-ssh flow; the only difference by mode is whether we prompt or use opts. func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opts grantSSHOpts, deps grantSSHDeps) error { //nolint:gocognit,gocyclo,funlen // ok - // Run through the login flow currentUser, err := s.GetCurrentUser() if err != nil { return breverrors.WrapAndTrace(err) @@ -115,11 +118,13 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt return fmt.Errorf("in non-interactive mode --org, --node, and --user are required") } if opts.linuxUser == "" { - return fmt.Errorf("--linux-user is required in non-interactive mode (no cached value found)") + return fmt.Errorf("--linux-user is required in non-interactive mode") + } + if opts.portID == "" { + return fmt.Errorf("--port-id is required in non-interactive mode") } } - // Capture the target organization var org *entity.Organization if opts.interactive { allOrgs, listErr := s.ListOrganizations() @@ -136,7 +141,6 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) - // Capture the target node var node *nodev1.ExternalNode if opts.interactive { resp, listErr := client.ListNodes(ctx, connect.NewRequest(&nodev1.ListNodesRequest{ @@ -162,6 +166,11 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt return err } + brevPortID, portLabel, err := resolveGrantPort(ctx, t, opts, deps, node) + if err != nil { + return err + } + var selectedUser *entity.User var linuxUser string if opts.interactive { @@ -199,6 +208,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt t.Vprint("") } t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Node:")), t.BoldBlue(node.GetName()+" ("+node.GetExternalNodeId()+")")) + t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Port:")), t.BoldBlue(portLabel)) t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Brev user:")), t.BoldBlue(selectedUser.Name+" ("+selectedUser.ID+")")) t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Linux user:")), t.BoldBlue(linuxUser)) t.Vprint("") @@ -213,6 +223,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt _, err = client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ ExternalNodeId: node.GetExternalNodeId(), + PortId: brevPortID, UserId: selectedUser.ID, LinuxUser: linuxUser, })) @@ -226,6 +237,36 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, opt return nil } +// resolveGrantPort returns the Brev port ID and display label for the grant target. +func resolveGrantPort(ctx context.Context, t *terminal.Terminal, opts grantSSHOpts, deps grantSSHDeps, node *nodev1.ExternalNode) (portID, portLabel string, err error) { + if !opts.interactive { + p := findPortByID(node, opts.portID) + if p == nil { + return "", "", fmt.Errorf("no port with id %q on node %q", opts.portID, node.GetName()) + } + return p.GetPortId(), register.FormatPortLabel(p), nil + } + + ports := node.GetPorts() + if len(ports) == 0 { + return "", "", fmt.Errorf("no ports found on node %q", node.GetName()) + } + selected, selErr := register.SelectPortFromList(ctx, t, deps.prompter, ports) + if selErr != nil { + return "", "", selErr + } + return selected.GetPortId(), register.FormatPortLabel(selected), nil +} + +func findPortByID(node *nodev1.ExternalNode, portID string) *nodev1.Port { + for _, p := range node.GetPorts() { + if p.GetPortId() == portID { + return p + } + } + return nil +} + // uniqueLinuxUsersFromNodeSSHAccess returns distinct Linux users from the node's existing SSH access (for picker). func uniqueLinuxUsersFromNodeSSHAccess(node *nodev1.ExternalNode) []string { if node == nil { diff --git a/pkg/cmd/grantssh/grantssh_test.go b/pkg/cmd/grantssh/grantssh_test.go index 083a754ca..f8b22b697 100644 --- a/pkg/cmd/grantssh/grantssh_test.go +++ b/pkg/cmd/grantssh/grantssh_test.go @@ -245,6 +245,14 @@ func Test_runGrantSSH_HappyPath(t *testing.T) { ExternalNodeId: "unode_abc", Name: "My Spark", SshAccess: []*nodev1.SSHAccess{{UserId: "user_1", LinuxUser: "ubuntu"}}, + Ports: []*nodev1.Port{ + { + PortId: "port_ssh", + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, + ServerPort: 41920, + }, + }, }, }, }, nil @@ -277,6 +285,74 @@ func Test_runGrantSSH_HappyPath(t *testing.T) { if gotReq.GetLinuxUser() != "ubuntu" { t.Errorf("expected linux user ubuntu, got %s", gotReq.GetLinuxUser()) } + if gotReq.GetPortId() != "port_ssh" { + t.Errorf("expected port ID port_ssh, got %s", gotReq.GetPortId()) + } +} + +func Test_runGrantSSH_NonInteractiveWithPortID(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + targetUser := &entity.User{ID: "user_2", Name: "Alice", Email: "alice@example.com"} + + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", + attachments: []entity.OrgRoleAttachment{ + {Subject: "user_1"}, + {Subject: "user_2"}, + }, + users: map[string]*entity.User{"user_2": targetUser}, + } + + var gotReq *nodev1.GrantNodeSSHAccessRequest + svc := &fakeNodeService{ + listNodesFn: func(_ *nodev1.ListNodesRequest) (*nodev1.ListNodesResponse, error) { + return &nodev1.ListNodesResponse{ + Items: []*nodev1.ExternalNode{ + { + ExternalNodeId: "unode_abc", + Name: "My Spark", + Ports: []*nodev1.Port{ + {PortId: "port_ssh", Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, PortNumber: 22}, + }, + }, + }, + }, nil + }, + grantSSHFn: func(req *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + gotReq = req + return &nodev1.GrantNodeSSHAccessResponse{}, nil + }, + } + + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + opts := grantSSHOpts{ + interactive: false, + orgName: "TestOrg", + nodeName: "My Spark", + userIDOrEmail: "alice@example.com", + linuxUser: "ubuntu", + portID: "port_ssh", + skipConfirm: true, + } + err := runGrantSSH(context.Background(), term, store, opts, deps) + if err != nil { + t.Fatalf("runGrantSSH failed: %v", err) + } + if gotReq == nil || gotReq.GetPortId() != "port_ssh" { + t.Fatalf("expected port_ssh in request, got %+v", gotReq) + } } func Test_runGrantSSH_RPCFailure(t *testing.T) { diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index af8ec628f..e4c8367b2 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -457,12 +457,12 @@ func grantSSHAccessWithPort(ctx context.Context, t *terminal.Terminal, deps regi } func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User, port int32) error { - err := OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port) + brevPortID, err := OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port) if err != nil { return fmt.Errorf("allocate SSH port failed: %w", err) } - err = SetupAndRegisterNodeSSHAccess(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser.Username) + err = SetupAndRegisterNodeSSHAccess(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser.Username, brevPortID) if err != nil { return fmt.Errorf("grant SSH failed: %w", err) } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index c7c09fa3a..d98b1a92a 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -1010,8 +1010,8 @@ func Test_runRegister_OpenSSHPort(t *testing.T) { // nolint:funlen, gocyclo, goc if openReq.GetExternalNodeId() != "unode_abc" { t.Errorf("expected node ID unode_abc, got %s", openReq.GetExternalNodeId()) } - if openReq.GetProtocol() != nodev1.PortProtocol_PORT_PROTOCOL_SSH { - t.Errorf("expected PORT_PROTOCOL_SSH, got %s", openReq.GetProtocol()) + if openReq.GetProtocol() != nodev1.PortProtocol_PORT_PROTOCOL_TCP { + t.Errorf("expected PORT_PROTOCOL_TCP, got %s", openReq.GetProtocol()) } if openReq.GetPortNumber() != 2222 { t.Errorf("expected port 2222, got %d", openReq.GetPortNumber()) diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index 941729eca..f865070a8 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -55,6 +55,78 @@ func SelectNodeFromList(ctx context.Context, t *terminal.Terminal, prompter term return selected, nil } + +func SelectPortFromList(_ context.Context, t *terminal.Terminal, prompter terminal.Selector, ports []*nodev1.Port) (*nodev1.Port, error) { + if len(ports) == 0 { + return nil, fmt.Errorf("no ports to select") + } + if len(ports) == 1 { + return ports[0], nil + } + t.Vprint("") + labels := make([]string, len(ports)) + for i, p := range ports { + labels[i] = FormatPortLabel(p) + } + chosen := prompter.Select("Select port", labels) + for i, label := range labels { + if label == chosen { + return ports[i], nil + } + } + return nil, fmt.Errorf("selected item did not match any port") +} + +const ( + // PortChoiceUseExisting is the interactive option to grant SSH on an allocated port. + PortChoiceUseExisting = "Use an existing port" + // PortChoiceOpenNew is the interactive option to open a new port before granting SSH. + PortChoiceOpenNew = "Open a new port" +) + +// ResolveSSHAccessPort prompts for an existing or new port and returns its Brev port ID. +func ResolveSSHAccessPort( + ctx context.Context, + t *terminal.Terminal, + prompter terminal.Selector, + nodeClients externalnode.NodeClientFactory, + tokenProvider externalnode.TokenProvider, + reg *DeviceRegistration, + node *nodev1.ExternalNode, +) (string, error) { + ports := node.GetPorts() + if len(ports) == 0 { + return openPortForSSHAccess(ctx, t, nodeClients, tokenProvider, reg) + } + + t.Vprint("") + choice := prompter.Select("SSH port", []string{PortChoiceUseExisting, PortChoiceOpenNew}) + switch choice { + case PortChoiceUseExisting: + selected, err := SelectPortFromList(ctx, t, prompter, ports) + if err != nil { + return "", err + } + t.Vprintf(" Using port %s.\n", FormatPortLabel(selected)) + return selected.GetPortId(), nil + case PortChoiceOpenNew: + return openPortForSSHAccess(ctx, t, nodeClients, tokenProvider, reg) + default: + return "", fmt.Errorf("invalid port choice %q", choice) + } +} + +// FormatPortLabel formats a port as "nodePort->connectPort" (e.g. "11640->22"). +func FormatPortLabel(p *nodev1.Port) string { + if p == nil { + return "" + } + if p.GetServerPort() != 0 { + return fmt.Sprintf("%d->%d", p.GetPortNumber(), p.GetServerPort()) + } + return fmt.Sprintf("%d", p.GetPortNumber()) +} + const ( backoffInitialInterval = 1 * time.Second backoffMaxInterval = 10 * time.Second @@ -63,30 +135,72 @@ const ( backoffPrintRound = 500 * time.Millisecond ) -// BrevKeyPrefix is the marker prefix appended to every SSH key that Brev -// installs. Both old-format ("# brev-cli") and new-format -// ("# brev-cli user_id=xxx") keys start with this prefix. -const BrevKeyPrefix = "# brev-cli" +// BrevKeyPrefixLegacy marks keys written by older CLI versions (# brev-cli). +const BrevKeyPrefixLegacy = "# brev-cli" -// BrevKeyTag returns a comment tag that encodes the Brev user ID. -// Example: "# brev-cli user_id=user_abc123" -func BrevKeyTag(userID string) string { - if userID == "" { - return BrevKeyPrefix - } - return fmt.Sprintf("%s user_id=%s", BrevKeyPrefix, userID) +// BrevKeyPrefix is an alias for BrevKeyPrefixLegacy (tests and migration). +const BrevKeyPrefix = BrevKeyPrefixLegacy + +const ( + brevPortIDField = "brev-portID:" + brevUserIDField = "brev-userID:" +) + +// DevplaneAuthorizedKeysComment is the suffix on Brev-managed authorized_keys lines. +// The CLI writes this before GrantNodeSSHAccess so devplane need not modify the file. +func DevplaneAuthorizedKeysComment(portID, userID string) string { + return fmt.Sprintf("#brev-portID:%s,brev-userID:%s", portID, userID) } -// BrevAuthorizedKey represents a single Brev-managed key found in -// authorized_keys. +// BrevAuthorizedKey represents a single Brev-managed key found in authorized_keys. type BrevAuthorizedKey struct { Line string // full line from authorized_keys - KeyContent string // the ssh key portion (without the brev comment) - UserID string // parsed from "user_id=xxx", empty for old-format keys + KeyContent string // key type + material (and optional ssh comment), without brev suffix + PortID string // from devplane #brev-portID:... + UserID string // from devplane brev-userID:... or legacy user_id= } -// ListBrevAuthorizedKeys reads ~/.ssh/authorized_keys and returns all lines -// containing the BrevKeyPrefix marker. +func isBrevManagedAuthorizedKeysLine(line string) bool { + return strings.Contains(line, BrevKeyPrefixLegacy) || strings.Contains(line, "#brev-portID:") +} + +func parseBrevAuthorizedKeyLine(trimmed string) BrevAuthorizedKey { + bk := BrevAuthorizedKey{Line: trimmed} + + if idx := strings.Index(trimmed, "#brev-portID:"); idx >= 0 { + bk.KeyContent = strings.TrimSpace(trimmed[:idx]) + tag := trimmed[idx+1:] + for _, part := range strings.Split(tag, ",") { + part = strings.TrimSpace(part) + switch { + case strings.HasPrefix(part, brevPortIDField): + bk.PortID = strings.TrimPrefix(part, brevPortIDField) + case strings.HasPrefix(part, brevUserIDField): + bk.UserID = strings.TrimPrefix(part, brevUserIDField) + } + } + return bk + } + + if idx := strings.Index(trimmed, " "+BrevKeyPrefixLegacy); idx >= 0 { + bk.KeyContent = strings.TrimSpace(trimmed[:idx]) + tag := trimmed[idx+1:] + if uidIdx := strings.Index(tag, "user_id="); uidIdx >= 0 { + rest := tag[uidIdx+len("user_id="):] + if spIdx := strings.Index(rest, " "); spIdx >= 0 { + bk.UserID = rest[:spIdx] + } else { + bk.UserID = rest + } + } + return bk + } + + bk.KeyContent = trimmed + return bk +} + +// ListBrevAuthorizedKeys reads ~/.ssh/authorized_keys and returns Brev-managed lines. func ListBrevAuthorizedKeys(u *user.User) ([]BrevAuthorizedKey, error) { authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") @@ -100,35 +214,14 @@ func ListBrevAuthorizedKeys(u *user.User) ([]BrevAuthorizedKey, error) { var keys []BrevAuthorizedKey for _, line := range strings.Split(string(data), "\n") { - if !strings.Contains(line, BrevKeyPrefix) { + if !isBrevManagedAuthorizedKeysLine(line) { continue } trimmed := strings.TrimSpace(line) if trimmed == "" { continue } - - bk := BrevAuthorizedKey{Line: trimmed} - - // Split on " # brev-cli" to get the key content before the tag. - if idx := strings.Index(trimmed, " "+BrevKeyPrefix); idx >= 0 { - bk.KeyContent = trimmed[:idx] - tag := trimmed[idx+1:] // the "# brev-cli ..." part - // Parse user_id if present. - if uidIdx := strings.Index(tag, "user_id="); uidIdx >= 0 { - rest := tag[uidIdx+len("user_id="):] - // user_id value ends at next space or end of string. - if spIdx := strings.Index(rest, " "); spIdx >= 0 { - bk.UserID = rest[:spIdx] - } else { - bk.UserID = rest - } - } - } else { - bk.KeyContent = trimmed - } - - keys = append(keys, bk) + keys = append(keys, parseBrevAuthorizedKeyLine(trimmed)) } return keys, nil @@ -165,10 +258,23 @@ func RemoveAuthorizedKeyLine(u *user.User, line string) error { return nil } -// OpenSSHPort calls the OpenPort RPC to allocate an SSH port on the node. -// This must be called before GrantSSHAccessToNode when enabling SSH for the -// first time on a device. The call is idempotent — if the port is already -// open, the server returns the existing allocation. +func openPortForSSHAccess( + ctx context.Context, + t *terminal.Terminal, + nodeClients externalnode.NodeClientFactory, + tokenProvider externalnode.TokenProvider, + reg *DeviceRegistration, +) (string, error) { + t.Vprint("") + port, err := PromptSSHPort(t) + if err != nil { + return "", fmt.Errorf("invalid port: %w", err) + } + return OpenSSHPort(ctx, t, nodeClients, tokenProvider, reg, port) +} + +// OpenSSHPort calls the OpenPort RPC to allocate a port on the node for SSH access. +// The call is idempotent — if the port is already open, the server returns the existing allocation. func OpenSSHPort( ctx context.Context, t *terminal.Terminal, @@ -176,28 +282,27 @@ func OpenSSHPort( tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, port int32, -) error { +) (string, error) { if port < 1 || port > 65535 { - return fmt.Errorf("invalid SSH port %d: port must be between 1 and 65535", port) + return "", fmt.Errorf("invalid port %d: port must be between 1 and 65535", port) } client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - _, err := client.OpenPort(ctx, connect.NewRequest(&nodev1.OpenPortRequest{ + brevPort, err := client.OpenPort(ctx, connect.NewRequest(&nodev1.OpenPortRequest{ ExternalNodeId: reg.ExternalNodeID, - Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_TCP, PortNumber: port, })) if err != nil { - return fmt.Errorf("failed to allocate SSH port: %w", err) + return "", fmt.Errorf("failed to allocate port: %w", err) } - t.Vprintf(" SSH port %d allocated.\n", port) - return nil + t.Vprintf(" Port %d allocated (%s).\n", port, FormatPortLabel(brevPort.Msg.GetPort())) + return brevPort.Msg.GetPort().GetPortId(), nil } -// SetupAndRegisterNodeSSHAccess installs the user's public key in authorized_keys and -// calls GrantNodeSSHAccess to record access server-side. If the RPC fails, -// the installed key is rolled back. osUser is the target username on the -// remote device (e.g. "ubuntu"), used both in the RPC and to locate the -// correct ~/.ssh/authorized_keys for local key installation. +// SetupAndRegisterNodeSSHAccess installs the user's public key in authorized_keys +// with the devplane comment tag, then calls GrantNodeSSHAccess. The key must be +// present locally before devplane can grant; matching tags make the grant a no-op +// on the authorized_keys file. func SetupAndRegisterNodeSSHAccess( ctx context.Context, t *terminal.Terminal, @@ -206,19 +311,26 @@ func SetupAndRegisterNodeSSHAccess( reg *DeviceRegistration, targetUser *entity.User, linuxUsername string, + brevPortID string, ) error { - // Look up the local OS user for key installation. This may differ from - // the caller (e.g. running as root, installing keys for "ubuntu"). + if brevPortID == "" { + return fmt.Errorf("port id is required to grant SSH access") + } + osUser, lookupErr := user.Lookup(linuxUsername) if lookupErr != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not look up local user %q — skipping local key installation", linuxUsername))) + return fmt.Errorf("could not look up local user %q: %w", linuxUsername, lookupErr) } - if osUser != nil && targetUser.PublicKey != "" { - if added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey, targetUser.ID); err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else if added { - t.Vprint(" Brev public key added to authorized_keys.") + var keyAdded bool + if targetUser.PublicKey != "" { + added, installErr := InstallAuthorizedKey(osUser, targetUser.PublicKey, brevPortID, targetUser.ID) + if installErr != nil { + return fmt.Errorf("failed to install SSH public key: %w", installErr) + } + keyAdded = added + if added { + t.Vprint(" Public key added to authorized_keys.") } } @@ -233,6 +345,7 @@ func SetupAndRegisterNodeSSHAccess( opToTry := func() error { _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ ExternalNodeId: reg.ExternalNodeID, + PortId: brevPortID, UserId: targetUser.ID, LinuxUser: linuxUsername, })) @@ -243,9 +356,10 @@ func SetupAndRegisterNodeSSHAccess( return fmt.Errorf("failed to grant SSH access (transient): %w", err) } - // Permanent error — roll back the key so we don't leave an unrecorded entry and abort the backoff retry - if osUser != nil && targetUser.PublicKey != "" { - if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { + // Permanent error — roll back only the line we added for this port + if keyAdded && targetUser.PublicKey != "" { + line := targetUser.PublicKey + " " + DevplaneAuthorizedKeysComment(brevPortID, targetUser.ID) + if rerr := RemoveAuthorizedKeyLine(osUser, line); rerr != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) } } @@ -308,12 +422,11 @@ func PromptSSHPort(t *terminal.Terminal) (int32, error) { } } -// InstallAuthorizedKey appends the given public key to the user's -// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with -// a brev-cli comment (including the user ID) so it can be identified and -// removed later by RemoveBrevAuthorizedKeys or ListBrevAuthorizedKeys. -// Returns true if the key was newly written, false if it was already present. -func InstallAuthorizedKey(u *user.User, pubKey string, brevUserID string) (bool, error) { +// InstallAuthorizedKey adds an authorized_keys line for this port and user. The same +// public key may appear on multiple lines (one per port). Untagged or legacy-tagged +// lines are upgraded in place; lines that already have a #brev-portID tag are left +// alone and a new line is appended. +func InstallAuthorizedKey(u *user.User, pubKey, portID, brevUserID string) (bool, error) { pubKey = strings.TrimSpace(pubKey) if pubKey == "" { return false, nil @@ -331,28 +444,35 @@ func InstallAuthorizedKey(u *user.User, pubKey string, brevUserID string) (bool, return false, fmt.Errorf("reading authorized_keys: %w", err) } - taggedKey := pubKey + " " + BrevKeyTag(brevUserID) + taggedLine := pubKey + " " + DevplaneAuthorizedKeysComment(portID, brevUserID) - if strings.Contains(string(existing), taggedKey) { - return false, nil // already present with tag + if strings.Contains(string(existing), taggedLine) { + return false, nil } - // If the key exists but isn't tagged, replace it with the tagged version - // so that RemoveBrevAuthorizedKeys can find it later. - if strings.Contains(string(existing), pubKey) { - updated := strings.ReplaceAll(string(existing), pubKey, taggedKey) - if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { - return false, fmt.Errorf("writing authorized_keys: %w", err) + lines := strings.Split(string(existing), "\n") + var out []string + upgraded := false + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if !upgraded && strings.Contains(trimmed, pubKey) && !strings.Contains(trimmed, "#brev-portID:") { + out = append(out, taggedLine) + upgraded = true + continue } - return true, nil + out = append(out, trimmed) } - // Ensure existing content ends with a newline before appending. - content := string(existing) - if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content := strings.Join(out, "\n") + if len(out) > 0 { content += "\n" } - content += taggedKey + "\n" + if !upgraded { + content += taggedLine + "\n" + } if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { return false, fmt.Errorf("writing authorized_keys: %w", err) @@ -395,9 +515,7 @@ func RemoveAuthorizedKey(u *user.User, pubKey string) error { return nil } -// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli -// comment from the user's ~/.ssh/authorized_keys. It returns the lines that -// were removed so callers can report what was cleaned up. +// RemoveBrevAuthorizedKeys removes all Brev-managed SSH keys from authorized_keys. func RemoveBrevAuthorizedKeys(u *user.User) ([]string, error) { authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") @@ -412,7 +530,7 @@ func RemoveBrevAuthorizedKeys(u *user.User) ([]string, error) { var kept []string var removed []string for _, line := range strings.Split(string(existing), "\n") { - if strings.Contains(line, BrevKeyPrefix) { + if isBrevManagedAuthorizedKeysLine(line) { if trimmed := strings.TrimSpace(line); trimmed != "" { removed = append(removed, trimmed) } diff --git a/pkg/cmd/register/sshkeys_port_resolve_test.go b/pkg/cmd/register/sshkeys_port_resolve_test.go new file mode 100644 index 000000000..243342015 --- /dev/null +++ b/pkg/cmd/register/sshkeys_port_resolve_test.go @@ -0,0 +1,137 @@ +package register + +import ( + "context" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + "github.com/brevdev/brev-cli/pkg/terminal" +) + +type mockPortSelector struct { + choices []string + idx int +} + +func (m *mockPortSelector) Select(_ string, items []string) string { + if m.idx < len(m.choices) { + ch := m.choices[m.idx] + m.idx++ + return ch + } + if len(items) > 0 { + return items[0] + } + return "" +} + +type portOpenTestStore struct{ token string } + +func (p portOpenTestStore) GetAccessToken() (string, error) { return p.token, nil } + +type openPortCapture struct { + number int32 + protocol nodev1.PortProtocol +} + +func startPortOpenTestServer(t *testing.T) (mockNodeClientFactory, *openPortCapture) { + t.Helper() + cap := &openPortCapture{} + svc := &fakeNodeService{ + openPortFn: func(req *nodev1.OpenPortRequest) (*nodev1.OpenPortResponse, error) { + cap.number = req.GetPortNumber() + cap.protocol = req.GetProtocol() + return &nodev1.OpenPortResponse{Port: &nodev1.Port{ + PortId: "port_new", + Protocol: req.GetProtocol(), + PortNumber: req.GetPortNumber(), + ServerPort: 22, + }}, nil + }, + } + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + return mockNodeClientFactory{serverURL: server.URL}, cap +} + +func TestResolveSSHAccessPort_noPortsOpensNew(t *testing.T) { + SetTestSSHPort(2222) + defer ClearTestSSHPort() + + clients, cap := startPortOpenTestServer(t) + portID, err := ResolveSSHAccessPort( + context.Background(), + terminal.New(), + &mockPortSelector{}, + clients, + portOpenTestStore{token: "tok"}, + &DeviceRegistration{ExternalNodeID: "unode_abc", OrgID: "org_1"}, + &nodev1.ExternalNode{ExternalNodeId: "unode_abc"}, + ) + if err != nil { + t.Fatal(err) + } + if portID != "port_new" { + t.Fatalf("got port id %q", portID) + } + if cap.number != 2222 { + t.Fatalf("open port number = %d, want 2222", cap.number) + } + if cap.protocol != nodev1.PortProtocol_PORT_PROTOCOL_TCP { + t.Fatalf("protocol = %v, want TCP", cap.protocol) + } +} + +func TestResolveSSHAccessPort_useExisting(t *testing.T) { + ports := []*nodev1.Port{ + {PortId: "port_a", PortNumber: 11640, ServerPort: 22}, + {PortId: "port_b", PortNumber: 8080, ServerPort: 8080}, + } + sel := &mockPortSelector{choices: []string{PortChoiceUseExisting, "11640->22"}} + + portID, err := ResolveSSHAccessPort( + context.Background(), + terminal.New(), + sel, + mockNodeClientFactory{serverURL: "http://unused"}, + portOpenTestStore{token: "tok"}, + &DeviceRegistration{ExternalNodeID: "unode_abc", OrgID: "org_1"}, + &nodev1.ExternalNode{Ports: ports}, + ) + if err != nil { + t.Fatal(err) + } + if portID != "port_a" { + t.Fatalf("got port id %q, want port_a", portID) + } +} + +func TestResolveSSHAccessPort_openNewWhenPortsExist(t *testing.T) { + SetTestSSHPort(2222) + defer ClearTestSSHPort() + + clients, cap := startPortOpenTestServer(t) + sel := &mockPortSelector{choices: []string{PortChoiceOpenNew}} + portID, err := ResolveSSHAccessPort( + context.Background(), + terminal.New(), + sel, + clients, + portOpenTestStore{token: "tok"}, + &DeviceRegistration{ExternalNodeID: "unode_abc", OrgID: "org_1"}, + &nodev1.ExternalNode{Ports: []*nodev1.Port{{PortId: "port_existing", PortNumber: 11640, ServerPort: 22}}}, + ) + if err != nil { + t.Fatal(err) + } + if portID != "port_new" { + t.Fatalf("got port id %q", portID) + } + if cap.number != 2222 { + t.Fatalf("open port number = %d, want 2222", cap.number) + } +} diff --git a/pkg/cmd/register/sshkeys_port_test.go b/pkg/cmd/register/sshkeys_port_test.go new file mode 100644 index 000000000..c392adbd9 --- /dev/null +++ b/pkg/cmd/register/sshkeys_port_test.go @@ -0,0 +1,45 @@ +package register + +import ( + "context" + "testing" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + "github.com/brevdev/brev-cli/pkg/terminal" +) + +func TestSelectPortFromList_singlePortAutoSelect(t *testing.T) { + ports := []*nodev1.Port{{PortId: "port_1", PortNumber: 22, ServerPort: 41920}} + p, err := SelectPortFromList(context.Background(), terminal.New(), mockSelectorAlwaysFirst{}, ports) + if err != nil { + t.Fatal(err) + } + if p.GetPortId() != "port_1" { + t.Fatalf("got %q", p.GetPortId()) + } +} + +func TestFormatPortLabel(t *testing.T) { + label := FormatPortLabel(&nodev1.Port{PortId: "port_1", PortNumber: 11640, ServerPort: 22}) + want := "11640->22" + if label != want { + t.Fatalf("got %q, want %q", label, want) + } +} + +func TestFormatPortLabel_noServerPort(t *testing.T) { + label := FormatPortLabel(&nodev1.Port{PortNumber: 22}) + if label != "22" { + t.Fatalf("got %q", label) + } +} + +type mockSelectorAlwaysFirst struct{} + +func (mockSelectorAlwaysFirst) Select(_ string, items []string) string { + if len(items) > 0 { + return items[0] + } + return "" +} diff --git a/pkg/cmd/register/sshkeys_test.go b/pkg/cmd/register/sshkeys_test.go index c474377d5..acecb74ab 100644 --- a/pkg/cmd/register/sshkeys_test.go +++ b/pkg/cmd/register/sshkeys_test.go @@ -35,31 +35,20 @@ func readKeys(t *testing.T, u *user.User) string { return string(data) } -// --- BrevKeyTag --- - -func TestBrevKeyTag_WithUserID(t *testing.T) { - tag := BrevKeyTag("user_abc123") - expected := "# brev-cli user_id=user_abc123" - if tag != expected { - t.Errorf("expected %q, got %q", expected, tag) +func TestDevplaneAuthorizedKeysComment(t *testing.T) { + got := DevplaneAuthorizedKeysComment("nport_abc", "user_xyz") + want := "#brev-portID:nport_abc,brev-userID:user_xyz" + if got != want { + t.Fatalf("got %q, want %q", got, want) } } -func TestBrevKeyTag_EmptyUserID(t *testing.T) { - tag := BrevKeyTag("") - if tag != BrevKeyPrefix { - t.Errorf("expected %q, got %q", BrevKeyPrefix, tag) - } -} - -// --- ListBrevAuthorizedKeys --- - -func TestListBrevAuthorizedKeys_ParsesNewFormat(t *testing.T) { +func TestListBrevAuthorizedKeys_ParsesDevplaneFormat(t *testing.T) { u := tempUser(t) seedKeys(t, u, strings.Join([]string{ "ssh-rsa EXISTING user@host", - "ssh-ed25519 AAAA_ALICE # brev-cli user_id=user_1", - "ssh-rsa AAAA_BOB # brev-cli user_id=user_2", + "ssh-ed25519 AAAA_ALICE user@a.com " + DevplaneAuthorizedKeysComment("port_1", "user_1"), + "ssh-rsa AAAA_BOB " + DevplaneAuthorizedKeysComment("port_2", "user_2"), "", }, "\n")) @@ -67,52 +56,39 @@ func TestListBrevAuthorizedKeys_ParsesNewFormat(t *testing.T) { if err != nil { t.Fatalf("ListBrevAuthorizedKeys: %v", err) } - if len(keys) != 2 { t.Fatalf("expected 2 keys, got %d", len(keys)) } - - if keys[0].KeyContent != "ssh-ed25519 AAAA_ALICE" { - t.Errorf("expected key content 'ssh-ed25519 AAAA_ALICE', got %q", keys[0].KeyContent) - } - if keys[0].UserID != "user_1" { - t.Errorf("expected user_id 'user_1', got %q", keys[0].UserID) - } - - if keys[1].KeyContent != "ssh-rsa AAAA_BOB" { - t.Errorf("expected key content 'ssh-rsa AAAA_BOB', got %q", keys[1].KeyContent) + if keys[0].PortID != "port_1" || keys[0].UserID != "user_1" { + t.Errorf("key[0]: port=%q user=%q", keys[0].PortID, keys[0].UserID) } - if keys[1].UserID != "user_2" { - t.Errorf("expected user_id 'user_2', got %q", keys[1].UserID) + if keys[1].PortID != "port_2" || keys[1].UserID != "user_2" { + t.Errorf("key[1]: port=%q user=%q", keys[1].PortID, keys[1].UserID) } } -func TestListBrevAuthorizedKeys_ParsesOldFormat(t *testing.T) { +func TestListBrevAuthorizedKeys_ParsesLegacyFormat(t *testing.T) { u := tempUser(t) - seedKeys(t, u, "ssh-ed25519 AAAA_OLD # brev-cli\n") + seedKeys(t, u, "ssh-ed25519 AAAA_OLD # brev-cli user_id=uid_42\n") keys, err := ListBrevAuthorizedKeys(u) if err != nil { t.Fatalf("ListBrevAuthorizedKeys: %v", err) } - if len(keys) != 1 { t.Fatalf("expected 1 key, got %d", len(keys)) } - if keys[0].KeyContent != "ssh-ed25519 AAAA_OLD" { - t.Errorf("expected key content 'ssh-ed25519 AAAA_OLD', got %q", keys[0].KeyContent) - } - if keys[0].UserID != "" { - t.Errorf("expected empty user_id for old format, got %q", keys[0].UserID) + if keys[0].UserID != "uid_42" { + t.Errorf("expected user_id uid_42, got %q", keys[0].UserID) } } func TestListBrevAuthorizedKeys_MixedFormats(t *testing.T) { u := tempUser(t) seedKeys(t, u, strings.Join([]string{ - "ssh-rsa AAAA_OLD # brev-cli", + "ssh-rsa AAAA_LEGACY # brev-cli", "ssh-rsa NONBREV user@host", - "ssh-ed25519 AAAA_NEW # brev-cli user_id=uid_42", + "ssh-ed25519 AAAA_NEW " + DevplaneAuthorizedKeysComment("p1", "uid_42"), "", }, "\n")) @@ -120,24 +96,13 @@ func TestListBrevAuthorizedKeys_MixedFormats(t *testing.T) { if err != nil { t.Fatalf("ListBrevAuthorizedKeys: %v", err) } - if len(keys) != 2 { t.Fatalf("expected 2 brev keys, got %d", len(keys)) } - - // Old format - if keys[0].UserID != "" { - t.Errorf("expected empty user_id for old format, got %q", keys[0].UserID) - } - // New format - if keys[1].UserID != "uid_42" { - t.Errorf("expected user_id 'uid_42', got %q", keys[1].UserID) - } } func TestListBrevAuthorizedKeys_NoFile(t *testing.T) { u := tempUser(t) - keys, err := ListBrevAuthorizedKeys(u) if err != nil { t.Fatalf("expected no error for missing file, got: %v", err) @@ -150,7 +115,6 @@ func TestListBrevAuthorizedKeys_NoFile(t *testing.T) { func TestListBrevAuthorizedKeys_NoBrevKeys(t *testing.T) { u := tempUser(t) seedKeys(t, u, "ssh-rsa NONBREV user@host\n") - keys, err := ListBrevAuthorizedKeys(u) if err != nil { t.Fatalf("ListBrevAuthorizedKeys: %v", err) @@ -160,130 +124,147 @@ func TestListBrevAuthorizedKeys_NoBrevKeys(t *testing.T) { } } -// --- RemoveAuthorizedKeyLine --- - func TestRemoveAuthorizedKeyLine_RemovesExactLine(t *testing.T) { u := tempUser(t) + line := "ssh-ed25519 REMOVE " + DevplaneAuthorizedKeysComment("p1", "user_1") seedKeys(t, u, strings.Join([]string{ "ssh-rsa KEEP user@host", - "ssh-ed25519 REMOVE # brev-cli user_id=user_1", - "ssh-rsa KEEP2 admin@server", + line, "", }, "\n")) - if err := RemoveAuthorizedKeyLine(u, "ssh-ed25519 REMOVE # brev-cli user_id=user_1"); err != nil { + if err := RemoveAuthorizedKeyLine(u, line); err != nil { t.Fatalf("RemoveAuthorizedKeyLine: %v", err) } + if strings.Contains(readKeys(t, u), "REMOVE") { + t.Fatal("line was not removed") + } +} - result := readKeys(t, u) - if strings.Contains(result, "REMOVE") { - t.Errorf("line was not removed:\n%s", result) +func TestRemoveBrevAuthorizedKeys_DevplaneLines(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, strings.Join([]string{ + "ssh-rsa KEEP user@host", + "ssh-rsa BREV1 " + DevplaneAuthorizedKeysComment("p1", "u1"), + "ssh-rsa BREV2 " + DevplaneAuthorizedKeysComment("p2", "u2"), + "", + }, "\n")) + + removed, err := RemoveBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) } - if !strings.Contains(result, "ssh-rsa KEEP user@host") { - t.Errorf("other key was removed:\n%s", result) + if len(removed) != 2 { + t.Fatalf("expected 2 removed, got %d", len(removed)) + } + result := readKeys(t, u) + if strings.Contains(result, "#brev-portID:") { + t.Errorf("brev keys remain:\n%s", result) } - if !strings.Contains(result, "ssh-rsa KEEP2 admin@server") { - t.Errorf("other key was removed:\n%s", result) + if !strings.Contains(result, "KEEP") { + t.Error("non-brev key was removed") } } -func TestRemoveAuthorizedKeyLine_NoopForEmptyLine(t *testing.T) { +func TestRemoveBrevAuthorizedKeys_LegacyLines(t *testing.T) { u := tempUser(t) - if err := RemoveAuthorizedKeyLine(u, ""); err != nil { - t.Fatalf("expected no error, got: %v", err) + seedKeys(t, u, "ssh-rsa BREVKEY "+BrevKeyPrefixLegacy+"\n") + removed, err := RemoveBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) + } + if len(removed) != 1 { + t.Fatalf("expected 1 removed, got %d", len(removed)) } } -func TestRemoveAuthorizedKeyLine_NoopForMissingFile(t *testing.T) { +func TestInstallAuthorizedKey_AppendsDevplaneComment(t *testing.T) { u := tempUser(t) - if err := RemoveAuthorizedKeyLine(u, "ssh-rsa SOMETHING"); err != nil { - t.Fatalf("expected no error, got: %v", err) + pub := "ssh-rsa AAAA testkey user@example.com" + wantTag := DevplaneAuthorizedKeysComment("port_1", "user_1") + if _, err := InstallAuthorizedKey(u, pub, "port_1", "user_1"); err != nil { + t.Fatal(err) + } + content := readKeys(t, u) + if !strings.Contains(content, pub+" "+wantTag) { + t.Fatalf("expected devplane-tagged key, got:\n%s", content) } } -// --- InstallAuthorizedKey with user ID --- - -func TestInstallAuthorizedKey_IncludesUserID(t *testing.T) { +func TestInstallAuthorizedKey_SkipsDuplicate(t *testing.T) { u := tempUser(t) - - if _, err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", "user_abc"); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) + pub := "ssh-rsa AAAA testkey" + if _, err := InstallAuthorizedKey(u, pub, "port_1", "user_1"); err != nil { + t.Fatal(err) } + if _, err := InstallAuthorizedKey(u, pub, "port_1", "user_1"); err != nil { + t.Fatal(err) + } + if strings.Count(readKeys(t, u), "ssh-rsa AAAA testkey") != 1 { + t.Fatal("expected single key line") + } +} +func TestInstallAuthorizedKey_UpgradesUntaggedKey(t *testing.T) { + u := tempUser(t) + pub := "ssh-rsa AAAA testkey" + seedKeys(t, u, pub+"\n") + wantTag := DevplaneAuthorizedKeysComment("port_1", "user_1") + if _, err := InstallAuthorizedKey(u, pub, "port_1", "user_1"); err != nil { + t.Fatal(err) + } content := readKeys(t, u) - expected := "ssh-rsa AAAA testkey # brev-cli user_id=user_abc" - if !strings.Contains(content, expected) { - t.Errorf("expected %q in authorized_keys, got:\n%s", expected, content) + if !strings.Contains(content, pub+" "+wantTag) { + t.Fatalf("expected upgraded tag, got:\n%s", content) } } -// --- PromptSSHPort --- - -func promptSSHPortWithInput(t *testing.T, input string) (int32, error) { - t.Helper() +func TestInstallAuthorizedKey_secondPortAppendsNewLine(t *testing.T) { + u := tempUser(t) + pub := "ssh-rsa AAAA testkey user@example.com" + line1 := pub + " " + DevplaneAuthorizedKeysComment("port_a", "user_1") + seedKeys(t, u, line1+"\n") - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("creating pipe: %v", err) + if _, err := InstallAuthorizedKey(u, pub, "port_b", "user_1"); err != nil { + t.Fatal(err) } - defer r.Close() - - // Write input and close writer so ReadString sees EOF after newline. - if _, err := w.WriteString(input); err != nil { - t.Fatalf("writing to pipe: %v", err) + content := readKeys(t, u) + if strings.Count(content, "#brev-portID:") != 2 { + t.Fatalf("expected two port-tagged lines, got:\n%s", content) } - w.Close() - - origStdin := os.Stdin - os.Stdin = r - defer func() { os.Stdin = origStdin }() - - ClearTestSSHPort() // ensure we go through the real path - term := terminal.New() - return PromptSSHPort(term) -} - -func TestPromptSSHPort(t *testing.T) { - tests := []struct { - name string - input string - want int32 - }{ - {"Default", "\n", 22}, - {"CustomPort", "2222\n", 2222}, - {"MinPort", "1\n", 1}, - {"MaxPort", "65535\n", 65535}, - {"RetryAfterOutOfRange", "99999\n22\n", 22}, - {"RetryAfterZero", "0\n443\n", 443}, - {"RetryAfterNonNumeric", "abc\n8080\n", 8080}, - {"RetryAfterNegative", "-1\n22\n", 22}, - {"RetryMultipleThenValid", "foo\n99999\n22\n", 22}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - port, err := promptSSHPortWithInput(t, tt.input) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if port != tt.want { - t.Errorf("expected port %d, got %d", tt.want, port) - } - }) + if !strings.Contains(content, DevplaneAuthorizedKeysComment("port_a", "user_1")) { + t.Fatal("first port line should remain") + } + if !strings.Contains(content, DevplaneAuthorizedKeysComment("port_b", "user_1")) { + t.Fatal("second port line should be added") + } + if strings.Count(content, "port_a") != 1 || strings.Count(content, "port_b") != 1 { + t.Fatalf("merged comments on one line:\n%s", content) } } -func TestInstallAuthorizedKey_EmptyUserID_UsesPrefix(t *testing.T) { +func TestRemoveAuthorizedKey_ByPublicKeyMaterial(t *testing.T) { u := tempUser(t) - - if _, err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { - t.Fatalf("InstallAuthorizedKey: %v", err) + pub := "ssh-rsa AAAA testkey" + seedKeys(t, u, pub+" user@host "+DevplaneAuthorizedKeysComment("p1", "u1")+"\n") + if err := RemoveAuthorizedKey(u, pub); err != nil { + t.Fatal(err) + } + if strings.Contains(readKeys(t, u), "AAAA") { + t.Fatal("key material should be removed") } +} - content := readKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyPrefix) { - t.Errorf("expected key tagged with prefix, got:\n%s", content) +// --- PromptSSHPort --- + +func TestPromptSSHPort(t *testing.T) { + SetTestSSHPort(2222) + defer ClearTestSSHPort() + port, err := PromptSSHPort(terminal.New()) + if err != nil { + t.Fatalf("PromptSSHPort: %v", err) } - if strings.Contains(content, "user_id=") { - t.Errorf("should not contain user_id when empty, got:\n%s", content) + if port != 2222 { + t.Errorf("expected 2222, got %d", port) } } diff --git a/pkg/cmd/revokessh/revokessh.go b/pkg/cmd/revokessh/revokessh.go index c8bc34266..61d90e860 100644 --- a/pkg/cmd/revokessh/revokessh.go +++ b/pkg/cmd/revokessh/revokessh.go @@ -52,6 +52,7 @@ func NewCmdRevokeSSH(t *terminal.Terminal, store RevokeSSHStore) *cobra.Command var nodeFlag string var userFlag string var linuxUserFlag string + var portIDFlag string var approveFlag bool cmd := &cobra.Command{ @@ -59,16 +60,17 @@ func NewCmdRevokeSSH(t *terminal.Terminal, store RevokeSSHStore) *cobra.Command Use: "revoke-ssh", DisableFlagsInUseLine: true, Short: "Revoke SSH access to a node for an org member", - Long: "Revoke SSH access to a node for a member of your organization. Interactive: no flags, prompts for org and which access to revoke. Non-interactive: --org, --user, --linux-user required.", - Example: " brev revoke-ssh\n brev revoke-ssh --org my-org --node my-node --user user@example.com --linux-user ubuntu --approve", + Long: "Revoke SSH access to a node for a member of your organization. Interactive: no flags, prompts for org, node, and which access entry to revoke. Non-interactive: --org, --node, --user, --linux-user, and --port-id required.", + Example: " brev revoke-ssh\n brev revoke-ssh --org my-org --node my-node --user user@example.com --linux-user ubuntu --port-id port_abc --approve", RunE: func(cmd *cobra.Command, args []string) error { - interactive := orgFlag == "" && nodeFlag == "" && userFlag == "" && linuxUserFlag == "" + interactive := orgFlag == "" && nodeFlag == "" && userFlag == "" && linuxUserFlag == "" && portIDFlag == "" opts := revokeSSHOpts{ interactive: interactive, orgName: orgFlag, nodeName: nodeFlag, userIDOrEmail: userFlag, linuxUser: linuxUserFlag, + portID: portIDFlag, skipConfirm: approveFlag, } return runRevokeSSH(cmd.Context(), t, store, opts, defaultRevokeSSHDeps()) @@ -79,31 +81,31 @@ func NewCmdRevokeSSH(t *terminal.Terminal, store RevokeSSHStore) *cobra.Command cmd.Flags().StringVarP(&nodeFlag, "node", "n", "", "node name (required in non-interactive mode)") cmd.Flags().StringVarP(&userFlag, "user", "u", "", "Brev user ID or email to revoke (required in non-interactive mode)") cmd.Flags().StringVar(&linuxUserFlag, "linux-user", "", "Linux username on the target node (required in non-interactive mode)") + cmd.Flags().StringVar(&portIDFlag, "port-id", "", "Brev port ID for the SSH access entry (required in non-interactive mode)") cmd.Flags().BoolVar(&approveFlag, "approve", false, "skip confirmation prompt (assume yes)") return cmd } -// revokeSSHOpts carries mode and inputs: when interactive, org/user/linuxUser from prompts; otherwise from flags. +// revokeSSHOpts carries mode and inputs: when interactive, org/user/linuxUser/port from prompts; otherwise from flags. type revokeSSHOpts struct { interactive bool orgName string nodeName string userIDOrEmail string linuxUser string + portID string skipConfirm bool } // runRevokeSSH runs the revoke-ssh flow; the only difference by mode is whether we prompt or use opts. func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, opts revokeSSHOpts, deps revokeSSHDeps) error { //nolint:gocognit,gocyclo,funlen // ok - // Basic validation if !opts.interactive { - if opts.orgName == "" || opts.nodeName == "" || opts.userIDOrEmail == "" || opts.linuxUser == "" { - return fmt.Errorf("in non-interactive mode --org, --node, --user, and --linux-user are required") + if opts.orgName == "" || opts.nodeName == "" || opts.userIDOrEmail == "" || opts.linuxUser == "" || opts.portID == "" { + return fmt.Errorf("in non-interactive mode --org, --node, --user, --linux-user, and --port-id are required") } } - // Capture the target organization var selectedOrg *entity.Organization if opts.interactive { list, listErr := s.ListOrganizations() @@ -125,7 +127,6 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) - // Capture the target node var selectedNode *nodev1.ExternalNode if opts.interactive { resp, listErr := client.ListNodes(ctx, connect.NewRequest(&nodev1.ListNodesRequest{ @@ -157,16 +158,11 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o return nil } - // Capture the target SSH access - var targetUserID, targetLinuxUser string + var targetUserID, targetLinuxUser, targetPortID, portLabel string if opts.interactive { labels := make([]string, len(nodeSshAccesses)) for i, sa := range nodeSshAccesses { - userName := sa.GetUserId() - if u, err := s.GetUserByID(sa.GetUserId()); err == nil && u != nil { - userName = fmt.Sprintf("%s (%s)", u.Name, sa.GetUserId()) - } - labels[i] = fmt.Sprintf("%s, linux_user: %s", userName, sa.GetLinuxUser()) + labels[i] = formatSSHAccessLabel(s, sa, selectedNode) } selected := deps.prompter.Select("Select SSH access to revoke:", labels) @@ -184,6 +180,8 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o targetUserID = selectedAccess.GetUserId() targetLinuxUser = selectedAccess.GetLinuxUser() + targetPortID = selectedAccess.GetPortId() + portLabel = portLabelForAccess(selectedNode, selectedAccess) } else { resolvedUserID, err := resolveUserID(s, selectedOrg.ID, opts.userIDOrEmail) if err != nil { @@ -191,17 +189,20 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o } var found bool for _, sa := range nodeSshAccesses { - if sa.GetUserId() == resolvedUserID && sa.GetLinuxUser() == opts.linuxUser { + if sa.GetUserId() == resolvedUserID && sa.GetLinuxUser() == opts.linuxUser && sa.GetPortId() == opts.portID { found = true + portLabel = portLabelForAccess(selectedNode, sa) break } } if !found { - return fmt.Errorf("no SSH access entry found for user %q and linux_user %q on node %q", targetUserID, targetLinuxUser, selectedNode.GetName()) + return fmt.Errorf("no SSH access entry found for user %q, linux_user %q, and port_id %q on node %q", + resolvedUserID, opts.linuxUser, opts.portID, selectedNode.GetName()) } targetUserID = resolvedUserID targetLinuxUser = opts.linuxUser + targetPortID = opts.portID } userDisplay := targetUserID @@ -219,6 +220,7 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o t.Vprint("") } t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Node:")), t.BoldBlue(selectedNode.GetName()+" ("+selectedNode.GetExternalNodeId()+")")) + t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Port:")), t.BoldBlue(portLabel)) t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "User:")), t.BoldBlue(userDisplay)) t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "Linux user:")), t.BoldBlue(targetLinuxUser)) t.Vprint("") @@ -233,6 +235,7 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o _, err := client.RevokeNodeSSHAccess(ctx, connect.NewRequest(&nodev1.RevokeNodeSSHAccessRequest{ ExternalNodeId: selectedNode.GetExternalNodeId(), + PortId: targetPortID, UserId: targetUserID, LinuxUser: targetLinuxUser, })) @@ -244,13 +247,32 @@ func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, o return nil } +func formatSSHAccessLabel(s RevokeSSHStore, sa *nodev1.SSHAccess, node *nodev1.ExternalNode) string { + userName := sa.GetUserId() + if u, err := s.GetUserByID(sa.GetUserId()); err == nil && u != nil && u.Name != "" { + userName = u.Name + } + return fmt.Sprintf("%s, linux_user: %s, port: %s", userName, sa.GetLinuxUser(), portLabelForAccess(node, sa)) +} + +func portLabelForAccess(node *nodev1.ExternalNode, sa *nodev1.SSHAccess) string { + for _, p := range node.GetPorts() { + if p.GetPortId() == sa.GetPortId() { + return register.FormatPortLabel(p) + } + } + if sa.GetPortId() != "" { + return sa.GetPortId() + } + return "unknown" +} + // resolveUserID resolves idOrEmail to a Brev user ID using org members when it looks like an email. func resolveUserID(s RevokeSSHStore, orgID string, idOrEmail string) (string, error) { if idOrEmail == "" { return "", fmt.Errorf("user is required") } - // Search by ID if !strings.Contains(idOrEmail, "@") { u, err := s.GetUserByID(idOrEmail) if err == nil && u != nil { @@ -259,7 +281,6 @@ func resolveUserID(s RevokeSSHStore, orgID string, idOrEmail string) (string, er return idOrEmail, nil } - // Search by email attachments, err := s.GetOrgRoleAttachments(orgID) if err != nil { return "", fmt.Errorf("failed to list org members: %w", err) @@ -267,7 +288,6 @@ func resolveUserID(s RevokeSSHStore, orgID string, idOrEmail string) (string, er for _, a := range attachments { u, err := s.GetUserByID(a.Subject) if err != nil { - // Ignore error and continue continue } if u != nil && strings.EqualFold(u.Email, idOrEmail) { @@ -275,6 +295,5 @@ func resolveUserID(s RevokeSSHStore, orgID string, idOrEmail string) (string, er } } - // No match found return "", fmt.Errorf("no org member found with email %q", idOrEmail) } diff --git a/pkg/cmd/revokessh/revokessh_test.go b/pkg/cmd/revokessh/revokessh_test.go index ef031f46b..d9743251e 100644 --- a/pkg/cmd/revokessh/revokessh_test.go +++ b/pkg/cmd/revokessh/revokessh_test.go @@ -239,7 +239,10 @@ func Test_runRevokeSSH_RevokeAccess(t *testing.T) { ExternalNodeId: "unode_abc", Name: "My Spark", SshAccess: []*nodev1.SSHAccess{ - {UserId: "user_2", LinuxUser: "ubuntu"}, + {UserId: "user_2", LinuxUser: "ubuntu", PortId: "port_ssh"}, + }, + Ports: []*nodev1.Port{ + {PortId: "port_ssh", Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, PortNumber: 22, ServerPort: 41920}, }, }, }, @@ -274,6 +277,9 @@ func Test_runRevokeSSH_RevokeAccess(t *testing.T) { if gotReq.GetLinuxUser() != "ubuntu" { t.Errorf("expected linux user ubuntu, got %s", gotReq.GetLinuxUser()) } + if gotReq.GetPortId() != "port_ssh" { + t.Errorf("expected port ID port_ssh, got %s", gotReq.GetPortId()) + } } func Test_runRevokeSSH_RPCFailure(t *testing.T) { @@ -296,7 +302,10 @@ func Test_runRevokeSSH_RPCFailure(t *testing.T) { ExternalNodeId: "unode_abc", Name: "My Spark", SshAccess: []*nodev1.SSHAccess{ - {UserId: "user_3", LinuxUser: "testuser"}, + {UserId: "user_3", LinuxUser: "testuser", PortId: "port_ssh"}, + }, + Ports: []*nodev1.Port{ + {PortId: "port_ssh", Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, PortNumber: 22}, }, }, }, From 3e74c9e3d10117b26121369e0bb510cba3553523 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 20 May 2026 12:48:59 -0700 Subject: [PATCH 5/6] lint --- pkg/cmd/grantssh/grantssh.go | 2 +- pkg/cmd/register/sshkeys.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index a0d705acd..74971700d 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -253,7 +253,7 @@ func resolveGrantPort(ctx context.Context, t *terminal.Terminal, opts grantSSHOp } selected, selErr := register.SelectPortFromList(ctx, t, deps.prompter, ports) if selErr != nil { - return "", "", selErr + return "", "", breverrors.WrapAndTrace(selErr) } return selected.GetPortId(), register.FormatPortLabel(selected), nil } diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index f865070a8..1188766dc 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -55,7 +55,6 @@ func SelectNodeFromList(ctx context.Context, t *terminal.Terminal, prompter term return selected, nil } - func SelectPortFromList(_ context.Context, t *terminal.Terminal, prompter terminal.Selector, ports []*nodev1.Port) (*nodev1.Port, error) { if len(ports) == 0 { return nil, fmt.Errorf("no ports to select") From e104e514fd90c260b92ebe1297528ad40c8f4877 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 20 May 2026 12:56:02 -0700 Subject: [PATCH 6/6] go mod tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 6e88264ce..bd3acbbf1 100644 --- a/go.mod +++ b/go.mod @@ -150,7 +150,7 @@ require ( golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.42.0 - golang.org/x/term v0.41.0 + golang.org/x/term v0.41.0 // indirect golang.org/x/time v0.12.0 // indirect google.golang.org/protobuf v1.36.11 gopkg.in/inf.v0 v0.9.1 // indirect