G

[Golang] 基于TCP的简单协议实现

RoLingG Golang 2025-05-26

基于 TCP 的简单协议实现

开头注明:本人这次属于学习,很大一部分靠的是AI生成的代码理解和一些对于TCP资料的学习,可能对于懂的人来说没什么技术含量,在此提前声明。

简介

写这个其实是因为某公司的笔试需要进行协议相关的开发,但因为本人菜鸡太久没写协议类开发代码和标准库 io 相关的代码导致没过测试,因此来写写相关代码。

代码解析

项目结构:

--Project
 --pkg
     --connInit.go
 --client.go
 --server.go

协议设计 connInit.go

package pkg

import (
    "bufio"
    "fmt"
    "io"
    "net"
    "strings"
    "sync"
)

// Conn 是你需要实现的一种连接类型,它支持下面描述的若干接口;
// 为了实现这些接口,你需要设计一个基于 TCP 的简单协议;
type Conn struct {
    tcpConn net.Conn
    writer  *bufio.Writer
    reader  *bufio.Reader
    mutex   sync.Mutex
    keys    sync.Map // 多个并发数据流的标识字典(单一连接可不需要)
}

// NewConn 从一个 TCP 连接得到一个你实现的连接对象
func NewConn(conn net.Conn) *Conn {
    return &Conn{
       tcpConn: conn,
       writer:  bufio.NewWriter(conn),
       reader:  bufio.NewReader(conn),
    }
}

// Send 传入一个 key 表示发送者将要传输的数据对应的标识;
// 返回 writer 可供发送者分多次写入大量该 key 对应的数据;
// 当发送者已将该 key 对应的所有数据写入后,调用 writer.Close 告知接收者:该 key 的数据已经完全写入;
func (conn *Conn) Send(key string) (io.WriteCloser, error) {
    conn.mutex.Lock()
    defer conn.mutex.Unlock()

    fmt.Println("Starting to send key:", key)
    if _, ok := conn.keys.Load(key); ok {
       return nil, fmt.Errorf("key %s already exists", key)
    }

    pipeReader, pipeWriter := io.Pipe()

    go func() {
       defer conn.keys.Delete(key)
       defer pipeReader.Close()

       if _, err := conn.writer.WriteString(key + "\n"); err != nil {
          pipeWriter.CloseWithError(err)
          return
       }

       if _, err := io.Copy(conn.writer, pipeReader); err != nil {
          pipeWriter.CloseWithError(err)
          return
       }

       if _, err := conn.writer.WriteString("END\n"); err != nil {
          pipeWriter.CloseWithError(err)
          return
       }

       if err := conn.writer.Flush(); err != nil {
          pipeWriter.CloseWithError(err)
          return
       }

       pipeWriter.Close()
    }()

    conn.keys.Store(key, pipeReader)
    return pipeWriter, nil
}

// Receive 返回一个 key 表示接收者将要接收到的数据对应的标识;
// 返回的 reader 可供接收者多次读取该 key 对应的数据;
// 当 reader 返回 io.EOF 错误时,表示接收者已经完整接收该 key 对应的数据;
func (conn *Conn) Receive() (string, io.Reader, error) {
    conn.mutex.Lock()
    defer conn.mutex.Unlock()

    keyLine, err := conn.reader.ReadString('\n')
    if err != nil {
       if err == io.EOF {
          fmt.Println("Connection closed while reading key")
          return "", nil, io.EOF
       }
       fmt.Println("Error reading key:", err)
       return "", nil, err
    }

    key := strings.TrimSuffix(keyLine, "\n")

    if _, ok := conn.keys.Load(key); ok {
       fmt.Printf("Warning: key %s already exists, overwriting\n", key)
    }

    reader, writer := io.Pipe()

    go func() {
       defer conn.keys.Delete(key)
       defer writer.Close()

       buffer := &strings.Builder{}
       for {
          line, err := conn.reader.ReadString('\n')
          if err != nil {
             writer.CloseWithError(err)
             return
          }

          if strings.HasSuffix(line, "END\n") {
             line = line[:len(line)-4]
             buffer.WriteString(line)
             break
          }
          buffer.WriteString(line)
       }

       if _, err := io.Copy(writer, strings.NewReader(buffer.String())); err != nil {
          writer.CloseWithError(err)
          return
       }

       writer.Close()
    }()

    conn.keys.Store(key, reader)
    return key, reader, nil
}

// Close 关闭你实现的连接对象及其底层的 TCP 连接
func (conn *Conn) Close() {
    conn.mutex.Lock()
    defer conn.mutex.Unlock()

    conn.tcpConn.Close()

    conn.keys.Range(func(key, value interface{}) bool {
       if closer, ok := value.(io.Closer); ok {
          if err := closer.Close(); err != nil {
             fmt.Printf("Failed to close %v: %v\n", key, err)
          }
       }
       return true
    })

    conn.keys = sync.Map{}
}

服务端 server.go

// server.go
package main

import (
    "fmt"
    "io"
    "net"
    "sync"
    "tcpInit/pkg"
)

func main() {
    server, err := net.Listen("tcp", "localhost:8080")
    if err != nil {
        fmt.Printf("Failed to start server: %v\n", err)
        return
    }
    defer server.Close()

    fmt.Println("Server started on :8080")

    var wg sync.WaitGroup

    for {
        conn, err := server.Accept()
        if err != nil {
            fmt.Printf("Failed to accept connection: %v\n", err)
            continue
        }
        wg.Add(1)
        go func(conn net.Conn) {
            defer wg.Done()
            defer conn.Close()

            connObj := pkg.NewConn(conn)

            for {
                key, reader, err := connObj.Receive()
                if err != nil {
                    fmt.Printf("Failed to receive data: %v\n", err)
                    return
                }
                fmt.Printf("Server received key: %s from %s\n", key, conn.RemoteAddr())

                data := make([]byte, 1024)
                n, err := reader.Read(data)
                if err != nil && err != io.EOF {
                    fmt.Printf("Failed to read data: %v\n", err)
                    return
                }
                fmt.Printf("Server received data: %s from %s\n", string(data[:n]), conn.RemoteAddr())

                if err == io.EOF {
                    break // 如果接收到 EOF,退出循环
                }
            }

            connObj.Close()
        }(conn)
    }

    wg.Wait()
}

客户端 client.go

// client.go
package main

import (
    "bufio"
    "fmt"
    "net"
    "os"
    "tcpInit/pkg"

    // 确保这个路径是正确的
    "time"
)

func main() {
    // 客户端进行连接服务端
    conn, err := net.Dial("tcp", "localhost:8080")
    if err != nil {
        fmt.Printf("Failed to connect to server: %v\n", err)
        return
    }
    defer conn.Close()

    connObj := pkg.NewConn(conn)

    for {
        fmt.Print("Enter key (type 'exit' to quit): ")
        var input string
        var key string
        fmt.Scanln(&key)
        fmt.Print("Enter message (type 'exit' to quit): ")
        scanner := bufio.NewScanner(os.Stdin)
        scanner.Scan() // 读取整行输入
        input = scanner.Text()

        if input == "exit" || key == "exit" {
            break
        }

        // 发送消息内容(其实更像是打开发送/接收通道,因为实质上的发送消息应该是下面的Write,写进去才能发,写进去之前Send里面的通道都是阻塞等待的)
        writer, err := connObj.Send(key)
        if err != nil {
            fmt.Printf("Failed to send data: %v\n", err)
            return
        }

        if _, err := writer.Write([]byte(input)); err != nil {
            fmt.Printf("Failed to write data: %v\n", err)
            return
        }

        if err := writer.Close(); err != nil {
            fmt.Printf("Failed to close writer: %v\n", err)
            return
        }

        time.Sleep(1 * time.Second) // 等待服务器处理
    }

    // 如果客户端想要保持连接,那么底层tcp关闭就必须在外面,里面只能close写入流
    connObj.Close()
    fmt.Println("Client exited.")
}

测试

测试下来显示,服务端能够被多个客户端连接,且不同客户端之间不会相互影响,服务端能够异步性的返回正常的结果。这其中要归功于 Connkeys 作为各个客户端的连接标识以及其 mutext 的互斥锁,保证了各个客户端的独立性与数据的一致性。避免了脏数据问题。

总结

从这过程中我学到了很多关于使用 io 实现服务端与客户端连接操作相关的知识、设计思路与写法。

了解到了一些地方容易发生阻塞问题,还有 interface 相关的内容,例如:

pipeReader, pipeWriter := io.Pipe()

readerwriter 只是被分别初始化成了:

// reader
type PipeReader struct {
    p *pipe
}
 
A PipeReader is the read half of a pipe.
方法对象: (*PipeReader):
    Read(data []byte) (n int, err error)
    Close() error
    CloseWithError(err error) error

// writer
type PipeWriter struct {
    p *pipe
}
 
A PipeWriter is the write half of a pipe.
方法对象: (*PipeWriter):
    Write(data []byte) (n int, err error)
    Close() error
    CloseWithError(err error) error

从上面可以看到这两个值并未被赋予别的值。但在下面这里要将其内内容赋值给 conn.writer

if _, err := io.Copy(conn.writer, pipeReader); err != nil {
    pipeWriter.CloseWithError(err)
    return
}

但这确实是获得到了 TCP 传输时内 SEND 之后的Writer传入的内容,其数据来源是 io.pipe 的缓存区域,由 io.pipe 的回参的 writer 写入,由 reader 读取,也就是说 SEND 方法 func (conn *Conn) Send(key string) (io.WriteCloser, error) 传出 pipeWriter 出去之后写入数据:

writer, err := connObj.Send(key)
...
data := "Hello, World!"
if _, err := writer.Write([]byte(data)); err != nil {
    fmt.Printf("Failed to write data: %v\n", err)
    return
}

调用了 SEND 之后得到 writer 写入,也就是 pipeWriter,写入数据到 io.pipe 的缓存空间。这样 SEND 内的写成里的 io.Copy(conn.writer, pipeReader) 因为之前读取不到数据一致阻塞,现在写入了数据就能读取到,也就能接着往下运行了。

另外为什么明明TCP有直接的实现WriteRead还是要在 Conn 里面写 writerreader 缓冲区去进行写入和读取呢?

这是因为 bufio.Readerbufio.Writer 提供了缓冲机制,可以减少底层 TCP 连接的读写操作次数。
直接使用 net.ConnReadWrite 方法时,每次调用都会触发底层的系统调用,这可能会导致较高的开销。而 bufio 的缓冲机制可以将多次小的读写操作合并为一次较大的操作,从而提高性能。减少系统调用次数,优化网络开销。

综上这就是这一次实现里面我所学习思考和设计的内容。

PREV
[算法学习] 线性动态规划顺/逆序

评论(0)

发布评论