Skip to content

Commit 80b6431

Browse files
author
mkyla
authored
refactor: implement UDP frame read/write functions for improved data handling
1 parent a87717f commit 80b6431

1 file changed

Lines changed: 40 additions & 10 deletions

File tree

internal/common.go

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,40 @@ func (c *Common) decode(data []byte) ([]byte, error) {
326326
return c.xor(decoded), nil
327327
}
328328

329+
func writeUDPFrame(w net.Conn, data []byte) error {
330+
length := len(data)
331+
if length > 65535 {
332+
return fmt.Errorf("writeUDPFrame: datagram too large: %d", length)
333+
}
334+
header := [2]byte{byte(length >> 8), byte(length)}
335+
if _, err := w.Write(header[:]); err != nil {
336+
return err
337+
}
338+
_, err := w.Write(data)
339+
return err
340+
}
341+
342+
func readUDPFrame(conn net.Conn, buf []byte, timeout time.Duration) (int, error) {
343+
if timeout > 0 {
344+
conn.SetReadDeadline(time.Now().Add(timeout))
345+
}
346+
var header [2]byte
347+
if _, err := io.ReadFull(conn, header[:]); err != nil {
348+
return 0, err
349+
}
350+
length := int(header[0])<<8 | int(header[1])
351+
if length == 0 {
352+
return 0, nil
353+
}
354+
if length > len(buf) {
355+
return 0, fmt.Errorf("readUDPFrame: datagram too large: %d > buffer %d", length, len(buf))
356+
}
357+
if _, err := io.ReadFull(conn, buf[:length]); err != nil {
358+
return 0, err
359+
}
360+
return length, nil
361+
}
362+
329363
func (c *Common) resolve(network, address string) (any, error) {
330364
now := time.Now()
331365

@@ -1389,10 +1423,9 @@ func (c *Common) commonUDPLoop() {
13891423

13901424
buffer := c.getUDPBuffer()
13911425
defer c.putUDPBuffer(buffer)
1392-
reader := &conn.TimeoutReader{Conn: remoteConn, Timeout: udpReadTimeout}
13931426

13941427
for c.ctx.Err() == nil {
1395-
x, err := reader.Read(buffer)
1428+
x, err := readUDPFrame(remoteConn, buffer, udpReadTimeout)
13961429
if err != nil {
13971430
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
13981431
c.logger.Debug("UDP session abort: %v", err)
@@ -1426,10 +1459,9 @@ func (c *Common) commonUDPLoop() {
14261459
c.logger.Debug("Starting transfer: %v <-> %v", remoteConn.LocalAddr(), c.targetUDPConn.LocalAddr())
14271460
}
14281461

1429-
_, err = remoteConn.Write(buffer[:x])
1430-
if err != nil {
1462+
if err = writeUDPFrame(remoteConn, buffer[:x]); err != nil {
14311463
if err != io.EOF {
1432-
c.logger.Error("commonUDPLoop: write to tunnel failed: %v", err)
1464+
c.logger.Error("tunnelUDPLoop: write to tunnel failed: %v", err)
14331465
}
14341466
c.targetUDPSession.Delete(sessionKey)
14351467
remoteConn.Close()
@@ -1704,10 +1736,9 @@ func (c *Common) commonUDPOnce(signal Signal) {
17041736

17051737
buffer := c.getUDPBuffer()
17061738
defer c.putUDPBuffer(buffer)
1707-
reader := &conn.TimeoutReader{Conn: remoteConn, Timeout: udpReadTimeout}
17081739

17091740
for c.ctx.Err() == nil {
1710-
x, err := reader.Read(buffer)
1741+
x, err := readUDPFrame(remoteConn, buffer, udpReadTimeout)
17111742
if err != nil {
17121743
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
17131744
c.logger.Debug("UDP session abort: %v", err)
@@ -1747,10 +1778,9 @@ func (c *Common) commonUDPOnce(signal Signal) {
17471778
return
17481779
}
17491780

1750-
_, err = remoteConn.Write(buffer[:x])
1751-
if err != nil {
1781+
if err = writeUDPFrame(remoteConn, buffer[:x]); err != nil {
17521782
if err != io.EOF {
1753-
c.logger.Error("commonUDPOnce: write to tunnel failed: %v", err)
1783+
c.logger.Error("tunnelUDPOnce: write to tunnel failed: %v", err)
17541784
}
17551785
return
17561786
}

0 commit comments

Comments
 (0)