Go实现简单的RPC框架项目

knoci 发布于 2025-02-26 111 次阅读


kRPC

​ kRPC是参考极客兔兔的7days-golang系列的GeeRPC,其中做了中文注释和一些细节的修改。本文对项目的关键技术要点进行解释。

消息编码

​ 一个典型的 RPC 调用如下:

err = client.Call("Arith.Multiply", args, &reply)

​ 客户端发送的请求包括服务名 Arith,方法名 Multiply,参数 args三个,服务端的响应包括错误 error,返回值 reply 2 个

​ 我们将请求和响应中的参数和返回值抽象为 body,剩余的信息放在 header 中,那么就可以抽象出数据结构 Header

type Header struct {
    ServiceMethod string // 服务名和方法名
    Seq           uint64 // 请求的序号
    Error         string
}

​ 进一步,抽象出对消息体进行编解码的接口 Codec,抽象出接口是为了实现不同的 Codec 实例

// 编码接口
type Codec interface {
    io.Closer
    ReadHeader(*Header) error
    ReadBody(interface{}) error
    Write(*Header, interface{}) error
}

​ 紧接着,抽象出 Codec 的构造函数,客户端和服务端可以通过 Codec 的 Type 得到构造函数

type NewCodecFunc func(io.ReadWriteCloser)Codec //NewCodecFunc函数接收一个 io.ReadWriteCloser 类型的参数,并返回一个 Codec 类型的值

type Type string

const (
    GobType  Type = "application/gob"
    JsonType Type = "application/json"
)

var NewCodecFuncMap map[Type]NewCodecFunc

func init() {
    NewCodecFuncMap = make(map[Type]NewCodecFunc)
    NewCodecFuncMap[GobType] = NewGobCodec
    NewCodecFuncMap[JsonType] = NewJsonCodec
}

​ Codec作为编解码的接口,抽象出来的结构体要实现消息编码、消息解码、消息写入三个重要功能,这里用JsonCodec来做例子,dec 和 enc 对应 Json 的 Decoder 和 Encoder,buf 是为了防止阻塞而创建的带缓冲的 Writer

ReadHeader()函数和ReadBody()用dec对接收的消息解码,Write()函数用enc解码消息到buf然后Flush()写入Conn连接

// JsonCodec 结构体
type JsonCodec struct {
    conn io.ReadWriteCloser
    buf  *bufio.Writer
    dec  *json.Decoder
    enc  *json.Encoder
}

// 类型断言,确保 JsonCodec 实现了 Codec 接口
var _ Codec = (*JsonCodec)(nil)

// NewJsonCodec 创建一个新的 JsonCodec 实例
func NewJsonCodec(conn io.ReadWriteCloser) Codec {
    buf := bufio.NewWriter(conn)
    return &JsonCodec{
        conn: conn,
        buf:  buf,
        dec:  json.NewDecoder(conn), // 连接conn的 json 解码器
        enc:  json.NewEncoder(buf),  // 缓冲buf的 json 编码器
    }
}

// ReadHeader 从连接中读取头部信息
func (c *JsonCodec) ReadHeader(h *Header) error {

    return c.dec.Decode(h)
}

// ReadBody 从连接中读取正文内容
func (c *JsonCodec) ReadBody(body interface{}) error {
    return c.dec.Decode(body)
}

// Write 将头部和正文写入连接
func (c *JsonCodec) Write(h *Header, body interface{}) (err error) {
    defer func() {
        _ = c.buf.Flush()
        if err != nil {
            _ = c.Close()
        }
    }()
    if err := c.enc.Encode(h); err != nil {
        log.Println("rpc codec: json error encoding header:", err)
        return err
    }
    if err := c.enc.Encode(body); err != nil {
        log.Println("rpc codec: json error encoding body:", err)
        return err
    }
    return nil
}

// Close 关闭连接
func (c *JsonCodec) Close() error {
    return c.conn.Close()
}

单机服务器端和客户端

服务器

​ 对于 kRPC 来说,目前需要协商的唯一一项内容是消息的编解码方式。我们将这部分信息,放到服务器端结构体 Option 中承载

type Option struct {
    MagicNumber    int // MagicNumber 用于标识这是一个RPC请求
    CodecType      codec.Type
    ConnectTimeout time.Duration // 0 means no limit
    HandleTimeout  time.Duration
}

var DefaultOption = &Option{
    MagicNumber:    MagicNumber,
    CodecType:      codec.GobType,
    ConnectTimeout: time.Second * 10,
}

​ 服务器端由启动,监听网络连接,并为每个新连接启动一个 Aceept(Conn)ServeConn 协程来处理请求

ServeConn首先使用 json.NewDecoder 反序列化得到 Option 实例,检查 MagicNumber 和 CodeType 的值是否正确,然后根据 CodeType 得到对应的消息编解码器,接下来的处理交给 serverCodec

// Accept 监听网络连接,并为每个新连接启动一个协程来处理请求
func (server *Server) Accept(lis net.Listener) {
    for {
        conn, err := lis.Accept()
        if err != nil {
            log.Println("rpc server: accept error:", err)
            return
        }
        go server.ServeConn(conn)
    }
}

// Accept accepts connections on the listener and serves requests for each incoming connection.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }

// ServeConn 处理单个客户端连接。
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
    defer func() { _ = conn.Close() }()
    var opt Option
    if err := json.NewDecoder(conn).Decode(&opt); err != nil {
        log.Println("rpc server: options error: ", err)
        return
    }
    if opt.MagicNumber != MagicNumber {
        log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
        return
    }
    f := codec.NewCodecFuncMap[opt.CodecType]
    if f == nil {
        log.Printf("rpc server: invalid codec type %s", opt.CodecType)
        return
    }
    server.serveCodec(f(conn))
}

// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}

serveCodec的过程非常简单。主要包含三个阶段

  • 读取请求 readRequest
  • 处理请求 handleRequest
  • 回复请求 sendResponse
// serveCodec 处理来自客户端的RPC请求
func (server *Server) serveCodec(cc codec.Codec) {
    sending := new(sync.Mutex) // 确保发送完整的响应
    wg := new(sync.WaitGroup)  // 等待所有请求处理完毕
    for {
        req, err := server.readRequest(cc) // 读取请求
        if err != nil {
            if req == nil {
                break // 无法恢复,请关闭连接
            }
            req.h.Error = err.Error()
            server.sendResponse(cc, req.h, invalidRequest, sending) //
            continue
        }
        wg.Add(1)
        go server.handleRequest(cc, req, sending, wg, DefaultOption.ConnectTimeout) // 回复请求
    }
    wg.Wait()
    _ = cc.Close()
}

​ 然后我们对request进行抽象,根据典型的rpc调用,request肯定包括调用的服务service、访问传入参数args、传出指针&reply,我们再加上包含请求信息的请求头,就构成了request结构体。

err = client.Call("service", args, &reply)

type request struct {
    h            *codec.Header // 请求头
    argv, replyv reflect.Value // argv and replyv of request
    mtype        *service.MethodType
    svc          *service.Service
}

​ 下面是处理request的方法,其中findService()和一些相关service代码会在服务注册中讲到,这里理解处理request的大概流程就行。

​ 我们可以看到在代码中读取和返回参数都运用了反射,因为在RPC框架中,客户端和服务端之间传递的参数和返回值的类型是动态的,无法在编译时确定,而反射允许程序在运行时检查这些参数和返回值的类型,并动态地创建和操作它们

​ 此外,还通过<-time.After(timeout)实现了超时机制

// 从连接中读取RPC请求的头部和参数
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
    var h codec.Header
    if err := cc.ReadHeader(&h); err != nil {
        if err != io.EOF && err != io.ErrUnexpectedEOF {
            log.Println("rpc server: read header error:", err)
        }
        return nil, err
    }
    return &h, nil
}

func (server *Server) readRequest(cc codec.Codec) (*request, error) {
    // 读取请求头信息
    h, err := server.readRequestHeader(cc)
    if err != nil {
        return nil, err
    }
    // 初始化请求对象
    req := &request{h: h}
    // findService根据请求头中的 ServiceMethod 查找对应的服务和方法
    req.svc, req.mtype, err = server.findService(h.ServiceMethod)
    if err != nil {
        return req, err
    }
    // 创建请求参数和返回值的反射值
    req.argv = req.mtype.NewArgv()
    req.replyv = req.mtype.NewReplyv()
    // 确保 argv 是一个指针类型,因为 codec.ReadBody 需要一个指针作为参数
    argvi := req.argv.Interface()
    if req.argv.Type().Kind() != reflect.Ptr {
        argvi = req.argv.Addr().Interface() // 如果不是指针类型,取地址
    }
    // 读取请求正文并填充到 argv 中
    if err = cc.ReadBody(argvi); err != nil {
        log.Println("rpc server: read body err:", err)
        return req, err
    }
    return req, nil
}

// 将响应写入连接
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
    sending.Lock()
    defer sending.Unlock()
    // 调用接编码接口codec的write方法把h和body写入conn
    if err := cc.Write(h, body); err != nil {
        log.Println("rpc server: write response error:", err)
    }
}

// 处理RPC请求
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
    defer wg.Done()
    called := make(chan struct{}) // 方法调用标识管道
    sent := make(chan struct{}) // 消息返回标识管道
    go func() {
        err := req.svc.Call(req.mtype, req.argv, req.replyv) //调用方法
        called <- struct{}{} // 向管道传入空结构体表示已调用
        if err != nil {
            req.h.Error = err.Error()
            server.sendResponse(cc, req.h, invalidRequest, sending)
            sent <- struct{}{}
            return
        }
        server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
        sent <- struct{}{}
    }()
    // 如果超时时间timeout设置为0,则堵塞等待方法调用和消息返回
    if timeout == 0 {
        <-called
        <-sent
        return
    }
    select {
        case <-time.After(timeout): // 经过时间timeout后向time管道传入,用于超时提醒
        req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
        server.sendResponse(cc, req.h, invalidRequest, sending)
    case <-called:
        <-sent
    }
}

客户端

​ 首先我们先不谈client的结构体,我们都知道调用client.Call()函数能够向服务器端发送一次Call调用,那么问题来了,我们作为Client端要怎么样向服务器传递一次Call调用呢?

​ 显然我们可以给服务器传递一个Call结构体,Call结构体中就包含了调用的所有信息

// Call 代表一个活跃的RPC调用
type Call struct {
    Seq           uint64      // 调用唯一序号
    ServiceMethod string      // 格式 "<服务名>.<方法名>"
    Args          interface{} // 函数参数
    Reply         interface{} // 函数返回结果
    Error         error       // 调用错误信息
    Done          chan *Call  // 调用完成通知通道
}

// 把当前Call放入Done管道,Done存在Call即证明调用已完成
func (call *Call) done() {
    call.Done <- call
}

​ 接着我们设计出client,还有实现call的三个方法

// Client 表示一个RPC客户端
// 单个客户端可以关联多个未完成调用,且支持多协程并发使用
type Client struct {
    cc       codec.Codec      // 编解码器
    opt      *server.Option   // 协议选项
    sending  sync.Mutex       // 发送锁,保证请求原子性
    header   codec.Header     // 请求头(复用减少内存分配)
    mu       sync.Mutex       // 全局锁(保护以下字段)
    seq      uint64           // 请求序号生成器(原子递增)
    pending  map[uint64]*Call // 未完成调用映射表,键是编号,值是 Call 实例。
    closing  bool             // 用户主动关闭标记
    shutdown bool             // 服务端要求关闭标记
}

// 注册RPC调用到待处理队列
func (client *Client) registerCall(call *Call) (uint64, error) {
    client.mu.Lock()
    defer client.mu.Unlock()
    if client.closing || client.shutdown {
        return 0, ErrShutdown
    }
    call.Seq = client.seq
    client.pending[call.Seq] = call // 往当前序号的pending数组中加入call调用
    client.seq++ // 序号递增
    return call.Seq, nil
}

// 从待处理队列移除调用
func (client *Client) removeCall(seq uint64) *Call {
    client.mu.Lock()
    defer client.mu.Unlock()
    call := client.pending[seq]
    delete(client.pending, seq)
    return call
}

// 终止所有未完成调用(异常处理)
func (client *Client) terminateCalls(err error) {
    client.sending.Lock()
    defer client.sending.Unlock()
    client.mu.Lock()
    defer client.mu.Unlock()
    client.shutdown = true // 服务端要求关闭标记置为true
    for _, call := range client.pending { 
        call.Error = err
        call.done()
    }
}

对一个客户端端来说,接收响应、发送请求是最重要的 2 个功能。那么首先实现接收功能,接收到的响应有三种情况:

  • call 不存在,可能是请求没有发送完整,或者因为其他原因被取消,但是服务端仍旧处理了。
  • call 存在,但服务端处理出错,即 h.Error 不为空。
  • call 存在,服务端处理正常,那么需要从 body 中读取 Reply 的值。
// 接收响应协程(核心事件循环)
func (client *Client) receive() {
    var err error
    for err == nil { // 持续处理响应
        var h codec.Header
        if err = client.cc.ReadHeader(&h); err != nil {
            break
        }
        call := client.removeCall(h.Seq)
        switch {
        case call == nil: // 写部分失败后的无效响应
            err = client.cc.ReadBody(nil)
        case h.Error != "": // 服务端返回错误
            call.Error = fmt.Errorf(h.Error)
            err = client.cc.ReadBody(nil)
            call.done()
        default: // 正常处理响应体
            err = client.cc.ReadBody(call.Reply)
            if err != nil {
                call.Error = errors.New("reading body " + err.Error())
            }
            call.done()
        }
    }
    // 发生错误时终止所有调用
    client.terminateCalls(err)
}

​ 然后实现发送请求的能力,Call传入方法、参数、返回指针、完成管道给Go()函数,Go()函数封装为Call消息调用结构体,并且通过send()处理结构体

send()中使用sending.Lock()加锁保证同一时间只有一个Call能返回消息,不然就会造成消息的紊乱;随后准备消息头部,和消息参数一起发送给codec的conn

// Call 调用指定的服务方法,并等待调用完成。
// 如果调用失败,返回错误。
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    // 发起调用,并等待完成
    call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
    select {
    case <-ctx.Done(): // 如果上下文完成或超时
        client.removeCall(call.Seq) // 移除调用
        return errors.New("rpc client: call failed: " + ctx.Err().Error())
    case call := <-call.Done: // 等待调用完成
        return call.Error
    }
}

// Go 异步调用入口(非阻塞)
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
    if done == nil {
        done = make(chan *Call, 10) // 缓冲通道优化
    } else if cap(done) == 0 {
        log.Panic("rpc client: done channel need cache")
    }
    call := &Call{
        ServiceMethod: serviceMethod,
        Args:          args,
        Reply:         reply,
        Done:          done,
    }
    client.send(call)
    return call
}

// 发送请求(核心发送逻辑)
func (client *Client) send(call *Call) {
    client.sending.Lock()         // 保证请求原子性
    defer client.sending.Unlock() // 确保锁释放

    // 注册调用
    seq, err := client.registerCall(call)
    if err != nil {
        call.Error = err
        call.done()
        return
    }

    // 准备回复的头部
    client.header.ServiceMethod = call.ServiceMethod
    client.header.Seq = seq
    client.header.Error = ""

    // 编码并发送请求(核心IO操作)
    if err := client.cc.Write(&client.header, call.Args); err != nil {
        call := client.removeCall(seq)
        if call != nil { // 部分写入失败处理
            call.Error = err
            call.done()
        }
    }
}

​ 最后实现 Dial 函数,该函数用来创建客户端和服务器的连接

newClientFunc方法初始化一个client实例

parseOptions()解析Option里的MagicNumber是否正确,CodecType是什么类型

clientResult是一个包括client指针和err错位的封装结构体

// dialTimeout 尝试连接到指定的网络地址,并创建一个 RPC 客户端。
// 如果连接超时,则返回错误。
func dialTimeout(f newClientFunc, network, address string, opts ...*server.Option) (client *Client, err error) {
    // 解析选项
    opt, err := parseOptions(opts...)
    if err != nil {
        return nil, err
    }

    // 尝试建立连接,带超时
    conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
    if err != nil {
        return nil, err
    }

    // 如果客户端创建失败,关闭连接
    defer func() {
        if err != nil {
            _ = conn.Close()
        }
    }()

    // 创建一个通道用于接收客户端创建的结果
    ch := make(chan clientResult)
    go func() {
        client, err := f(conn, opt) // 调用传入的客户端创建函数
        ch <- clientResult{client: client, err: err}
    }()

    // 如果没有设置连接超时,则直接等待结果
    if opt.ConnectTimeout == 0 {
        result := <-ch
        return result.client, result.err
    }

    // 使用 select 等待超时或客户端创建完成
    select {
    case <-time.After(opt.ConnectTimeout):
        return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
    case result := <-ch:
        return result.client, result.err
    }
}

// Dial 连接到指定的 RPC 服务器地址,是dialTimeout 的一个封装
func Dial(network, address string, opts ...*server.Option) (*Client, error) {
    return dialTimeout(NewClient, network, address, opts...)
}

服务注册

​ 之前在服务端中挖坑,说到了service代码会在服务注册中讲到,从服务器的readRequest()函数中,我们可以推见service部分的主要功能,就是服务的查找注册

​ 打个比方,我们可以把RPC过程看做点外卖,服务端是一个奶茶外卖店,而客户端是一个口渴的用户,此时他想要点一杯好喝到咩噗茶,那么,前提条件就是奶茶外卖的商店界面有"好喝到咩噗茶"这个货品,不然根本没有这个商品,客户不可能完成下单

​ 于是service部分的服务注册就类似奶茶店的商品上架,即让客户端能够选择服务,服务器端能够知晓并且处理服务

​ 首先我们先抽象方法结构体,就类似于奶茶店的好喝到咩噗茶;newArgv()newReplyv()都是对方法的规定

// MethodType 用于描述一个方法的元信息,包括方法本身、参数类型、返回类型以及调用次数。
type MethodType struct {
    method    reflect.Method // 方法的反射信息
    ArgType   reflect.Type   // 参数的类型
    ReplyType reflect.Type   // 返回的类型
    numCalls  uint64         // 方法被调用的次数
}

// NumCalls 返回方法被调用的次数。
func (m *MethodType) NumCalls() uint64 {
    return atomic.LoadUint64(&m.numCalls) // 使用原子操作确保线程安全
}

// newArgv 创建一个新的参数值,支持指针类型和值类型。
func (m *MethodType) NewArgv() reflect.Value {
    if m.ArgType.Kind() == reflect.Ptr {
        return reflect.New(m.ArgType.Elem()) // 如果是引用类型,创建一个指向底层类型的指针
    }
    return reflect.New(m.ArgType).Elem() // 如果是值类型,直接创建值
}

// newReplyv 创建一个新的返回值,支持初始化复杂类型(如 map 和 slice)。
func (m *MethodType) NewReplyv() reflect.Value {
    replyv := reflect.New(m.ReplyType.Elem())
    switch m.ReplyType.Elem().Kind() {
    case reflect.Map:
        replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) // MakeMap创建具有指定类型的新映射
    case reflect.Slice:
        replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) // 初始化空 slice
    }
    return replyv
}

​ 之后,定义结构体 service,类似奶茶店,service可以注册多个Method,就像奶茶店可以有多款奶茶

// Service 表示一个 RPC 服务,包含服务名称、类型、接收者以及方法映射。
type Service struct {
    Name   string                 // 服务名称
    typ    reflect.Type           // 结构体的类型
    Rcvr   reflect.Value          // 结构体的实例本身,保留 rcvr 是因为在调用时需要 rcvr 作为第 0 个参数
    Method map[string]*MethodType // 方法映射,键为方法名
}

// newService 创建一个新的 RPC 服务实例,并注册其中的方法。
func NewService(rcvr interface{}) *Service {
    s := &Service{
        Rcvr: reflect.ValueOf(rcvr),
        Name: reflect.Indirect(reflect.ValueOf(rcvr)).Type().Name(),
        typ:  reflect.TypeOf(rcvr),
    }
    if !isExportedOrBuiltinType(s.typ) {
        log.Fatalf("rpc server: %s is not a valid Service name", s.Name)
    }
    s.registerMethods() // 注册服务中的方法
    return s
}

// registerMethods 遍历服务类型的所有方法,并注册符合条件的方法。
func (s *Service) registerMethods() {
    s.Method = make(map[string]*MethodType)
    for i := 0; i < s.typ.NumMethod(); i++ {
        method := s.typ.Method(i)
        mType := method.Type
        // 检查方法签名是否符合 RPC 方法的要求:2 个参数(接收者除外)和 1 个返回值,且返回值为 error 类型
        if mType.NumIn() != 3 || mType.NumOut() != 1 || mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
            continue
        }
        argType, replyType := mType.In(1), mType.In(2)
        if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
            continue // 参数和返回类型必须是导出的或内置类型
        }
        s.Method[method.Name] = &MethodType{
            method:    method,
            ArgType:   argType,
            ReplyType: replyType,
        }
        log.Printf("rpc server: register %s.%sn", s.Name, method.Name)
    }
}

// isExportedOrBuiltinType 检查一个类型是否是导出的或内置类型。
func isExportedOrBuiltinType(t reflect.Type) bool {
    return ast.IsExported(t.Name()) || t.PkgPath() == ""
}

​ 最后,我们还需要实现 call方法,即能够通过反射值调用方法。

// call 调用一个注册的方法,并返回错误(如果有)。
func (s *Service) Call(m *MethodType, argv, replyv reflect.Value) error {
    atomic.AddUint64(&m.numCalls, 1) // 增加方法调用次数
    returnValues := m.method.Func.Call([]reflect.Value{s.Rcvr, argv, replyv}) // 方法调用的第一个参数是接收者
    if errInter := returnValues[0].Interface(); errInter != nil {
        return errInter.(error) // 如果返回值是非 nil 的 error,返回错误
    }
    return nil
}

支持HTTP协议

​ Web 开发中,我们经常使用 HTTP 协议中的 HEAD、GET、POST 等方式发送请求,等待响应。但 RPC 的消息格式与标准的 HTTP 协议并不兼容,在这种情况下,就需要一个协议的转换过程。HTTP 协议的 CONNECT 方法恰好提供了这个能力,CONNECT 一般用于代理服务。

​ 假设浏览器与服务器之间的 HTTPS 通信都是加密的,浏览器通过代理服务器发起 HTTPS 请求时,由于请求的站点地址和端口号都是加密保存在 HTTPS 请求报文头中的,代理服务器如何知道往哪里发送请求呢?为了解决这个问题,浏览器通过 HTTP 明文形式向代理服务器发送一个 CONNECT 请求告诉代理服务器目标地址和端口,代理服务器接收到这个请求后,会在对应端口与目标站点建立一个 TCP 连接,连接建立成功后返回 HTTP 200 状态码告诉浏览器与该站点的加密通道已经完成。接下来代理服务器仅需透传浏览器和服务器之间的加密数据包即可,代理服务器无需解析 HTTPS 报文。

服务器

​ 要支持HTTP 协议的 CONNECT 方法,需要用ServeHTTP()实现 http.Handler 接口,用于处理 HTTP 请求

const (
    connected        = "200 Connected to kRPC" // HTTP 连接成功的响应消息
    defaultRPCPath   = "/_kprc_"               // 默认的 RPC 请求路径
    defaultDebugPath = "/debug/krpc"           // 默认的调试路径
    MagicNumber      = 0x3bef5c
)

// ServeHTTP 实现了 http.Handler 接口,用于处理 HTTP 请求。
// 它只支持 HTTP 的 CONNECT 方法,用于建立 RPC 连接。
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    if req.Method != "CONNECT" {
        // 如果请求方法不是 CONNECT,返回 405 Method Not Allowed
        w.Header().Set("Content-Type", "text/plain; charset=utf-8")
        w.WriteHeader(http.StatusMethodNotAllowed)
        _, _ = io.WriteString(w, "405 must CONNECTn")
        return
    }

    // 使用 Hijacker 接管连接
    conn, _, err := w.(http.Hijacker).Hijack()
    if err != nil {
        log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
        return
    }

    // 向客户端发送连接成功的 HTTP 响应
    _, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"nn")

    // 调用 ServeConn 方法处理 RPC 请求
    server.ServeConn(conn)
}

// HandleHTTP 注册一个 HTTP 处理器,用于处理默认路径上的 RPC 请求。
// 它将 RPC 请求转发到 `ServeHTTP` 方法。
func (server *Server) HandleHTTP() {
    http.Handle(defaultRPCPath, server) // 注册默认的 RPC 路径
}

// HandleHTTP 是一个便捷方法,用于默认服务器注册 HTTP 处理器。
func HandleHTTP() {
    DefaultServer.HandleHTTP() // 调用默认服务器的 HandleHTTP 方法
}

客户端

​ 服务端已经能够接受 CONNECT 请求,并返回了 200 状态码 ,客户端要做的,发起 CONNECT 请求,检查返回状态码 HTTP/1.0 200 Connected to Gee RPC 即可成功建立连接

/ NewHTTPClient 通过 HTTP 协议创建一个新的 RPC 客户端实例。
func NewHTTPClient(conn net.Conn, opt *server.Option) (*Client, error) {
    // 向服务器发送 CONNECT 请求,尝试建立 RPC 连接
    _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0nn", defaultRPCPath))

    // 读取 HTTP 响应,确保连接成功
    resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
    if err == nil && resp.Status == connected {
        return NewClient(conn, opt) // 如果连接成功,创建 RPC 客户端
    }
    if err == nil {
        err = errors.New("unexpected HTTP response: " + resp.Status) // 如果响应状态不匹配,返回错误
    }
    return nil, err
}

// DialHTTP 连接到一个监听在默认 HTTP RPC 路径上的 HTTP RPC 服务器。
func DialHTTP(network, address string, opts ...*server.Option) (*Client, error) {
    return dialTimeout(NewHTTPClient, network, address, opts...) // 使用 NewHTTPClient 创建客户端
}

​ 通过 HTTP CONNECT 请求建立连接之后,后续的通信过程就交给 NewClient 了,为了简化调用,提供了一个统一入口 XDial

// XDial 根据 rpcAddr 的协议部分调用不同的函数来连接到 RPC 服务器。
// rpcAddr 是一个通用格式(protocol@addr),用于表示 RPC 服务器的方式。
// 示例:
//  http@10.0.0.1:7001
//  udp@10.0.0.1:9999
func XDial(rpcAddr string, opts ...*server.Option) (*Client, error) {
    parts := strings.Split(rpcAddr, "@")
    if len(parts) != 2 {
        return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
    }
    protocol, addr := parts[0], parts[1] // 分离协议和地址
    switch protocol {
    case "http":
        return DialHTTP("tcp", addr, opts...) // 对于 HTTP 协议,调用 DialHTTP
    case "tcp":
        return Dial("tcp", addr, opts...) // 对于 TCP 协议,调用 Dial
    case "udp":
        return Dial("udp", addr, opts...) // 对于 UDP 协议,调用 Dial
    default:
        return Dial(protocol, addr, opts...) // 调用 Dial
    }
}

负载均衡

​ 负载均衡是一种通过将网络流量或工作任务分配到多个服务器上来提高系统性能、可靠性和可扩展性的技术,在众多负载均衡策略中,我选择简单实现以下两个:

  • 随机选择策略 - 从服务列表中随机选择一个。
  • 轮询算法(Round Robin) - 依次调度不同的服务器,每次调度执行 i = (i + 1) mode n。

​ 负载均衡的前提是有多个服务实例,那我们首先实现一个最基础的服务发现模块 Discovery,支持多个服务

// SelectMode 定义了选择服务器的模式。
type SelectMode int

const (
    // 随机选择服务器
    RandomSelect SelectMode = iota
    // 使用轮询算法选择服务器
    RoundRobinSelect
)

// Discovery 是服务发现接口,定义了服务发现的基本操作。
type Discovery interface {
    // 刷新服务器列表(从远程注册中心获取)
    Refresh() error
    // 更新服务器列表
    Update(servers []string) error
    // 根据选择模式获取一个服务器
    Get(mode SelectMode) (string, error)
    // 获取所有服务器
    GetAll() ([]string, error)
}

​ 紧接着,我们实现一个不需要注册中心,服务列表由手工维护的服务发现的结构体MultiServersDiscovery

// MultiServersDiscovery 是一个不依赖于注册中心的服务发现实现。
// 用户需要显式提供服务器地址。
type MultiServersDiscovery struct {
    r       *rand.Rand   // 生成随机数
    mu      sync.RWMutex // 保护以下字段
    servers []string     // 服务器列表
    index   int          // 轮询算法的当前位置
}

// NewMultiServerDiscovery 创建一个 MultiServersDiscovery 实例。
func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery {
    d := &MultiServersDiscovery{
        servers: servers,
        r:       rand.New(rand.NewSource(time.Now().UnixNano())), // 初始化随机数生成器
    }
    d.index = d.r.Intn(math.MaxInt32 - 1) // 初始化轮询算法的起始位置
    return d
}

​ 然后,实现 Discovery 接口

// Refresh 方法在 MultiServersDiscovery 中没有实际意义,因此直接返回 nil。
func (d *MultiServersDiscovery) Refresh() error {
    return nil
}

// Update 方法动态更新服务器列表。
func (d *MultiServersDiscovery) Update(servers []string) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.servers = servers // 更新服务器列表
    return nil
}

// Get 方法根据选择模式获取一个服务器。
func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) {
    d.mu.Lock()
    defer d.mu.Unlock()
    n := len(d.servers)
    if n == 0 {
        return "", errors.New("rpc discovery: no available servers") // 如果服务器列表为空,返回错误
    }
    switch mode {
    case RandomSelect:
        return d.servers[d.r.Intn(n)], nil // 随机选择一个服务器
    case RoundRobinSelect:
        s := d.servers[d.index%n]   // 使用轮询算法选择服务器
        d.index = (d.index + 1) % n // 更新轮询算法的当前位置
        return s, nil
    default:
        return "", errors.New("rpc discovery: not supported select mode") // 不支持的选择模式
    }
}

// GetAll 方法返回所有服务器。
func (d *MultiServersDiscovery) GetAll() ([]string, error) {
    d.mu.RLock()
    defer d.mu.RUnlock()
    // 返回服务器列表的副本
    servers := make([]string, len(d.servers))
    copy(servers, d.servers)
    return servers, nil
}

​ 接下来,我们向用户暴露一个支持负载均衡的客户端 XClient,这个Client可以从Diocovery中选择服务器列表中的服务器进行通信

​ XClient 的构造函数需要传入三个参数,服务发现实例 Discovery、负载均衡模式 SelectMode 以及协议选项 Option

// XClient 是一个分布式 RPC 客户端,通过服务发现机制动态选择服务器。
type XClient struct {
    d       Discovery          // 服务发现接口
    mode    SelectMode         // 选择模式(随机选择或轮询选择)
    opt     *server.Option     // RPC 客户端选项
    mu      sync.Mutex         // 保护以下字段
    clients map[string]*Client // 存储已连接的 RPC 客户端
}

​ 接下来,实现客户端最基本的功能 Call

// dial 尝试连接到指定的 RPC 服务器地址。
func (xc *XClient) dial(rpcAddr string) (*Client, error) {
    xc.mu.Lock()
    defer xc.mu.Unlock()
    client, ok := xc.clients[rpcAddr]
    if ok && !client.IsAvailable() {
        // 如果客户端不可用,关闭并移除
        _ = client.Close()
        delete(xc.clients, rpcAddr)
        client = nil
    }
    if client == nil {
        // 如果尚未连接,创建新的客户端
        var err error
        client, err = XDial(rpcAddr, xc.opt)
        if err != nil {
            return nil, err
        }
        xc.clients[rpcAddr] = client
    }
    return client, nil
}

// call 在指定的 RPC 服务器上执行调用。
func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
    client, err := xc.dial(rpcAddr)
    if err != nil {
        return err
    }
    return client.Call(ctx, serviceMethod, args, reply)
}

// Call 在一个合适的服务器上执行调用。
func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    rpcAddr, err := xc.d.Get(xc.mode) // 根据选择模式获取服务器地址
    if err != nil {
        return err
    }
    return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}

​ 另外,我们为 XClient 添加一个常用功能:Broadcast,这里的clonedReply是对Reply指针克隆,因为多个go程用同一个Reply指针肯定会出现写入混乱

// Broadcast 在所有注册的服务器上广播调用。
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    servers, err := xc.d.GetAll() // 获取所有服务器地址
    if err != nil {
        return err
    }
    var wg sync.WaitGroup
    var mu sync.Mutex // 保护 e 和 replyDone
    var e error
    replyDone := reply == nil // 如果 reply 为 nil,则不需要设置值
    ctx, cancel := context.WithCancel(ctx)
    defer cancel() // 确保在函数返回时调用 cancel
    for _, rpcAddr := range servers {
        wg.Add(1)
        go func(rpcAddr string) {
            defer wg.Done()
            var clonedReply interface{}
            if reply != nil {
                // 如果 reply 不为 nil,克隆 reply 以避免并发写入
                clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
            }
            err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
            mu.Lock()
            if err != nil && e == nil {
                e = err
                cancel() // 如果任何调用失败,取消未完成的调用
            }
            if err == nil && !replyDone {
                // 如果调用成功且 reply 不为 nil,设置 reply 的值
                reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
                replyDone = true
            }
            mu.Unlock()
        }(rpcAddr)
    }
    wg.Wait()
    return e
}

注册中心

​ 注册中心的位置如上图所示。注册中心的好处在于,客户端和服务端都只需要感知注册中心的存在,而无需感知对方的存在。更具体一些:

  1. 服务端启动后,向注册中心发送注册消息,注册中心得知该服务已经启动,处于可用状态。一般来说,服务端还需要定期向注册中心发送心跳,证明自己还活着。
  2. 客户端向注册中心询问,当前哪天服务是可用的,注册中心将可用的服务列表返回客户端。
  3. 客户端根据注册中心得到的服务列表,选择其中一个发起调用。

​ 首先定义Registry 结构体,默认超时时间设置为 5 min,也就是说,任何注册的服务超过 5 min,即视为不可用状态

// Registry 是一个简单的注册中心,提供以下功能:
// 1. 添加服务器并接收心跳以保持其活跃状态。
// 2. 同时返回所有活跃的服务器,并同步删除已死亡的服务器。
type Registry struct {
    timeout time.Duration          // 心跳超时时间
    mu      sync.Mutex             // 保护以下字段
    servers map[string]*ServerItem // 存储服务器信息
}

// ServerItem 表示一个服务器的信息。
type ServerItem struct {
    Addr  string    // 服务器地址
    start time.Time // 服务器注册时间
}

const (
    defaultPath    = "/_krpc_/registry" // 默认的注册路径
    defaultTimeout = time.Minute * 5    // 默认的心跳超时时间(5分钟)
)

​ 为 Registry 实现添加服务实例和返回服务列表的方法

// putServer 添加或更新一个服务器到注册中心。
func (r *Registry) putServer(addr string) {
    r.mu.Lock()
    defer r.mu.Unlock()
    s := r.servers[addr]
    if s == nil {
        r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()} // 新增服务器
    } else {
        s.start = time.Now() // 如果服务器已存在,更新其启动时间以保持活跃
    }
}

// aliveServers 返回所有活跃的服务器地址。
func (r *Registry) aliveServers() []string {
    r.mu.Lock()
    defer r.mu.Unlock()
    var alive []string
    for addr, s := range r.servers {
        if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) { // 检查是否超时
            alive = append(alive, addr)
        } else {
            delete(r.servers, addr) // 删除超时的服务器
        }
    }
    sort.Strings(alive) // 对活跃服务器地址进行排序
    return alive
}

​ 为了实现上的简单,Registry 采用 HTTP 协议提供服务,且所有的有用信息都承载在 HTTP Header 中

// ServeHTTP 实现了 http.Handler 接口,用于处理注册中心的 HTTP 请求。
func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    switch req.Method {
    case "GET":
        // 返回所有活跃的服务器地址,通过 HTTP 头部返回
        w.Header().Set("X-krpc-Servers", strings.Join(r.aliveServers(), ","))
    case "POST":
        // 注册一个服务器,服务器地址通过 HTTP 头部传递
        addr := req.Header.Get("X-krpc-Server")
        if addr == "" {
            w.WriteHeader(http.StatusInternalServerError) // 如果地址为空,返回 500 错误
            return
        }
        r.putServer(addr) // 添加或更新服务器
    default:
        w.WriteHeader(http.StatusMethodNotAllowed) // 不支持的 HTTP 方法
    }
}

// HandleHTTP 注册一个 HTTP 处理器,用于处理注册中心的消息。
func (r *Registry) HandleHTTP(registryPath string) {
    http.Handle(registryPath, r)                    // 注册 HTTP 处理器
    log.Println("rpc registry path:", registryPath) // 打印注册路径
}

// HandleHTTP 是一个便捷方法,用于默认注册中心实例注册 HTTP 处理器。
func HandleHTTP() {
    DefaultGeeRegister.HandleHTTP(defaultPath) // 调用默认注册中心的 HandleHTTP 方法
}

​ 另外,提供 Heartbeat 方法,便于服务启动时定时向注册中心发送心跳,默认周期比注册中心设置的过期时间少 1 min

// Heartbeat 定期向注册中心发送心跳消息。
// 它是一个辅助函数,用于服务器注册或发送心跳。
func Heartbeat(registry, addr string, duration time.Duration) {
    if duration == 0 {
        // 确保在被注册中心移除之前有足够的时间发送心跳。
        // 默认心跳间隔为超时时间减去 1 分钟。
        duration = defaultTimeout - time.Duration(1)*time.Minute
    }
    var err error
    err = sendHeartbeat(registry, addr) // 首次发送心跳
    go func() {
        t := time.NewTicker(duration) // 创建一个定时器
        for err == nil {
            <-t.C                               // 等待下一次心跳时间
            err = sendHeartbeat(registry, addr) // 发送心跳
        }
    }()
}

// sendHeartbeat 向注册中心发送心跳消息。
func sendHeartbeat(registry, addr string) error {
    log.Println(addr, "send heart beat to registry", registry)
    httpClient := &http.Client{}                     // 创建 HTTP 客户端
    req, _ := http.NewRequest("POST", registry, nil) // 创建 POST 请求
    req.Header.Set("X-krpc-Server", addr)            // 设置服务器地址
    if _, err := httpClient.Do(req); err != nil {    // 发送请求
        log.Println("rpc server: heart beat err:", err)
        return err
    }
    return nil
}

​ 有了注册中心,我们就在 xclient 中对应实现服务发现,不同与上面负载均衡中的MultiServersDiscovery, RegistryDiscovery是服务列表由注册中心维护的,同时它也嵌套了 MultiServersDiscovery,很多能力可以复用。

type RegistryDiscovery struct {
    *MultiServersDiscovery
    registry   string
    timeout    time.Duration
    lastUpdate time.Time
}

const defaultUpdateTimeout = time.Second * 10

func NewRegistryDiscovery(registerAddr string, timeout time.Duration) *RegistryDiscovery {
    if timeout == 0 {
        timeout = defaultUpdateTimeout
    }
    d := &RegistryDiscovery{
        MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)),
        registry:              registerAddr,
        timeout:               timeout,
    }
    return d
}

​ 实现 UpdateRefresh方法,超时重新获取的逻辑在 Refresh中实现, GetGetAll与 MultiServersDiscovery 相似,唯一的不同在于,RegistryDiscovery 需要先调用 Refresh 确保服务列表没有过期。

// Update 更新服务发现的服务器列表。
func (d *RegistryDiscovery) Update(servers []string) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.servers = servers       // 更新服务器列表
    d.lastUpdate = time.Now() // 更新最后更新时间
    return nil
}

// Refresh 从注册中心刷新服务器列表。
func (d *RegistryDiscovery) Refresh() error {
    d.mu.Lock()
    defer d.mu.Unlock()
    // 如果上次更新时间加上超时时间大于当前时间,则不需要刷新
    if d.lastUpdate.Add(d.timeout).After(time.Now()) {
        return nil
    }
    log.Println("rpc registry: refresh servers from registry", d.registry)
    resp, err := http.Get(d.registry) // 从注册中心获取服务器列表
    if err != nil {
        log.Println("rpc registry refresh err:", err)
        return err
    }
    // 从响应头中解析服务器列表
    servers := strings.Split(resp.Header.Get("X-krpc-Servers"), ",")
    d.servers = make([]string, 0, len(servers)) // 重新初始化服务器列表
    for _, server := range servers {
        if strings.TrimSpace(server) != "" { // 去除空字符串
            d.servers = append(d.servers, strings.TrimSpace(server))
        }
    }
    d.lastUpdate = time.Now() // 更新最后更新时间
    return nil
}

// Get 根据选择模式获取一个服务器。
func (d *RegistryDiscovery) Get(mode SelectMode) (string, error) {
    if err := d.Refresh(); err != nil { // 刷新服务器列表
        return "", err
    }
    return d.MultiServersDiscovery.Get(mode) // 调用 MultiServersDiscovery 获取服务器
}

// GetAll 返回所有服务器。
func (d *RegistryDiscovery) GetAll() ([]string, error) {
    if err := d.Refresh(); err != nil { // 刷新服务器列表
        return nil, err
    }
    return d.MultiServersDiscovery.GetAll() // 调用 MultiServersDiscovery 获取所有服务器
}

总结

​ 能够帮助理解RPC框架和RPC通信的步骤~