Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strconv"
"strings"

"github.com/gorilla/websocket"
"github.com/vmihailenco/msgpack/v5"
)

Expand Down Expand Up @@ -99,16 +98,15 @@ type ErrorInfo struct {
}

// ParseProtocolMessage decodes a WebSocket frame into a ProtocolMessage.
// For text frames (JSON) and binary frames (msgpack).
// The format is taken from the connection's "format" query parameter:
// "msgpack" is decoded as msgpack, everything else (including "json" and an
// absent value) is decoded as JSON.
// Returns the parsed message. On failure, returns a message with Action=-1.
func ParseProtocolMessage(data []byte, messageType int) ProtocolMessage {
if messageType == websocket.TextMessage {
return parseJSON(data)
}
if messageType == websocket.BinaryMessage {
func ParseProtocolMessage(data []byte, format string) ProtocolMessage {
if format == "msgpack" {
return parseMsgpack(data)
}
return ProtocolMessage{Action: -1}
return parseJSON(data)
}

func parseJSON(data []byte) ProtocolMessage {
Expand Down
14 changes: 7 additions & 7 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func TestWsPassthrough(t *testing.T) {
defer deleteSession(t, controlURL, session.SessionID)

// Connect through proxy
conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "key=test.key:secret")
conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json", "key=test.key:secret")
defer conn.Close()

// Should receive CONNECTED from mock upstream
Expand Down Expand Up @@ -467,7 +467,7 @@ func TestWsFrameSuppression(t *testing.T) {
})
defer deleteSession(t, controlURL, session.SessionID)

conn := connectWs(t, fmt.Sprintf("localhost:%d", port))
conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json")
defer conn.Close()

// Read CONNECTED
Expand Down Expand Up @@ -515,7 +515,7 @@ func TestWsInjectAndClose(t *testing.T) {
})
defer deleteSession(t, controlURL, session.SessionID)

conn := connectWs(t, fmt.Sprintf("localhost:%d", port))
conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json")
defer conn.Close()

// Should receive the injected DISCONNECTED message
Expand Down Expand Up @@ -672,7 +672,7 @@ func TestWsSuppressOnwards(t *testing.T) {
})
defer deleteSession(t, controlURL, session.SessionID)

conn := connectWs(t, fmt.Sprintf("localhost:%d", port))
conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json")
defer conn.Close()

// Read CONNECTED (arrives before suppress_onwards fires)
Expand Down Expand Up @@ -854,7 +854,7 @@ func TestRuleTimesLimit(t *testing.T) {

func TestProtocolParseJSON(t *testing.T) {
msg := `{"action":10,"channel":"test-channel"}`
pm := ParseProtocolMessage([]byte(msg), websocket.TextMessage)
pm := ParseProtocolMessage([]byte(msg), "json")
if pm.Action != ActionAttach {
t.Fatalf("expected action %d, got %d", ActionAttach, pm.Action)
}
Expand All @@ -865,7 +865,7 @@ func TestProtocolParseJSON(t *testing.T) {

func TestProtocolParseJSONWithError(t *testing.T) {
msg := `{"action":9,"error":{"code":40142,"statusCode":401,"message":"Token expired"}}`
pm := ParseProtocolMessage([]byte(msg), websocket.TextMessage)
pm := ParseProtocolMessage([]byte(msg), "json")
if pm.Action != ActionError {
t.Fatalf("expected action %d, got %d", ActionError, pm.Action)
}
Expand All @@ -885,7 +885,7 @@ func TestProtocolParseMsgpack(t *testing.T) {
"channel": "test",
}
data := mustMarshalMsgpack(t, raw)
pm := ParseProtocolMessage(data, websocket.BinaryMessage)
pm := ParseProtocolMessage(data, "msgpack")
if pm.Action != ActionAttach {
t.Fatalf("expected action %d, got %d", ActionAttach, pm.Action)
}
Expand Down
14 changes: 9 additions & 5 deletions ws_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) {
}
}

// Protocol format ("json" or "msgpack") is declared by the SDK via the
// "format" query param and is fixed for the lifetime of the connection.
format := queryParams["format"]

// Create WsConnection and register it
wc := NewWsConnection(0)
session.AddWsConn(wc)
Expand Down Expand Up @@ -142,13 +146,13 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) {
// server → client relay
go func() {
defer wg.Done()
relayFrames(session, wc, serverConn, clientConn, "server_to_client", "ws_frame_to_client")
relayFrames(session, wc, serverConn, clientConn, "server_to_client", "ws_frame_to_client", format)
}()

// client → server relay
go func() {
defer wg.Done()
relayFrames(session, wc, clientConn, serverConn, "client_to_server", "ws_frame_to_server")
relayFrames(session, wc, clientConn, serverConn, "client_to_server", "ws_frame_to_server", format)
}()

wg.Wait()
Expand All @@ -163,7 +167,7 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) {
}

// relayFrames reads frames from src and writes to dst, applying rules.
func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, direction, matchType string) {
func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, direction, matchType, format string) {
for {
if wc.IsClosed() {
return
Expand Down Expand Up @@ -202,11 +206,11 @@ func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, d
}

// Parse protocol message for rule matching and logging
pm := ParseProtocolMessage(data, msgType)
pm := ParseProtocolMessage(data, format)

// Log the frame (as JSON for readability, even if binary)
var logMsg json.RawMessage
if msgType == websocket.TextMessage {
if format != "msgpack" {
logMsg = json.RawMessage(data)
} else {
// For binary frames, log the parsed summary
Expand Down
Loading