diff --git a/protocol.go b/protocol.go index 24eaac3..b666057 100644 --- a/protocol.go +++ b/protocol.go @@ -6,7 +6,6 @@ import ( "strconv" "strings" - "github.com/gorilla/websocket" "github.com/vmihailenco/msgpack/v5" ) @@ -99,16 +98,15 @@ type ErrorInfo struct { } // ParseProtocolMessage decodes a WebSocket frame into a ProtocolMessage. -// For text frames (JSON) and binary frames (msgpack). +// The format is taken from the connection's "format" query parameter: +// "msgpack" is decoded as msgpack, everything else (including "json" and an +// absent value) is decoded as JSON. // Returns the parsed message. On failure, returns a message with Action=-1. -func ParseProtocolMessage(data []byte, messageType int) ProtocolMessage { - if messageType == websocket.TextMessage { - return parseJSON(data) - } - if messageType == websocket.BinaryMessage { +func ParseProtocolMessage(data []byte, format string) ProtocolMessage { + if format == "msgpack" { return parseMsgpack(data) } - return ProtocolMessage{Action: -1} + return parseJSON(data) } func parseJSON(data []byte) ProtocolMessage { diff --git a/proxy_test.go b/proxy_test.go index c3cd439..2181312 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -315,7 +315,7 @@ func TestWsPassthrough(t *testing.T) { defer deleteSession(t, controlURL, session.SessionID) // Connect through proxy - conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "key=test.key:secret") + conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json", "key=test.key:secret") defer conn.Close() // Should receive CONNECTED from mock upstream @@ -467,7 +467,7 @@ func TestWsFrameSuppression(t *testing.T) { }) defer deleteSession(t, controlURL, session.SessionID) - conn := connectWs(t, fmt.Sprintf("localhost:%d", port)) + conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json") defer conn.Close() // Read CONNECTED @@ -515,7 +515,7 @@ func TestWsInjectAndClose(t *testing.T) { }) defer deleteSession(t, controlURL, session.SessionID) - conn := connectWs(t, fmt.Sprintf("localhost:%d", port)) + conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json") defer conn.Close() // Should receive the injected DISCONNECTED message @@ -672,7 +672,7 @@ func TestWsSuppressOnwards(t *testing.T) { }) defer deleteSession(t, controlURL, session.SessionID) - conn := connectWs(t, fmt.Sprintf("localhost:%d", port)) + conn := connectWs(t, fmt.Sprintf("localhost:%d", port), "format=json") defer conn.Close() // Read CONNECTED (arrives before suppress_onwards fires) @@ -854,7 +854,7 @@ func TestRuleTimesLimit(t *testing.T) { func TestProtocolParseJSON(t *testing.T) { msg := `{"action":10,"channel":"test-channel"}` - pm := ParseProtocolMessage([]byte(msg), websocket.TextMessage) + pm := ParseProtocolMessage([]byte(msg), "json") if pm.Action != ActionAttach { t.Fatalf("expected action %d, got %d", ActionAttach, pm.Action) } @@ -865,7 +865,7 @@ func TestProtocolParseJSON(t *testing.T) { func TestProtocolParseJSONWithError(t *testing.T) { msg := `{"action":9,"error":{"code":40142,"statusCode":401,"message":"Token expired"}}` - pm := ParseProtocolMessage([]byte(msg), websocket.TextMessage) + pm := ParseProtocolMessage([]byte(msg), "json") if pm.Action != ActionError { t.Fatalf("expected action %d, got %d", ActionError, pm.Action) } @@ -885,7 +885,7 @@ func TestProtocolParseMsgpack(t *testing.T) { "channel": "test", } data := mustMarshalMsgpack(t, raw) - pm := ParseProtocolMessage(data, websocket.BinaryMessage) + pm := ParseProtocolMessage(data, "msgpack") if pm.Action != ActionAttach { t.Fatalf("expected action %d, got %d", ActionAttach, pm.Action) } diff --git a/ws_proxy.go b/ws_proxy.go index 344a848..dae222e 100644 --- a/ws_proxy.go +++ b/ws_proxy.go @@ -28,6 +28,10 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) { } } + // Protocol format ("json" or "msgpack") is declared by the SDK via the + // "format" query param and is fixed for the lifetime of the connection. + format := queryParams["format"] + // Create WsConnection and register it wc := NewWsConnection(0) session.AddWsConn(wc) @@ -142,13 +146,13 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) { // server → client relay go func() { defer wg.Done() - relayFrames(session, wc, serverConn, clientConn, "server_to_client", "ws_frame_to_client") + relayFrames(session, wc, serverConn, clientConn, "server_to_client", "ws_frame_to_client", format) }() // client → server relay go func() { defer wg.Done() - relayFrames(session, wc, clientConn, serverConn, "client_to_server", "ws_frame_to_server") + relayFrames(session, wc, clientConn, serverConn, "client_to_server", "ws_frame_to_server", format) }() wg.Wait() @@ -163,7 +167,7 @@ func HandleWsProxy(session *Session, w http.ResponseWriter, r *http.Request) { } // relayFrames reads frames from src and writes to dst, applying rules. -func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, direction, matchType string) { +func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, direction, matchType, format string) { for { if wc.IsClosed() { return @@ -202,11 +206,11 @@ func relayFrames(session *Session, wc *WsConnection, src, dst *websocket.Conn, d } // Parse protocol message for rule matching and logging - pm := ParseProtocolMessage(data, msgType) + pm := ParseProtocolMessage(data, format) // Log the frame (as JSON for readability, even if binary) var logMsg json.RawMessage - if msgType == websocket.TextMessage { + if format != "msgpack" { logMsg = json.RawMessage(data) } else { // For binary frames, log the parsed summary