diff --git a/.gitignore b/.gitignore index 83c65e1..0e425d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ v2ray-plugin* /bin/ /.idea/ +/.gocache/ +/.gomodcache/ diff --git a/README.md b/README.md index d22d3df..c7b7192 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,68 @@ On your client ss-local -c config.json -p 443 --plugin v2ray-plugin --plugin-opts "mode=quic;host=mydomain.me" ``` +### SIP003U UDP forwarding + +SIP003U support is split between shadowsocks-libev and this plugin: + +* shadowsocks-libev `--plugin-mode` decides whether UDP relay traffic is routed through the plugin port. +* v2ray-plugin `udpMode` decides whether this plugin starts its native UDP relay and which UDP transport it uses. + +The TCP transport still follows `mode`. Enabling UDP forwarding does not change an existing TCP deployment unless both `mode=websocket` and `udpMode=websocket` are used on the server, where the plugin owns the public WebSocket listener and routes TCP and UDP by path. + +#### UDP over QUIC Datagram + +On your server + +```sh +ss-server -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "server;tls;host=mydomain.me;udpMode=quic" +``` + +On your client + +```sh +ss-local -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "tls;host=mydomain.me;udpMode=quic" +``` + +To keep TCP on WebSocket while sending UDP through QUIC Datagram, leave `mode` unset or set it to `websocket`: + +```sh +ss-server -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "server;tls;host=mydomain.me;mode=websocket;udpMode=quic" +ss-local -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "tls;host=mydomain.me;mode=websocket;udpMode=quic" +``` + +#### UDP over WebSocket + +Use `udpMode=websocket` when UDP traffic also needs to pass through an HTTP/WebSocket proxy path, for example a regular Cloudflare proxied hostname. The TCP WebSocket path remains controlled by `path`; the UDP WebSocket path is controlled by `udpPath` and defaults to `/ray-udp`. + +On your server + +```sh +ss-server -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "server;tls;host=mydomain.me;mode=websocket;path=/ray;udpMode=websocket;udpPath=/ray-udp" +``` + +On your client + +```sh +ss-local -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "tls;host=mydomain.me;mode=websocket;path=/ray;udpMode=websocket;udpPath=/ray-udp" +``` + +In this mode the server-side plugin listens on the public TCP port, accepts `/ray-udp` itself, and reverse-proxies the normal TCP WebSocket path to an internal loopback v2ray-core listener. `path` and `udpPath` must both start with `/`, must not contain `?` or `#`, and must be different. + +`udpMode=websocket` can also be combined with `mode=quic`. In that layout, TCP relay traffic uses v2ray-core QUIC on UDP while UDP relay traffic uses the plugin's WebSocket listener on TCP. Because TCP and UDP sockets are separate, the same numeric port can be reused without the public WebSocket reverse-proxy layer. + +Each encrypted Shadowsocks UDP packet is sent as one WebSocket binary message. The plugin preserves packet boundaries and keeps Shadowsocks UDP payloads opaque; it does not parse, decrypt, modify, coalesce, or fragment UDP payloads. + +`udpTimeout` controls the plugin's own UDP flow table and defaults to 30 seconds: + +```sh +ss-local -c config.json -p 443 -u --plugin v2ray-plugin --plugin-mode tcp_and_udp --plugin-opts "tls;host=mydomain.me;udpMode=quic;udpTimeout=60" +``` + +This timeout is separate from shadowsocks-libev's internal UDP relay timeout. The implementation does not add separate UDP local or remote port options and does not fragment oversized UDP datagrams; oversized packets are dropped and logged. Certificate options are shared with the TCP TLS path, so certificate mismatch errors usually mean `host`, `cert`, `certRaw`, or `key` differs between client and server. If TCP works but UDP bypasses the plugin, check that shadowsocks-libev was started with `--plugin-mode tcp_and_udp` or another UDP-capable plugin mode. + +`udpMode=quic` uses QUIC Datagram and needs end-to-end UDP reachability to the plugin. It will not work through a regular Cloudflare orange-cloud HTTP proxy because Cloudflare terminates QUIC/HTTP3 at the edge and speaks HTTP to the origin. `udpMode=websocket` is Cloudflare-compatible, but it carries UDP packets over a reliable WebSocket/TCP stream, so packet loss can cause head-of-line blocking. + ### Issue a cert for TLS and QUIC `v2ray-plugin` will look for TLS certificates signed by [acme.sh](https://github.com/acmesh-official/acme.sh) by default. diff --git a/go.mod b/go.mod index 8fb51f5..5c5b41e 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.23.2 require ( github.com/golang/protobuf v1.5.4 + github.com/quic-go/quic-go v0.48.1 github.com/v2fly/v2ray-core/v5 v5.22.0 google.golang.org/protobuf v1.35.1 ) @@ -19,7 +20,6 @@ require ( github.com/miekg/dns v1.1.62 // indirect github.com/onsi/ginkgo/v2 v2.21.0 // indirect github.com/pires/go-proxyproto v0.8.0 // indirect - github.com/quic-go/quic-go v0.48.1 // indirect go.uber.org/mock v0.5.0 // indirect golang.org/x/crypto v0.28.0 // indirect golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect diff --git a/main.go b/main.go index 51d1970..2c740b7 100644 --- a/main.go +++ b/main.go @@ -45,12 +45,15 @@ var ( remoteAddr = flag.String("remoteAddr", "127.0.0.1", "remote address to forward.") remotePort = flag.String("remotePort", "1080", "remote port to forward.") path = flag.String("path", "/", "URL path for websocket.") + udpPath = flag.String("udpPath", "/ray-udp", "URL path for UDP over websocket.") host = flag.String("host", "cloudfront.com", "Hostname for server.") tlsEnabled = flag.Bool("tls", false, "Enable TLS.") cert = flag.String("cert", "", "Path to TLS certificate file. Overrides certRaw. Default: ~/.acme.sh/{host}/fullchain.cer") certRaw = flag.String("certRaw", "", "Raw TLS certificate content. Intended only for Android.") key = flag.String("key", "", "(server) Path to TLS key file. Default: ~/.acme.sh/{host}/{host}.key") mode = flag.String("mode", "websocket", "Transport mode: websocket, quic (enforced tls).") + udpMode = flag.String("udpMode", "", "UDP transport mode: quic, websocket.") + udpTimeout = flag.Int("udpTimeout", 30, "UDP relay timeout in seconds.") mux = flag.Int("mux", 1, "Concurrent multiplexed connections (websocket client mode only).") server = flag.Bool("server", false, "Run in server mode") logLevel = flag.String("loglevel", "", "loglevel for v2ray: debug, info, warning (default), error, none.") @@ -104,24 +107,56 @@ func parseLocalAddr(localAddr string) []string { return strings.Split(localAddr, "|") } -func generateConfig() (*core.Config, error) { - lport, err := net.PortFromString(*localPort) - if err != nil { - return nil, newError("invalid localPort:", *localPort).Base(err) +func applyUDPOptions(opts Args) error { + if c, b := opts.Get("udpMode"); b { + *udpMode = c } - rport, err := strconv.ParseUint(*remotePort, 10, 32) - if err != nil { - return nil, newError("invalid remotePort:", *remotePort).Base(err) + if c, b := opts.Get("udpPath"); b { + *udpPath = c } - outboundProxy := serial.ToTypedMessage(&freedom.Config{ - DestinationOverride: &freedom.DestinationOverride{ - Server: &protocol.ServerEndpoint{ - Address: net.NewIPOrDomain(net.ParseAddress(*remoteAddr)), - Port: uint32(rport), - }, - }, - }) + if c, b := opts.Get("udpTimeout"); b { + i, err := strconv.Atoi(c) + if err != nil { + return newError("invalid udpTimeout:", c).Base(err) + } + *udpTimeout = i + } + return validateUDPOptions() +} +func validateUDPOptions() error { + switch *udpMode { + case "", "quic", "websocket": + default: + return newError("unsupported udpMode:", *udpMode) + } + if *udpTimeout <= 0 { + return newError("invalid udpTimeout:", *udpTimeout) + } + if *udpMode == "websocket" { + switch *mode { + case "websocket", "quic": + default: + return newError("udpMode=websocket requires mode=websocket or mode=quic") + } + if *mode == "websocket" && !isValidWebSocketPath(*path) { + return newError("invalid websocket path:", *path) + } + if !isValidWebSocketPath(*udpPath) { + return newError("invalid udp websocket path:", *udpPath) + } + if *mode == "websocket" && *udpPath == *path { + return newError("udpPath must differ from path:", *udpPath) + } + } + return nil +} + +func isValidWebSocketPath(p string) bool { + return strings.HasPrefix(p, "/") && !strings.ContainsAny(p, "?#") +} + +func generateTCPStreamConfig() (internet.StreamConfig, bool, error) { var transportSettings proto.Message var connectionReuse bool switch *mode { @@ -141,7 +176,7 @@ func generateConfig() (*core.Config, error) { } *tlsEnabled = true default: - return nil, newError("unsupported mode:", *mode) + return internet.StreamConfig{}, false, newError("unsupported mode:", *mode) } streamConfig := internet.StreamConfig{ @@ -170,9 +205,10 @@ func generateConfig() (*core.Config, error) { *cert = fmt.Sprintf("%s/.acme.sh/%s/fullchain.cer", homeDir(), *host) logWarn("No TLS cert specified, trying", *cert) } + var err error certificate.Certificate, err = readCertificate() if err != nil { - return nil, newError("failed to read cert").Base(err) + return internet.StreamConfig{}, false, newError("failed to read cert").Base(err) } if *key == "" { *key = fmt.Sprintf("%[1]s/.acme.sh/%[2]s/%[2]s.key", homeDir(), *host) @@ -180,14 +216,15 @@ func generateConfig() (*core.Config, error) { } certificate.Key, err = filesystem.ReadFile(*key) if err != nil { - return nil, newError("failed to read key file").Base(err) + return internet.StreamConfig{}, false, newError("failed to read key file").Base(err) } tlsConfig.Certificate = []*tls.Certificate{&certificate} } else if *cert != "" || *certRaw != "" { certificate := tls.Certificate{Usage: tls.Certificate_AUTHORITY_VERIFY} + var err error certificate.Certificate, err = readCertificate() if err != nil { - return nil, newError("failed to read cert").Base(err) + return internet.StreamConfig{}, false, newError("failed to read cert").Base(err) } tlsConfig.Certificate = []*tls.Certificate{&certificate} } @@ -195,6 +232,32 @@ func generateConfig() (*core.Config, error) { streamConfig.SecuritySettings = []*anypb.Any{serial.ToTypedMessage(&tlsConfig)} } + return streamConfig, connectionReuse, nil +} + +func generateConfig() (*core.Config, error) { + lport, err := net.PortFromString(*localPort) + if err != nil { + return nil, newError("invalid localPort:", *localPort).Base(err) + } + rport, err := strconv.ParseUint(*remotePort, 10, 32) + if err != nil { + return nil, newError("invalid remotePort:", *remotePort).Base(err) + } + outboundProxy := serial.ToTypedMessage(&freedom.Config{ + DestinationOverride: &freedom.DestinationOverride{ + Server: &protocol.ServerEndpoint{ + Address: net.NewIPOrDomain(net.ParseAddress(*remoteAddr)), + Port: uint32(rport), + }, + }, + }) + + streamConfig, connectionReuse, err := generateTCPStreamConfig() + if err != nil { + return nil, err + } + apps := []*anypb.Any{ serial.ToTypedMessage(&dispatcher.Config{}), serial.ToTypedMessage(&proxyman.InboundConfig{}), @@ -258,6 +321,77 @@ func generateConfig() (*core.Config, error) { } } +type pluginServer struct { + tcp core.Server + udp *udpRelay +} + +func (s *pluginServer) Start() error { + if err := s.tcp.Start(); err != nil { + return err + } + if err := s.udp.Start(); err != nil { + if closeErr := s.tcp.Close(); closeErr != nil { + logWarn(closeErr.Error()) + } + return err + } + return nil +} + +type combinedWebSocketServer struct { + tcp core.Server + udp *udpRelay + router *webSocketRouter +} + +func (s *combinedWebSocketServer) Start() error { + if err := s.tcp.Start(); err != nil { + return err + } + if err := s.udp.Start(); err != nil { + if closeErr := s.tcp.Close(); closeErr != nil { + logWarn(closeErr.Error()) + } + return err + } + if err := s.router.Start(); err != nil { + if closeErr := s.udp.Close(); closeErr != nil { + logWarn(closeErr.Error()) + } + if closeErr := s.tcp.Close(); closeErr != nil { + logWarn(closeErr.Error()) + } + return err + } + return nil +} + +func (s *combinedWebSocketServer) Close() error { + var closeErr error + if err := s.router.Close(); err != nil { + closeErr = err + } + if err := s.udp.Close(); err != nil && closeErr == nil { + closeErr = err + } + if err := s.tcp.Close(); err != nil && closeErr == nil { + closeErr = err + } + return closeErr +} + +func (s *pluginServer) Close() error { + var closeErr error + if err := s.udp.Close(); err != nil { + closeErr = err + } + if err := s.tcp.Close(); err != nil && closeErr == nil { + closeErr = err + } + return closeErr +} + func startV2Ray() (core.Server, error) { opts, err := parseEnv() @@ -342,9 +476,19 @@ func startV2Ray() (core.Server, error) { } } + if err := applyUDPOptions(opts); err != nil { + return nil, err + } + if *vpn { registerControlFunc() } + } else if err := validateUDPOptions(); err != nil { + return nil, err + } + + if *server && *mode == "websocket" && *udpMode == "websocket" { + return startCombinedWebSocketServer() } config, err := generateConfig() @@ -355,7 +499,11 @@ func startV2Ray() (core.Server, error) { if err != nil { return nil, newError("failed to create v2ray instance").Base(err) } - return instance, nil + udpRelay, err := newUDPRelayFromOptions() + if err != nil { + return nil, err + } + return &pluginServer{tcp: instance, udp: udpRelay}, nil } func printCoreVersion() { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..69e9b0c --- /dev/null +++ b/main_test.go @@ -0,0 +1,1147 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +const testWebSocketUDPPath = "/ray-udp" + +var testWebSocketDialer = websocket.Dialer{} + +func withTCPOptionState(t *testing.T) { + t.Helper() + oldFastOpen := *fastOpen + oldLocalAddr := *localAddr + oldLocalPort := *localPort + oldRemoteAddr := *remoteAddr + oldRemotePort := *remotePort + oldPath := *path + oldUDPPath := *udpPath + oldHost := *host + oldTLSEnabled := *tlsEnabled + oldCert := *cert + oldCertRaw := *certRaw + oldKey := *key + oldMode := *mode + oldMux := *mux + oldServer := *server + oldFWMark := *fwmark + + *fastOpen = false + *localAddr = "127.0.0.1" + *localPort = "1984" + *remoteAddr = "127.0.0.1" + *remotePort = "1080" + *path = "/" + *udpPath = "/ray-udp" + *host = "cloudfront.com" + *tlsEnabled = false + *cert = "" + *certRaw = "" + *key = "" + *mode = "websocket" + *mux = 1 + *server = false + *fwmark = 0 + + t.Cleanup(func() { + *fastOpen = oldFastOpen + *localAddr = oldLocalAddr + *localPort = oldLocalPort + *remoteAddr = oldRemoteAddr + *remotePort = oldRemotePort + *path = oldPath + *udpPath = oldUDPPath + *host = oldHost + *tlsEnabled = oldTLSEnabled + *cert = oldCert + *certRaw = oldCertRaw + *key = oldKey + *mode = oldMode + *mux = oldMux + *server = oldServer + *fwmark = oldFWMark + }) +} + +func withUDPOptionState(t *testing.T, mode string, timeout int) { + t.Helper() + oldMode := *udpMode + oldTimeout := *udpTimeout + *udpMode = mode + *udpTimeout = timeout + t.Cleanup(func() { + *udpMode = oldMode + *udpTimeout = oldTimeout + }) +} + +func TestApplyUDPOptions(t *testing.T) { + withUDPOptionState(t, "", 30) + + opts := Args{ + "udpMode": []string{"quic"}, + "udpTimeout": []string{"45"}, + } + + if err := applyUDPOptions(opts); err != nil { + t.Fatalf("applyUDPOptions returned error: %v", err) + } + if *udpMode != "quic" { + t.Fatalf("udpMode = %q, want quic", *udpMode) + } + if *udpTimeout != 45 { + t.Fatalf("udpTimeout = %d, want 45", *udpTimeout) + } +} + +func TestGenerateTCPStreamConfigKeepsDefaultWebSocket(t *testing.T) { + withTCPOptionState(t) + + streamConfig, connectionReuse, err := generateTCPStreamConfig() + if err != nil { + t.Fatalf("generateTCPStreamConfig returned error: %v", err) + } + if streamConfig.ProtocolName != "websocket" { + t.Fatalf("ProtocolName = %q, want websocket", streamConfig.ProtocolName) + } + if !connectionReuse { + t.Fatal("connectionReuse = false, want true for default mux") + } + if *tlsEnabled { + t.Fatal("tlsEnabled = true, want false for default websocket mode") + } +} + +func TestGenerateTCPStreamConfigIgnoresUDPMode(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "quic", 30) + + streamConfig, _, err := generateTCPStreamConfig() + if err != nil { + t.Fatalf("generateTCPStreamConfig returned error: %v", err) + } + if streamConfig.ProtocolName != "websocket" { + t.Fatalf("ProtocolName = %q, want websocket", streamConfig.ProtocolName) + } +} + +func TestGenerateTCPStreamConfigQUICEnablesTLS(t *testing.T) { + withTCPOptionState(t) + *mode = "quic" + + streamConfig, connectionReuse, err := generateTCPStreamConfig() + if err != nil { + t.Fatalf("generateTCPStreamConfig returned error: %v", err) + } + if streamConfig.ProtocolName != "quic" { + t.Fatalf("ProtocolName = %q, want quic", streamConfig.ProtocolName) + } + if connectionReuse { + t.Fatal("connectionReuse = true, want false for quic mode") + } + if !*tlsEnabled { + t.Fatal("tlsEnabled = false, want true for quic mode") + } +} + +func TestValidateUDPOptionsAllowsDefaultDisabledMode(t *testing.T) { + withUDPOptionState(t, "", 30) + + if err := validateUDPOptions(); err != nil { + t.Fatalf("validateUDPOptions returned error: %v", err) + } +} + +func TestValidateUDPOptionsRejectsUnsupportedMode(t *testing.T) { + withUDPOptionState(t, "h3", 30) + + if err := validateUDPOptions(); err == nil { + t.Fatal("validateUDPOptions returned nil, want unsupported mode error") + } +} + +func TestValidateUDPOptionsAllowsWebSocketMode(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "websocket", 30) + *path = "/ray" + *udpPath = "/ray-udp" + + if err := validateUDPOptions(); err != nil { + t.Fatalf("validateUDPOptions returned error: %v", err) + } +} + +func TestValidateUDPOptionsAllowsWebSocketUDPWithQUICTCPMode(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "websocket", 30) + *mode = "quic" + *path = "unused-quic-tcp-path" + *udpPath = "/ray-udp" + + if err := validateUDPOptions(); err != nil { + t.Fatalf("validateUDPOptions returned error: %v", err) + } +} + +func TestValidateUDPOptionsIgnoresTCPWebSocketPathForQUICTCPMode(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "websocket", 30) + *mode = "quic" + *path = "not/a/websocket/path" + *udpPath = "/ray-udp" + + if err := validateUDPOptions(); err != nil { + t.Fatalf("validateUDPOptions returned error: %v", err) + } +} + +func TestApplyUDPOptionsReadsWebSocketUDPPath(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "", 30) + *path = "/ray" + + opts := Args{ + "udpMode": []string{"websocket"}, + "udpPath": []string{"/ray-udp"}, + } + if err := applyUDPOptions(opts); err != nil { + t.Fatalf("applyUDPOptions returned error: %v", err) + } + if *udpMode != "websocket" { + t.Fatalf("udpMode = %q, want websocket", *udpMode) + } + if *udpPath != "/ray-udp" { + t.Fatalf("udpPath = %q, want /ray-udp", *udpPath) + } +} + +func TestValidateUDPOptionsRejectsSameWebSocketPath(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "websocket", 30) + *path = "/ray" + *udpPath = "/ray" + + if err := validateUDPOptions(); err == nil { + t.Fatal("validateUDPOptions returned nil, want same path error") + } +} + +func TestValidateUDPOptionsRejectsInvalidUDPPath(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "websocket", 30) + *path = "/ray" + *udpPath = "ray-udp" + + if err := validateUDPOptions(); err == nil { + t.Fatal("validateUDPOptions returned nil, want invalid udp path error") + } +} + +func TestApplyUDPOptionsRejectsInvalidTimeout(t *testing.T) { + withUDPOptionState(t, "", 30) + + opts := Args{"udpTimeout": []string{"soon"}} + if err := applyUDPOptions(opts); err == nil { + t.Fatal("applyUDPOptions returned nil, want invalid timeout error") + } +} + +func TestValidateUDPOptionsRejectsNonPositiveTimeout(t *testing.T) { + withUDPOptionState(t, "quic", 0) + + if err := validateUDPOptions(); err == nil { + t.Fatal("validateUDPOptions returned nil, want invalid timeout error") + } +} + +func TestNewUDPRelayFromOptionsDisabled(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "", 30) + + relay, err := newUDPRelayFromOptions() + if err != nil { + t.Fatalf("newUDPRelayFromOptions returned error: %v", err) + } + if relay != nil { + t.Fatal("newUDPRelayFromOptions returned relay, want nil") + } +} + +func TestNewUDPRelayFromOptionsEnabled(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "quic", 45) + *server = true + *localAddr = "127.0.0.1|::1" + *localPort = "1984" + *remoteAddr = "127.0.0.1" + *remotePort = "1080" + *host = "example.com" + + relay, err := newUDPRelayFromOptions() + if err != nil { + t.Fatalf("newUDPRelayFromOptions returned error: %v", err) + } + if relay == nil { + t.Fatal("newUDPRelayFromOptions returned nil, want relay") + } + if !relay.config.Server { + t.Fatal("relay.config.Server = false, want true") + } + if relay.config.LocalAddr != "127.0.0.1|::1" { + t.Fatalf("relay.config.LocalAddr = %q, want 127.0.0.1|::1", relay.config.LocalAddr) + } + if relay.config.Timeout != 45*time.Second { + t.Fatalf("relay.config.Timeout = %s, want 45s", relay.config.Timeout) + } +} + +func TestUDPRelayCanShareTCPPortNumber(t *testing.T) { + tcpListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen tcp returned error: %v", err) + } + defer tcpListener.Close() + + port := tcpListener.Addr().(*net.TCPAddr).Port + relay := newUDPRelay(udpRelayConfig{ + LocalAddr: "127.0.0.1", + LocalPort: strconv.Itoa(port), + Timeout: 30 * time.Second, + }) + if err := relay.Start(); err != nil { + t.Fatalf("udp relay Start returned error: %v", err) + } + if err := relay.Close(); err != nil { + t.Fatalf("udp relay Close returned error: %v", err) + } +} + +func TestUDPRelayPreservesDatagramBoundaries(t *testing.T) { + clientRelay, serverRelay, clientAddr, closeRelays := startUDPRelayPair(t, 5*time.Second) + defer closeRelays() + _ = clientRelay + _ = serverRelay + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + + first := []byte("first datagram") + second := []byte("second datagram stays separate") + writeUDPTestDatagram(t, appConn, clientAddr, first) + writeUDPTestDatagram(t, appConn, clientAddr, second) + + gotFirst := readUDPTestDatagram(t, appConn) + gotSecond := readUDPTestDatagram(t, appConn) + if string(gotFirst) != string(first) { + t.Fatalf("first response = %q, want %q", gotFirst, first) + } + if string(gotSecond) != string(second) { + t.Fatalf("second response = %q, want %q", gotSecond, second) + } +} + +func TestUDPRelayRunsAlongsideTCPModes(t *testing.T) { + for _, tcpMode := range []string{"websocket", "quic"} { + t.Run(tcpMode, func(t *testing.T) { + withTCPOptionState(t) + withUDPOptionState(t, "quic", 30) + *mode = tcpMode + + if _, _, err := generateTCPStreamConfig(); err != nil { + t.Fatalf("generateTCPStreamConfig returned error: %v", err) + } + clientRelay, serverRelay, clientAddr, closeRelays := startUDPRelayPair(t, 5*time.Second) + defer closeRelays() + _ = clientRelay + _ = serverRelay + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + + payload := []byte("udp with tcp mode " + tcpMode) + writeUDPTestDatagram(t, appConn, clientAddr, payload) + if got := readUDPTestDatagram(t, appConn); string(got) != string(payload) { + t.Fatalf("response = %q, want %q", got, payload) + } + }) + } +} + +func TestUDPRelayExpiresIdleFlows(t *testing.T) { + clientRelay, serverRelay, clientAddr, closeRelays := startUDPRelayPair(t, 200*time.Millisecond) + defer closeRelays() + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + + writeUDPTestDatagram(t, appConn, clientAddr, []byte("flow trigger")) + _ = readUDPTestDatagram(t, appConn) + + eventually(t, time.Second, func() bool { + clientRelay.mu.Lock() + defer clientRelay.mu.Unlock() + return len(clientRelay.clientFlows) == 1 + }) + eventually(t, time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 1 + }) + eventually(t, 2*time.Second, func() bool { + clientRelay.mu.Lock() + defer clientRelay.mu.Unlock() + return len(clientRelay.clientFlows) == 0 + }) + eventually(t, 2*time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 0 + }) +} + +func TestWebSocketUDPRelayPreservesDatagramBoundaries(t *testing.T) { + clientRelay, serverRelay, clientAddr, closeRelays := startWebSocketUDPRelayPair(t, 5*time.Second) + defer closeRelays() + _ = clientRelay + _ = serverRelay + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + + first := []byte("first websocket udp datagram") + second := []byte("second websocket udp datagram stays separate") + writeUDPTestDatagram(t, appConn, clientAddr, first) + writeUDPTestDatagram(t, appConn, clientAddr, second) + + if got := readUDPTestDatagram(t, appConn); string(got) != string(first) { + t.Fatalf("first response = %q, want %q", got, first) + } + if got := readUDPTestDatagram(t, appConn); string(got) != string(second) { + t.Fatalf("second response = %q, want %q", got, second) + } +} + +func TestWebSocketUDPRelayRoutesMultipleClientFlows(t *testing.T) { + clientRelay, serverRelay, clientAddr, closeRelays := startWebSocketUDPRelayPair(t, 5*time.Second) + defer closeRelays() + + firstConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket first app udp returned error: %v", err) + } + defer firstConn.Close() + secondConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket second app udp returned error: %v", err) + } + defer secondConn.Close() + + first := []byte("first websocket flow") + second := []byte("second websocket flow") + writeUDPTestDatagram(t, firstConn, clientAddr, first) + writeUDPTestDatagram(t, secondConn, clientAddr, second) + + if got := readUDPTestDatagram(t, firstConn); string(got) != string(first) { + t.Fatalf("first response = %q, want %q", got, first) + } + if got := readUDPTestDatagram(t, secondConn); string(got) != string(second) { + t.Fatalf("second response = %q, want %q", got, second) + } + eventually(t, time.Second, func() bool { + clientRelay.mu.Lock() + defer clientRelay.mu.Unlock() + return len(clientRelay.clientFlows) == 2 + }) + eventually(t, time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 2 + }) +} + +func TestWebSocketUDPRelayExpiresIdleFlows(t *testing.T) { + clientRelay, serverRelay, clientAddr, closeRelays := startWebSocketUDPRelayPair(t, 200*time.Millisecond) + defer closeRelays() + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + + writeUDPTestDatagram(t, appConn, clientAddr, []byte("websocket flow trigger")) + _ = readUDPTestDatagram(t, appConn) + + eventually(t, time.Second, func() bool { + clientRelay.mu.Lock() + defer clientRelay.mu.Unlock() + return len(clientRelay.clientFlows) == 1 + }) + eventually(t, time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 1 + }) + eventually(t, 2*time.Second, func() bool { + clientRelay.mu.Lock() + defer clientRelay.mu.Unlock() + return len(clientRelay.clientFlows) == 0 + }) + eventually(t, 2*time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 0 + }) +} + +func TestWebSocketRouterRoutesTCPAndUDPPaths(t *testing.T) { + internalHits := make(chan string, 1) + internalListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen internal tcp returned error: %v", err) + } + internalServer := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + internalHits <- req.URL.Path + w.WriteHeader(http.StatusNoContent) + })} + internalDone := make(chan struct{}) + go func() { + defer close(internalDone) + _ = internalServer.Serve(internalListener) + }() + defer func() { + internalServer.Close() + <-internalDone + }() + + echoConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket echo udp returned error: %v", err) + } + echoDone := make(chan struct{}) + go serveUDPEcho(echoConn, echoDone) + defer func() { + echoConn.Close() + <-echoDone + }() + + serverRelay := newUDPRelay(udpRelayConfig{ + Server: true, + Mode: "websocket", + RemoteAddr: "127.0.0.1", + RemotePort: strconv.Itoa(echoConn.LocalAddr().(*net.UDPAddr).Port), + Host: "127.0.0.1", + Path: "/ray-udp", + Timeout: 5 * time.Second, + }) + if err := serverRelay.Start(); err != nil { + t.Fatalf("server websocket udp relay Start returned error: %v", err) + } + defer serverRelay.Close() + + router, err := newWebSocketRouter(webSocketRouterConfig{ + LocalAddr: "127.0.0.1", + LocalPort: "0", + TCPPath: "/ray", + UDPPath: "/ray-udp", + InternalAddr: "127.0.0.1", + InternalPort: strconv.Itoa(internalListener.Addr().(*net.TCPAddr).Port), + }, serverRelay) + if err != nil { + t.Fatalf("newWebSocketRouter returned error: %v", err) + } + if err := router.Start(); err != nil { + t.Fatalf("websocket router Start returned error: %v", err) + } + defer router.Close() + + routerAddr := router.listeners[0].Addr().String() + resp, err := http.Get("http://" + routerAddr + "/ray") + if err != nil { + t.Fatalf("http.Get tcp path returned error: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("tcp path status = %d, want 204", resp.StatusCode) + } + select { + case got := <-internalHits: + if got != "/ray" { + t.Fatalf("internal path = %q, want /ray", got) + } + case <-time.After(time.Second): + t.Fatal("internal tcp proxy was not reached") + } + + clientRelay := newUDPRelay(udpRelayConfig{ + Mode: "websocket", + LocalAddr: "127.0.0.1", + LocalPort: "0", + RemoteAddr: "127.0.0.1", + RemotePort: strconv.Itoa(router.listeners[0].Addr().(*net.TCPAddr).Port), + Host: "127.0.0.1", + Path: "/ray-udp", + Timeout: 5 * time.Second, + }) + if err := clientRelay.Start(); err != nil { + t.Fatalf("client websocket udp relay Start returned error: %v", err) + } + defer clientRelay.Close() + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + payload := []byte("udp through shared tcp port") + writeUDPTestDatagram(t, appConn, clientRelay.listeners[0].LocalAddr(), payload) + if got := readUDPTestDatagram(t, appConn); string(got) != string(payload) { + t.Fatalf("udp response = %q, want %q", got, payload) + } +} + +func TestWebSocketUDPRelayRejectsTextMessages(t *testing.T) { + serverRelay, serverAddr, closeServer := startWebSocketUDPRelayServer(t, 5*time.Second) + defer closeServer() + + conn, _, err := testWebSocketDialer.Dial("ws://"+serverAddr+testWebSocketUDPPath, nil) + if err != nil { + t.Fatalf("websocket Dial returned error: %v", err) + } + defer conn.Close() + if err := conn.WriteMessage(websocket.TextMessage, []byte("not a udp packet")); err != nil { + t.Fatalf("websocket WriteMessage returned error: %v", err) + } + + eventually(t, time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 0 + }) +} + +func TestWebSocketUDPServerExpiresIdleFlowWithoutClientCleanup(t *testing.T) { + serverRelay, serverAddr, closeServer := startWebSocketUDPRelayServer(t, 200*time.Millisecond) + defer closeServer() + + conn, _, err := testWebSocketDialer.Dial("ws://"+serverAddr+testWebSocketUDPPath, nil) + if err != nil { + t.Fatalf("websocket Dial returned error: %v", err) + } + defer conn.Close() + if err := conn.WriteMessage(websocket.BinaryMessage, []byte("server idle cleanup trigger")); err != nil { + t.Fatalf("websocket WriteMessage returned error: %v", err) + } + _, got, err := conn.ReadMessage() + if err != nil { + t.Fatalf("websocket ReadMessage returned error: %v", err) + } + if string(got) != "server idle cleanup trigger" { + t.Fatalf("websocket response = %q, want trigger payload", got) + } + + eventually(t, time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 1 + }) + eventually(t, 2*time.Second, func() bool { + serverRelay.mu.Lock() + defer serverRelay.mu.Unlock() + return len(serverRelay.serverFlows) == 0 + }) +} + +func TestStandaloneWebSocketUDPRelayCanShareQUICPortNumber(t *testing.T) { + quicUDPPort, closeQUICUDPPort := reserveUDPPort(t) + defer closeQUICUDPPort() + + echoConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket echo udp returned error: %v", err) + } + echoDone := make(chan struct{}) + go serveUDPEcho(echoConn, echoDone) + defer func() { + echoConn.Close() + <-echoDone + }() + + serverRelay := newUDPRelay(udpRelayConfig{ + Server: true, + Mode: "websocket", + StandaloneWebSocketServer: true, + LocalAddr: "127.0.0.1", + LocalPort: quicUDPPort, + RemoteAddr: "127.0.0.1", + RemotePort: strconv.Itoa(echoConn.LocalAddr().(*net.UDPAddr).Port), + Host: "127.0.0.1", + Path: testWebSocketUDPPath, + Timeout: 5 * time.Second, + }) + if err := serverRelay.Start(); err != nil { + t.Fatalf("standalone websocket udp relay Start returned error: %v", err) + } + defer serverRelay.Close() + + clientRelay := newUDPRelay(udpRelayConfig{ + Mode: "websocket", + LocalAddr: "127.0.0.1", + LocalPort: "0", + RemoteAddr: "127.0.0.1", + RemotePort: quicUDPPort, + Host: "127.0.0.1", + Path: testWebSocketUDPPath, + Timeout: 5 * time.Second, + }) + if err := clientRelay.Start(); err != nil { + t.Fatalf("client websocket udp relay Start returned error: %v", err) + } + defer clientRelay.Close() + + appConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket app udp returned error: %v", err) + } + defer appConn.Close() + payload := []byte("websocket udp beside quic udp") + writeUDPTestDatagram(t, appConn, clientRelay.listeners[0].LocalAddr(), payload) + if got := readUDPTestDatagram(t, appConn); string(got) != string(payload) { + t.Fatalf("udp response = %q, want %q", got, payload) + } +} + +func TestStandaloneWebSocketUDPRelayCleansUpAfterStartFailure(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen reserve tcp returned error: %v", err) + } + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + if err := listener.Close(); err != nil { + t.Fatalf("reserved tcp Close returned error: %v", err) + } + + relay := newUDPRelay(udpRelayConfig{ + Server: true, + Mode: "websocket", + StandaloneWebSocketServer: true, + LocalAddr: "127.0.0.1|127.0.0.1", + LocalPort: port, + RemoteAddr: "127.0.0.1", + RemotePort: "9", + Host: "127.0.0.1", + Path: testWebSocketUDPPath, + Timeout: 5 * time.Second, + }) + if err := relay.Start(); err == nil { + relay.Close() + t.Fatal("standalone websocket udp relay Start returned nil, want duplicate listener error") + } + waitDone := make(chan struct{}) + go func() { + relay.wg.Wait() + close(waitDone) + }() + select { + case <-waitDone: + case <-time.After(time.Second): + t.Fatal("udp relay goroutines did not exit after Start failure") + } + if relay.ctx == nil { + t.Fatal("relay ctx = nil, want canceled context") + } + select { + case <-relay.ctx.Done(): + default: + t.Fatal("relay context was not canceled after Start failure") + } + if len(relay.wsListeners) != 0 { + t.Fatalf("wsListeners length = %d, want 0 after Start failure", len(relay.wsListeners)) + } +} + +func startUDPRelayPair(t *testing.T, timeout time.Duration) (*udpRelay, *udpRelay, net.Addr, func()) { + t.Helper() + + echoConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket echo udp returned error: %v", err) + } + echoDone := make(chan struct{}) + go serveUDPEcho(echoConn, echoDone) + + certPath, keyPath := writeTestCertificate(t, "127.0.0.1") + echoPort := strconv.Itoa(echoConn.LocalAddr().(*net.UDPAddr).Port) + serverRelay := newUDPRelay(udpRelayConfig{ + Server: true, + LocalAddr: "127.0.0.1", + LocalPort: "0", + RemoteAddr: "127.0.0.1", + RemotePort: echoPort, + Host: "127.0.0.1", + Cert: certPath, + Key: keyPath, + Timeout: timeout, + }) + if err := serverRelay.Start(); err != nil { + echoConn.Close() + <-echoDone + t.Fatalf("server udp relay Start returned error: %v", err) + } + serverPort := strconv.Itoa(serverRelay.listeners[0].LocalAddr().(*net.UDPAddr).Port) + + clientRelay := newUDPRelay(udpRelayConfig{ + LocalAddr: "127.0.0.1", + LocalPort: "0", + RemoteAddr: "127.0.0.1", + RemotePort: serverPort, + Host: "127.0.0.1", + Cert: certPath, + Timeout: timeout, + }) + if err := clientRelay.Start(); err != nil { + serverRelay.Close() + echoConn.Close() + <-echoDone + t.Fatalf("client udp relay Start returned error: %v", err) + } + clientAddr := clientRelay.listeners[0].LocalAddr() + + closeRelays := func() { + if err := clientRelay.Close(); err != nil { + t.Fatalf("client udp relay Close returned error: %v", err) + } + if err := serverRelay.Close(); err != nil { + t.Fatalf("server udp relay Close returned error: %v", err) + } + echoConn.Close() + <-echoDone + } + return clientRelay, serverRelay, clientAddr, closeRelays +} + +func startWebSocketUDPRelayPair(t *testing.T, timeout time.Duration) (*udpRelay, *udpRelay, net.Addr, func()) { + t.Helper() + + echoConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket echo udp returned error: %v", err) + } + echoDone := make(chan struct{}) + go serveUDPEcho(echoConn, echoDone) + + echoPort := strconv.Itoa(echoConn.LocalAddr().(*net.UDPAddr).Port) + serverRelay := newUDPRelay(udpRelayConfig{ + Server: true, + Mode: "websocket", + RemoteAddr: "127.0.0.1", + RemotePort: echoPort, + Host: "127.0.0.1", + Path: "/ray-udp", + Timeout: timeout, + }) + if err := serverRelay.Start(); err != nil { + echoConn.Close() + <-echoDone + t.Fatalf("server websocket udp relay Start returned error: %v", err) + } + + httpServer := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path != "/ray-udp" { + http.NotFound(w, req) + return + } + serverRelay.ServeWebSocket(w, req) + })} + httpListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + serverRelay.Close() + echoConn.Close() + <-echoDone + t.Fatalf("net.Listen http returned error: %v", err) + } + httpDone := make(chan struct{}) + go func() { + defer close(httpDone) + _ = httpServer.Serve(httpListener) + }() + + serverPort := strconv.Itoa(httpListener.Addr().(*net.TCPAddr).Port) + clientRelay := newUDPRelay(udpRelayConfig{ + Mode: "websocket", + LocalAddr: "127.0.0.1", + LocalPort: "0", + RemoteAddr: "127.0.0.1", + RemotePort: serverPort, + Host: "127.0.0.1", + Path: "/ray-udp", + Timeout: timeout, + }) + if err := clientRelay.Start(); err != nil { + httpServer.Close() + <-httpDone + serverRelay.Close() + echoConn.Close() + <-echoDone + t.Fatalf("client websocket udp relay Start returned error: %v", err) + } + clientAddr := clientRelay.listeners[0].LocalAddr() + + closeRelays := func() { + if err := clientRelay.Close(); err != nil { + t.Fatalf("client websocket udp relay Close returned error: %v", err) + } + httpServer.Close() + <-httpDone + if err := serverRelay.Close(); err != nil { + t.Fatalf("server websocket udp relay Close returned error: %v", err) + } + echoConn.Close() + <-echoDone + } + return clientRelay, serverRelay, clientAddr, closeRelays +} + +func startWebSocketUDPRelayServer(t *testing.T, timeout time.Duration) (*udpRelay, string, func()) { + t.Helper() + + echoConn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket echo udp returned error: %v", err) + } + echoDone := make(chan struct{}) + go serveUDPEcho(echoConn, echoDone) + + serverRelay := newUDPRelay(udpRelayConfig{ + Server: true, + Mode: "websocket", + RemoteAddr: "127.0.0.1", + RemotePort: strconv.Itoa(echoConn.LocalAddr().(*net.UDPAddr).Port), + Host: "127.0.0.1", + Path: testWebSocketUDPPath, + Timeout: timeout, + }) + if err := serverRelay.Start(); err != nil { + echoConn.Close() + <-echoDone + t.Fatalf("server websocket udp relay Start returned error: %v", err) + } + + httpServer := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path != testWebSocketUDPPath { + http.NotFound(w, req) + return + } + serverRelay.ServeWebSocket(w, req) + })} + httpListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + serverRelay.Close() + echoConn.Close() + <-echoDone + t.Fatalf("net.Listen http returned error: %v", err) + } + httpDone := make(chan struct{}) + go func() { + defer close(httpDone) + _ = httpServer.Serve(httpListener) + }() + + closeServer := func() { + httpServer.Close() + <-httpDone + if err := serverRelay.Close(); err != nil { + t.Fatalf("server websocket udp relay Close returned error: %v", err) + } + echoConn.Close() + <-echoDone + } + return serverRelay, httpListener.Addr().String(), closeServer +} + +func serveUDPEcho(conn net.PacketConn, done chan<- struct{}) { + defer close(done) + buf := make([]byte, udpRelayMaxPacketSize) + for { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + return + } + payload := make([]byte, n) + copy(payload, buf[:n]) + _, _ = conn.WriteTo(payload, addr) + } +} + +func writeUDPTestDatagram(t *testing.T, conn net.PacketConn, addr net.Addr, payload []byte) { + t.Helper() + if _, err := conn.WriteTo(payload, addr); err != nil { + t.Fatalf("udp WriteTo returned error: %v", err) + } +} + +func readUDPTestDatagram(t *testing.T, conn net.PacketConn) []byte { + t.Helper() + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline returned error: %v", err) + } + buf := make([]byte, udpRelayMaxPacketSize) + n, _, err := conn.ReadFrom(buf) + if err != nil { + t.Fatalf("udp ReadFrom returned error: %v", err) + } + payload := make([]byte, n) + copy(payload, buf[:n]) + return payload +} + +func reserveUDPPort(t *testing.T) (string, func()) { + t.Helper() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket reserve udp returned error: %v", err) + } + return strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port), func() { + if err := conn.Close(); err != nil { + t.Fatalf("reserved udp Close returned error: %v", err) + } + } +} + +func eventually(t *testing.T, timeout time.Duration, condition func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(20 * time.Millisecond) + } + if !condition() { + t.Fatal("condition was not met before timeout") + } +} + +func writeTestCertificate(t *testing.T, host string) (string, string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey returned error: %v", err) + } + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: host}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP(host)}, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + t.Fatalf("x509.CreateCertificate returned error: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + if err := os.WriteFile(certPath, certPEM, 0600); err != nil { + t.Fatalf("os.WriteFile cert returned error: %v", err) + } + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + t.Fatalf("os.WriteFile key returned error: %v", err) + } + return certPath, keyPath +} + +type fakeCoreServer struct { + started bool + closed bool +} + +func (s *fakeCoreServer) Start() error { + s.started = true + return nil +} + +func (s *fakeCoreServer) Close() error { + s.closed = true + return nil +} + +func TestPluginServerClosesTCPWhenUDPStartFails(t *testing.T) { + tcp := &fakeCoreServer{} + server := &pluginServer{ + tcp: tcp, + udp: newUDPRelay(udpRelayConfig{ + LocalAddr: "127.0.0.1", + LocalPort: "not-a-port", + Timeout: 30 * time.Second, + }), + } + + if err := server.Start(); err == nil { + t.Fatal("pluginServer Start returned nil, want udp start error") + } + if !tcp.started { + t.Fatal("tcp server was not started") + } + if !tcp.closed { + t.Fatal("tcp server was not closed after udp start failure") + } +} + +func TestPluginServerReturnsTCPStartError(t *testing.T) { + startErr := errors.New("start failed") + server := &pluginServer{ + tcp: &failingCoreServer{startErr: startErr}, + } + + if err := server.Start(); !errors.Is(err, startErr) { + t.Fatalf("pluginServer Start error = %v, want %v", err, startErr) + } +} + +type failingCoreServer struct { + startErr error +} + +func (s *failingCoreServer) Start() error { + return s.startErr +} + +func (s *failingCoreServer) Close() error { + return nil +} diff --git a/udp_relay.go b/udp_relay.go new file mode 100644 index 0000000..6be32f9 --- /dev/null +++ b/udp_relay.go @@ -0,0 +1,875 @@ +package main + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/quic-go/quic-go" + "github.com/v2fly/v2ray-core/v5/common/platform/filesystem" +) + +const udpRelayMaxPacketSize = 65507 + +type udpRelayConfig struct { + Server bool + Mode string + StandaloneWebSocketServer bool + LocalAddr string + LocalPort string + RemoteAddr string + RemotePort string + Host string + Path string + TLS bool + Cert string + CertRaw string + Key string + Timeout time.Duration +} + +type udpPacketSession interface { + Context() context.Context + Receive(context.Context) ([]byte, error) + Send([]byte) error + Close(string) error +} + +type quicPacketSession struct { + conn quic.Connection +} + +func (s quicPacketSession) Context() context.Context { + return s.conn.Context() +} + +func (s quicPacketSession) Receive(ctx context.Context) ([]byte, error) { + return s.conn.ReceiveDatagram(ctx) +} + +func (s quicPacketSession) Send(payload []byte) error { + return s.conn.SendDatagram(payload) +} + +func (s quicPacketSession) Close(reason string) error { + return s.conn.CloseWithError(0, reason) +} + +type webSocketPacketSession struct { + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +func newWebSocketPacketSession(parent context.Context, conn *websocket.Conn) *webSocketPacketSession { + ctx, cancel := context.WithCancel(parent) + conn.SetReadLimit(udpRelayMaxPacketSize) + return &webSocketPacketSession{conn: conn, ctx: ctx, cancel: cancel} +} + +func (s *webSocketPacketSession) Context() context.Context { + return s.ctx +} + +func (s *webSocketPacketSession) Receive(ctx context.Context) ([]byte, error) { + type readResult struct { + messageType int + payload []byte + err error + } + result := make(chan readResult, 1) + go func() { + messageType, payload, err := s.conn.ReadMessage() + result <- readResult{messageType: messageType, payload: payload, err: err} + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-s.ctx.Done(): + return nil, s.ctx.Err() + case r := <-result: + if r.err != nil { + return nil, r.err + } + if r.messageType != websocket.BinaryMessage { + return nil, newError("udp websocket received non-binary message") + } + return r.payload, nil + } +} + +func (s *webSocketPacketSession) Send(payload []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.conn.WriteMessage(websocket.BinaryMessage, payload) +} + +func (s *webSocketPacketSession) Close(reason string) error { + s.cancel() + s.mu.Lock() + defer s.mu.Unlock() + _ = s.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason), time.Now().Add(time.Second)) + return s.conn.Close() +} + +type udpRelay struct { + config udpRelayConfig + + mu sync.Mutex + listeners []net.PacketConn + quicServers []*quic.Listener + wsServers []*http.Server + wsListeners []net.Listener + closed bool + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + clientFlows map[string]*udpClientFlow + serverFlows map[udpPacketSession]*udpServerFlow +} + +type udpClientFlow struct { + source net.Addr + conn udpPacketSession + cancel context.CancelFunc + lastSeen time.Time +} + +type udpServerFlow struct { + conn udpPacketSession + udpConn net.Conn + cancel context.CancelFunc + lastSeen time.Time +} + +func newUDPRelay(config udpRelayConfig) *udpRelay { + return &udpRelay{config: config} +} + +func newUDPRelayFromOptions() (*udpRelay, error) { + if *udpMode == "" { + return nil, nil + } + if err := validateUDPOptions(); err != nil { + return nil, err + } + return newUDPRelay(udpRelayConfig{ + Server: *server, + Mode: *udpMode, + StandaloneWebSocketServer: *server && *udpMode == "websocket" && *mode != "websocket", + LocalAddr: *localAddr, + LocalPort: *localPort, + RemoteAddr: *remoteAddr, + RemotePort: *remotePort, + Host: *host, + Path: *udpPath, + TLS: *tlsEnabled, + Cert: *cert, + CertRaw: *certRaw, + Key: *key, + Timeout: time.Duration(*udpTimeout) * time.Second, + }), nil +} + +func (r *udpRelay) Start() error { + if r == nil { + return nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return newError("udp relay already closed") + } + if len(r.listeners) > 0 || len(r.wsListeners) > 0 { + return nil + } + + r.ctx, r.cancel = context.WithCancel(context.Background()) + if !r.config.Server { + r.clientFlows = make(map[string]*udpClientFlow) + } else { + r.serverFlows = make(map[udpPacketSession]*udpServerFlow) + } + r.wg.Add(1) + go r.cleanupIdleFlows() + if r.config.Server && r.config.Mode == "websocket" { + if r.config.StandaloneWebSocketServer { + if err := r.startServerWebSocketListenersLocked(); err != nil { + return r.cleanupFailedStartLocked(err) + } + } + return nil + } + + for _, addr := range parseLocalAddr(r.config.LocalAddr) { + listenAddr := net.JoinHostPort(addr, r.config.LocalPort) + conn, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return r.cleanupFailedStartLocked(newError("failed to start udp relay at", listenAddr).Base(err)) + } + r.listeners = append(r.listeners, conn) + if r.config.Server { + server, err := quic.Listen(conn, r.serverTLSConfig(), r.quicConfig()) + if err != nil { + return r.cleanupFailedStartLocked(newError("failed to start udp quic relay at", listenAddr).Base(err)) + } + r.quicServers = append(r.quicServers, server) + r.wg.Add(1) + go r.serveServerQUIC(server) + } else { + r.wg.Add(1) + go r.serveClientUDP(conn) + } + } + + return nil +} + +func (r *udpRelay) cleanupFailedStartLocked(startErr error) error { + if r.cancel != nil { + r.cancel() + r.cancel = nil + } + for _, quicServer := range r.quicServers { + if err := quicServer.Close(); err != nil { + logWarn(err.Error()) + } + } + r.quicServers = nil + r.closeServerWebSocketListenersLocked() + for _, listener := range r.listeners { + if err := listener.Close(); err != nil { + logWarn(err.Error()) + } + } + r.listeners = nil + return startErr +} + +func (r *udpRelay) serveServerQUIC(server *quic.Listener) { + defer r.wg.Done() + + for { + conn, err := server.Accept(r.ctx) + if err != nil { + if r.isClosed() { + return + } + logWarn("udp quic accept failed:", err.Error()) + continue + } + if err := r.startServerFlow(quicPacketSession{conn: conn}); err != nil { + logWarn("failed to start udp server flow:", err.Error()) + if closeErr := conn.CloseWithError(0, "udp server flow failed"); closeErr != nil { + logWarn(closeErr.Error()) + } + } + } +} + +func (r *udpRelay) startServerWebSocketListenersLocked() error { + for _, addr := range parseLocalAddr(r.config.LocalAddr) { + listenAddr := net.JoinHostPort(addr, r.config.LocalPort) + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + r.closeServerWebSocketListenersLocked() + return newError("failed to start udp websocket relay at", listenAddr).Base(err) + } + if r.config.TLS { + certificate, err := r.serverWebSocketTLSCertificate() + if err != nil { + listener.Close() + r.closeServerWebSocketListenersLocked() + return newError("failed to load udp websocket relay certificate").Base(err) + } + listener = tls.NewListener(listener, &tls.Config{Certificates: []tls.Certificate{certificate}}) + } + server := &http.Server{Handler: http.HandlerFunc(r.serveStandaloneWebSocket)} + r.wsListeners = append(r.wsListeners, listener) + r.wsServers = append(r.wsServers, server) + r.wg.Add(1) + go func(s *http.Server, l net.Listener) { + defer r.wg.Done() + if err := s.Serve(l); err != nil && err != http.ErrServerClosed && !r.isClosed() { + logWarn("udp websocket relay serve failed:", err.Error()) + } + }(server, listener) + } + return nil +} + +func (r *udpRelay) serveStandaloneWebSocket(w http.ResponseWriter, req *http.Request) { + if req.URL.Path != r.config.Path { + http.NotFound(w, req) + return + } + r.ServeWebSocket(w, req) +} + +func (r *udpRelay) startServerFlow(conn udpPacketSession) error { + ctx, cancel := context.WithCancel(r.ctx) + remoteAddr := net.JoinHostPort(r.config.RemoteAddr, r.config.RemotePort) + udpConn, err := (&net.Dialer{}).DialContext(ctx, "udp", remoteAddr) + if err != nil { + cancel() + return err + } + + flow := &udpServerFlow{ + conn: conn, + udpConn: udpConn, + cancel: cancel, + lastSeen: time.Now(), + } + + r.mu.Lock() + if r.closed { + r.mu.Unlock() + cancel() + if err := udpConn.Close(); err != nil { + logWarn(err.Error()) + } + return newError("udp relay closed") + } + r.serverFlows[conn] = flow + r.wg.Add(2) + go r.forwardServerQUICToUDP(flow) + go r.forwardServerUDPToQUIC(flow) + r.mu.Unlock() + + return nil +} + +func (r *udpRelay) forwardServerQUICToUDP(flow *udpServerFlow) { + defer r.wg.Done() + + for { + payload, err := flow.conn.Receive(flow.conn.Context()) + if err != nil { + if !r.isClosed() { + logWarn("udp server flow quic receive failed:", err.Error()) + } + r.closeServerFlow(flow) + return + } + if len(payload) == 0 { + continue + } + if _, err := flow.udpConn.Write(payload); err != nil { + if !r.isClosed() { + logWarn("failed to forward udp datagram to ss-server:", err.Error()) + } + r.closeServerFlow(flow) + return + } + r.touchServerFlow(flow) + } +} + +func (r *udpRelay) forwardServerUDPToQUIC(flow *udpServerFlow) { + defer r.wg.Done() + + buf := make([]byte, udpRelayMaxPacketSize) + for { + n, err := flow.udpConn.Read(buf) + if err != nil { + if !r.isClosed() { + logWarn("udp server flow socket read failed:", err.Error()) + } + r.closeServerFlow(flow) + return + } + if n == 0 { + continue + } + payload := make([]byte, n) + copy(payload, buf[:n]) + if err := flow.conn.Send(payload); err != nil { + var tooLarge *quic.DatagramTooLargeError + if errors.As(err, &tooLarge) { + logWarn("drop oversized udp response datagram max", tooLarge.MaxDatagramPayloadSize) + continue + } + if !r.isClosed() { + logWarn("failed to send udp response datagram:", err.Error()) + } + r.closeServerFlow(flow) + return + } + r.touchServerFlow(flow) + } +} + +func (r *udpRelay) cleanupIdleFlows() { + defer r.wg.Done() + + interval := r.config.Timeout / 2 + if interval <= 0 { + interval = time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + return + case now := <-ticker.C: + r.expireIdleFlows(now) + } + } +} + +func (r *udpRelay) expireIdleFlows(now time.Time) { + var clientFlows []*udpClientFlow + var serverFlows []*udpServerFlow + + r.mu.Lock() + if r.closed { + r.mu.Unlock() + return + } + for key, flow := range r.clientFlows { + if now.Sub(flow.lastSeen) >= r.config.Timeout { + delete(r.clientFlows, key) + clientFlows = append(clientFlows, flow) + } + } + for conn, flow := range r.serverFlows { + if now.Sub(flow.lastSeen) >= r.config.Timeout { + delete(r.serverFlows, conn) + serverFlows = append(serverFlows, flow) + } + } + r.mu.Unlock() + + for _, flow := range clientFlows { + flow.cancel() + if err := flow.conn.Close("udp flow idle timeout"); err != nil && !r.isClosed() { + logWarn(err.Error()) + } + } + for _, flow := range serverFlows { + r.shutdownServerFlow(flow, "udp flow idle timeout") + } +} + +func (r *udpRelay) serveClientUDP(listener net.PacketConn) { + defer r.wg.Done() + + buf := make([]byte, udpRelayMaxPacketSize) + for { + n, source, err := listener.ReadFrom(buf) + if err != nil { + if r.isClosed() { + return + } + logWarn("udp relay read failed:", err.Error()) + continue + } + if n == 0 { + continue + } + + payload := make([]byte, n) + copy(payload, buf[:n]) + + flow, err := r.getClientFlow(listener, source) + if err != nil { + if !r.isClosed() { + logWarn("failed to create udp client flow:", err.Error()) + } + continue + } + r.touchClientFlow(source.String()) + + if err := flow.conn.Send(payload); err != nil { + var tooLarge *quic.DatagramTooLargeError + if errors.As(err, &tooLarge) { + logWarn("drop oversized udp datagram from", source.String(), "max", tooLarge.MaxDatagramPayloadSize) + continue + } + logWarn("failed to send udp datagram:", err.Error()) + } + } +} + +func (r *udpRelay) getClientFlow(listener net.PacketConn, source net.Addr) (*udpClientFlow, error) { + key := source.String() + + r.mu.Lock() + if flow := r.clientFlows[key]; flow != nil { + r.mu.Unlock() + return flow, nil + } + ctx := r.ctx + r.mu.Unlock() + + flowCtx, cancel := context.WithCancel(ctx) + remoteAddr := net.JoinHostPort(r.config.RemoteAddr, r.config.RemotePort) + conn, err := r.dialClientSession(flowCtx, remoteAddr) + if err != nil { + cancel() + return nil, err + } + + flow := &udpClientFlow{ + source: source, + conn: conn, + cancel: cancel, + lastSeen: time.Now(), + } + + r.mu.Lock() + if existing := r.clientFlows[key]; existing != nil { + r.mu.Unlock() + cancel() + if err := conn.Close("duplicate udp flow"); err != nil { + logWarn(err.Error()) + } + return existing, nil + } + if r.closed { + r.mu.Unlock() + cancel() + if err := conn.Close("udp relay closed"); err != nil { + logWarn(err.Error()) + } + return nil, newError("udp relay closed") + } + r.clientFlows[key] = flow + r.wg.Add(1) + go r.receiveClientDatagrams(listener, key, flow) + r.mu.Unlock() + + return flow, nil +} + +func (r *udpRelay) receiveClientDatagrams(listener net.PacketConn, key string, flow *udpClientFlow) { + defer r.wg.Done() + + for { + payload, err := flow.conn.Receive(flow.conn.Context()) + if err != nil { + if !r.isClosed() { + logWarn("udp client flow receive failed:", err.Error()) + } + r.removeClientFlow(key, flow) + return + } + if len(payload) == 0 { + continue + } + if _, err := listener.WriteTo(payload, flow.source); err != nil { + if !r.isClosed() { + logWarn("failed to write udp datagram to local endpoint:", err.Error()) + } + continue + } + r.touchClientFlow(key) + } +} + +func (r *udpRelay) clientTLSConfig() *tls.Config { + config := &tls.Config{ + ServerName: r.config.Host, + NextProtos: []string{ + "v2ray-plugin-sip003u", + }, + } + if r.config.Cert == "" && r.config.CertRaw == "" { + return config + } + + certPEM, err := r.readConfiguredCertificate(r.config.Cert) + if err != nil { + logWarn("failed to read udp relay certificate:", err.Error()) + return config + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(certPEM) { + logWarn("failed to parse udp relay certificate authority") + return config + } + config.RootCAs = roots + return config +} + +func (r *udpRelay) dialClientSession(ctx context.Context, remoteAddr string) (udpPacketSession, error) { + if r.config.Mode == "websocket" { + scheme := "ws" + dialer := websocket.Dialer{} + if r.config.TLS { + scheme = "wss" + dialer.TLSClientConfig = r.webSocketClientTLSConfig() + } else if r.config.Cert != "" || r.config.CertRaw != "" { + dialer.TLSClientConfig = r.webSocketClientTLSConfig() + } + u := url.URL{Scheme: scheme, Host: remoteAddr, Path: r.config.Path} + header := http.Header{} + if r.config.Host != "" { + header.Set("Host", r.config.Host) + } + conn, _, err := dialer.DialContext(ctx, u.String(), header) + if err != nil { + return nil, err + } + return newWebSocketPacketSession(ctx, conn), nil + } + + conn, err := quic.DialAddr(ctx, remoteAddr, r.clientTLSConfig(), r.quicConfig()) + if err != nil { + return nil, err + } + return quicPacketSession{conn: conn}, nil +} + +func (r *udpRelay) webSocketClientTLSConfig() *tls.Config { + config := &tls.Config{ServerName: r.config.Host} + if r.config.Cert == "" && r.config.CertRaw == "" { + return config + } + certPEM, err := r.readConfiguredCertificate(r.config.Cert) + if err != nil { + logWarn("failed to read udp websocket certificate:", err.Error()) + return config + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(certPEM) { + logWarn("failed to parse udp websocket certificate authority") + return config + } + config.RootCAs = roots + return config +} + +func (r *udpRelay) ServeWebSocket(w http.ResponseWriter, req *http.Request) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + logWarn("udp websocket upgrade failed:", err.Error()) + return + } + session := newWebSocketPacketSession(r.ctx, conn) + if err := r.startServerFlow(session); err != nil { + logWarn("failed to start udp websocket server flow:", err.Error()) + if closeErr := session.Close("udp websocket server flow failed"); closeErr != nil { + logWarn(closeErr.Error()) + } + } +} + +func (r *udpRelay) serverTLSConfig() *tls.Config { + certificate, err := r.serverTLSCertificate() + if err != nil { + logWarn("failed to load udp relay server certificate:", err.Error()) + return &tls.Config{ + NextProtos: []string{"v2ray-plugin-sip003u"}, + } + } + return &tls.Config{ + Certificates: []tls.Certificate{certificate}, + NextProtos: []string{"v2ray-plugin-sip003u"}, + } +} + +func (r *udpRelay) serverTLSCertificate() (tls.Certificate, error) { + certPath := r.config.Cert + keyPath := r.config.Key + if certPath == "" && r.config.CertRaw == "" { + certPath = fmt.Sprintf("%s/.acme.sh/%s/fullchain.cer", homeDir(), r.config.Host) + logWarn("No UDP TLS cert specified, trying", certPath) + } + certPEM, err := r.readConfiguredCertificate(certPath) + if err != nil { + return tls.Certificate{}, err + } + if keyPath == "" { + keyPath = fmt.Sprintf("%[1]s/.acme.sh/%[2]s/%[2]s.key", homeDir(), r.config.Host) + logWarn("No UDP TLS key specified, trying", keyPath) + } + keyPEM, err := filesystem.ReadFile(keyPath) + if err != nil { + return tls.Certificate{}, err + } + return tls.X509KeyPair(certPEM, keyPEM) +} + +func (r *udpRelay) serverWebSocketTLSCertificate() (tls.Certificate, error) { + certPath := r.config.Cert + keyPath := r.config.Key + if certPath == "" && r.config.CertRaw == "" { + certPath = fmt.Sprintf("%s/.acme.sh/%s/fullchain.cer", homeDir(), r.config.Host) + logWarn("No UDP WebSocket TLS cert specified, trying", certPath) + } + certPEM, err := r.readConfiguredCertificate(certPath) + if err != nil { + return tls.Certificate{}, err + } + if keyPath == "" { + keyPath = fmt.Sprintf("%[1]s/.acme.sh/%[2]s/%[2]s.key", homeDir(), r.config.Host) + logWarn("No UDP WebSocket TLS key specified, trying", keyPath) + } + keyPEM, err := filesystem.ReadFile(keyPath) + if err != nil { + return tls.Certificate{}, err + } + return tls.X509KeyPair(certPEM, keyPEM) +} + +func (r *udpRelay) readConfiguredCertificate(certPath string) ([]byte, error) { + if certPath != "" { + return filesystem.ReadFile(certPath) + } + if r.config.CertRaw != "" { + certHead := "-----BEGIN CERTIFICATE-----" + certTail := "-----END CERTIFICATE-----" + fixedCert := certHead + "\n" + r.config.CertRaw + "\n" + certTail + return []byte(fixedCert), nil + } + return nil, newError("missing udp relay certificate") +} + +func (r *udpRelay) quicConfig() *quic.Config { + return &quic.Config{ + EnableDatagrams: true, + MaxIdleTimeout: r.config.Timeout, + KeepAlivePeriod: r.config.Timeout / 2, + } +} + +func (r *udpRelay) touchServerFlow(flow *udpServerFlow) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.serverFlows[flow.conn] == flow { + flow.lastSeen = time.Now() + } +} + +func (r *udpRelay) closeServerFlow(flow *udpServerFlow) { + r.mu.Lock() + if r.serverFlows[flow.conn] == flow { + delete(r.serverFlows, flow.conn) + } + r.mu.Unlock() + + r.shutdownServerFlow(flow, "udp server flow closed") +} + +func (r *udpRelay) shutdownServerFlow(flow *udpServerFlow, reason string) { + flow.cancel() + if err := flow.udpConn.Close(); err != nil && !r.isClosed() { + logWarn(err.Error()) + } + if err := flow.conn.Close(reason); err != nil && !r.isClosed() { + logWarn(err.Error()) + } +} + +func (r *udpRelay) touchClientFlow(key string) { + r.mu.Lock() + defer r.mu.Unlock() + + if flow := r.clientFlows[key]; flow != nil { + flow.lastSeen = time.Now() + } +} + +func (r *udpRelay) removeClientFlow(key string, flow *udpClientFlow) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.clientFlows[key] == flow { + delete(r.clientFlows, key) + } +} + +func (r *udpRelay) isClosed() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.closed +} + +func (r *udpRelay) Close() error { + if r == nil { + return nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return nil + } + r.closed = true + if r.cancel != nil { + r.cancel() + } + + var closeErr error + for _, quicServer := range r.quicServers { + if err := quicServer.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + r.quicServers = nil + + r.closeServerWebSocketListenersLocked() + + for _, listener := range r.listeners { + if err := listener.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + r.listeners = nil + + for key, flow := range r.clientFlows { + flow.cancel() + if err := flow.conn.Close("udp relay closed"); err != nil && closeErr == nil { + closeErr = err + } + delete(r.clientFlows, key) + } + for conn, flow := range r.serverFlows { + flow.cancel() + if err := flow.udpConn.Close(); err != nil && closeErr == nil { + closeErr = err + } + if err := conn.Close("udp relay closed"); err != nil && closeErr == nil { + closeErr = err + } + delete(r.serverFlows, conn) + } + r.mu.Unlock() + r.wg.Wait() + r.mu.Lock() + + return closeErr +} + +func (r *udpRelay) closeServerWebSocketListenersLocked() { + for _, server := range r.wsServers { + if err := server.Close(); err != nil { + logWarn(err.Error()) + } + } + r.wsServers = nil + for _, listener := range r.wsListeners { + if err := listener.Close(); err != nil { + logWarn(err.Error()) + } + } + r.wsListeners = nil +} diff --git a/websocket_router.go b/websocket_router.go new file mode 100644 index 0000000..e463d29 --- /dev/null +++ b/websocket_router.go @@ -0,0 +1,204 @@ +package main + +import ( + "context" + gotls "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "sync" + "time" + + core "github.com/v2fly/v2ray-core/v5" + "github.com/v2fly/v2ray-core/v5/common/platform/filesystem" +) + +type webSocketRouterConfig struct { + LocalAddr string + LocalPort string + TCPPath string + UDPPath string + InternalAddr string + InternalPort string + TLS bool + Host string + Cert string + CertRaw string + Key string +} + +type webSocketRouter struct { + config webSocketRouterConfig + udpRelay *udpRelay + proxy *httputil.ReverseProxy + servers []*http.Server + listeners []net.Listener + wg sync.WaitGroup +} + +func newWebSocketRouter(config webSocketRouterConfig, udpRelay *udpRelay) (*webSocketRouter, error) { + target := &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(config.InternalAddr, config.InternalPort), + } + return &webSocketRouter{ + config: config, + udpRelay: udpRelay, + proxy: httputil.NewSingleHostReverseProxy(target), + }, nil +} + +func (r *webSocketRouter) Start() error { + for _, addr := range parseLocalAddr(r.config.LocalAddr) { + listenAddr := net.JoinHostPort(addr, r.config.LocalPort) + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + r.Close() + return newError("failed to start websocket router at", listenAddr).Base(err) + } + if r.config.TLS { + cert, err := r.tlsCertificate() + if err != nil { + listener.Close() + r.Close() + return newError("failed to load websocket router certificate").Base(err) + } + listener = gotls.NewListener(listener, &gotls.Config{Certificates: []gotls.Certificate{cert}}) + } + + server := &http.Server{Handler: http.HandlerFunc(r.serveHTTP)} + r.listeners = append(r.listeners, listener) + r.servers = append(r.servers, server) + r.wg.Add(1) + go func(s *http.Server, l net.Listener) { + defer r.wg.Done() + if err := s.Serve(l); err != nil && err != http.ErrServerClosed { + logWarn("websocket router serve failed:", err.Error()) + } + }(server, listener) + } + return nil +} + +func (r *webSocketRouter) serveHTTP(w http.ResponseWriter, req *http.Request) { + switch req.URL.Path { + case r.config.UDPPath: + r.udpRelay.ServeWebSocket(w, req) + case r.config.TCPPath: + r.proxy.ServeHTTP(w, req) + default: + http.NotFound(w, req) + } +} + +func (r *webSocketRouter) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var closeErr error + for _, server := range r.servers { + if err := server.Shutdown(ctx); err != nil && err != http.ErrServerClosed && closeErr == nil { + closeErr = err + } + } + for _, listener := range r.listeners { + if err := listener.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + r.wg.Wait() + return closeErr +} + +func (r *webSocketRouter) tlsCertificate() (gotls.Certificate, error) { + certPath := r.config.Cert + keyPath := r.config.Key + if certPath == "" && r.config.CertRaw == "" { + certPath = fmt.Sprintf("%s/.acme.sh/%s/fullchain.cer", homeDir(), r.config.Host) + logWarn("No TLS cert specified, trying", certPath) + } + certPEM, err := readRouterCertificate(certPath, r.config.CertRaw) + if err != nil { + return gotls.Certificate{}, err + } + if keyPath == "" { + keyPath = fmt.Sprintf("%[1]s/.acme.sh/%[2]s/%[2]s.key", homeDir(), r.config.Host) + logWarn("No TLS key specified, trying", keyPath) + } + keyPEM, err := filesystem.ReadFile(keyPath) + if err != nil { + return gotls.Certificate{}, err + } + return gotls.X509KeyPair(certPEM, keyPEM) +} + +func readRouterCertificate(certPath, certRaw string) ([]byte, error) { + if certPath != "" { + return filesystem.ReadFile(certPath) + } + if certRaw != "" { + return []byte("-----BEGIN CERTIFICATE-----\n" + certRaw + "\n-----END CERTIFICATE-----"), nil + } + return nil, newError("missing websocket router certificate") +} + +func startCombinedWebSocketServer() (core.Server, error) { + internalPort, err := allocateLoopbackTCPPort() + if err != nil { + return nil, err + } + + publicLocalAddr := *localAddr + publicLocalPort := *localPort + oldLocalAddr := *localAddr + oldLocalPort := *localPort + oldTLSEnabled := *tlsEnabled + *localAddr = "127.0.0.1" + *localPort = internalPort + *tlsEnabled = false + config, err := generateConfig() + *localAddr = oldLocalAddr + *localPort = oldLocalPort + *tlsEnabled = oldTLSEnabled + if err != nil { + return nil, newError("failed to parse internal websocket config").Base(err) + } + + instance, err := core.New(config) + if err != nil { + return nil, newError("failed to create internal v2ray instance").Base(err) + } + udpRelay, err := newUDPRelayFromOptions() + if err != nil { + return nil, err + } + router, err := newWebSocketRouter(webSocketRouterConfig{ + LocalAddr: publicLocalAddr, + LocalPort: publicLocalPort, + TCPPath: *path, + UDPPath: *udpPath, + InternalAddr: "127.0.0.1", + InternalPort: internalPort, + TLS: oldTLSEnabled, + Host: *host, + Cert: *cert, + CertRaw: *certRaw, + Key: *key, + }, udpRelay) + if err != nil { + return nil, err + } + return &combinedWebSocketServer{tcp: instance, udp: udpRelay, router: router}, nil +} + +func allocateLoopbackTCPPort() (string, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", newError("failed to allocate internal tcp port").Base(err) + } + defer listener.Close() + return strconv.Itoa(listener.Addr().(*net.TCPAddr).Port), nil +}