client.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package rpc
  5. import (
  6. "bufio"
  7. "encoding/gob"
  8. "errors"
  9. "io"
  10. "log"
  11. "net"
  12. "sync"
  13. )
  14. // ServerError represents an error that has been returned from
  15. // the remote side of the RPC connection.
  16. type ServerError string
  17. func (e ServerError) Error() string {
  18. return string(e)
  19. }
  20. var ErrShutdown = errors.New("connection is shut down")
  21. // Call represents an active RPC.
  22. type Call struct {
  23. ServiceMethod string // The name of the service and method to call.
  24. Args any // The argument to the function (*struct).
  25. Reply any // The reply from the function (*struct).
  26. Error error // After completion, the error status.
  27. Done chan *Call // Receives *Call when Go is complete.
  28. }
  29. // Client represents an RPC Client.
  30. // There may be multiple outstanding Calls associated
  31. // with a single Client, and a Client may be used by
  32. // multiple goroutines simultaneously.
  33. type Client struct {
  34. codec ClientCodec
  35. reqMutex sync.Mutex // protects following
  36. request Request
  37. mutex sync.Mutex // protects following
  38. seq uint64
  39. pending map[uint64]*Call
  40. closing bool // user has called Close
  41. shutdown bool // server has told us to stop
  42. }
  43. // A ClientCodec implements writing of RPC requests and
  44. // reading of RPC responses for the client side of an RPC session.
  45. // The client calls WriteRequest to write a request to the connection
  46. // and calls ReadResponseHeader and ReadResponseBody in pairs
  47. // to read responses. The client calls Close when finished with the
  48. // connection. ReadResponseBody may be called with a nil
  49. // argument to force the body of the response to be read and then
  50. // discarded.
  51. // See NewClient's comment for information about concurrent access.
  52. type ClientCodec interface {
  53. WriteRequest(*Request, any) error
  54. ReadResponseHeader(*Response) error
  55. ReadResponseBody(any) error
  56. Close() error
  57. }
  58. func (client *Client) send(call *Call) {
  59. client.reqMutex.Lock()
  60. defer client.reqMutex.Unlock()
  61. // Register this call.
  62. client.mutex.Lock()
  63. if client.shutdown || client.closing {
  64. client.mutex.Unlock()
  65. call.Error = ErrShutdown
  66. call.done()
  67. return
  68. }
  69. seq := client.seq
  70. client.seq++
  71. client.pending[seq] = call
  72. client.mutex.Unlock()
  73. // Encode and send the request.
  74. client.request.Seq = seq
  75. client.request.ServiceMethod = call.ServiceMethod
  76. err := client.codec.WriteRequest(&client.request, call.Args)
  77. if err != nil {
  78. client.mutex.Lock()
  79. call = client.pending[seq]
  80. delete(client.pending, seq)
  81. client.mutex.Unlock()
  82. if call != nil {
  83. call.Error = err
  84. call.done()
  85. }
  86. }
  87. }
  88. func (client *Client) input() {
  89. var err error
  90. var response Response
  91. for err == nil {
  92. response = Response{}
  93. err = client.codec.ReadResponseHeader(&response)
  94. if err != nil {
  95. break
  96. }
  97. seq := response.Seq
  98. client.mutex.Lock()
  99. call := client.pending[seq]
  100. delete(client.pending, seq)
  101. client.mutex.Unlock()
  102. switch {
  103. case call == nil:
  104. // We've got no pending call. That usually means that
  105. // WriteRequest partially failed, and call was already
  106. // removed; response is a server telling us about an
  107. // error reading request body. We should still attempt
  108. // to read error body, but there's no one to give it to.
  109. err = client.codec.ReadResponseBody(nil)
  110. if err != nil {
  111. err = errors.New("reading error body: " + err.Error())
  112. }
  113. case response.Error != "":
  114. // We've got an error response. Give this to the request;
  115. // any subsequent requests will get the ReadResponseBody
  116. // error if there is one.
  117. call.Error = ServerError(response.Error)
  118. err = client.codec.ReadResponseBody(nil)
  119. if err != nil {
  120. err = errors.New("reading error body: " + err.Error())
  121. }
  122. call.done()
  123. default:
  124. err = client.codec.ReadResponseBody(call.Reply)
  125. if err != nil {
  126. call.Error = errors.New("reading body " + err.Error())
  127. }
  128. call.done()
  129. }
  130. }
  131. // Terminate pending calls.
  132. client.reqMutex.Lock()
  133. client.mutex.Lock()
  134. client.shutdown = true
  135. closing := client.closing
  136. if err == io.EOF {
  137. if closing {
  138. err = ErrShutdown
  139. } else {
  140. err = io.ErrUnexpectedEOF
  141. }
  142. }
  143. for _, call := range client.pending {
  144. call.Error = err
  145. call.done()
  146. }
  147. client.mutex.Unlock()
  148. client.reqMutex.Unlock()
  149. if debugLog && err != io.EOF && !closing {
  150. log.Println("rpc: client protocol error:", err)
  151. }
  152. }
  153. func (call *Call) done() {
  154. select {
  155. case call.Done <- call:
  156. // ok
  157. default:
  158. // We don't want to block here. It is the caller's responsibility to make
  159. // sure the channel has enough buffer space. See comment in Go().
  160. if debugLog {
  161. log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
  162. }
  163. }
  164. }
  165. // NewClient returns a new Client to handle requests to the
  166. // set of services at the other end of the connection.
  167. // It adds a buffer to the write side of the connection so
  168. // the header and payload are sent as a unit.
  169. //
  170. // The read and write halves of the connection are serialized independently,
  171. // so no interlocking is required. However each half may be accessed
  172. // concurrently so the implementation of conn should protect against
  173. // concurrent reads or concurrent writes.
  174. func NewClient(conn io.ReadWriteCloser) *Client {
  175. encBuf := bufio.NewWriter(conn)
  176. client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
  177. return NewClientWithCodec(client)
  178. }
  179. // NewClientWithCodec is like NewClient but uses the specified
  180. // codec to encode requests and decode responses.
  181. func NewClientWithCodec(codec ClientCodec) *Client {
  182. client := &Client{
  183. codec: codec,
  184. pending: make(map[uint64]*Call),
  185. }
  186. go client.input()
  187. return client
  188. }
  189. type gobClientCodec struct {
  190. rwc io.ReadWriteCloser
  191. dec *gob.Decoder
  192. enc *gob.Encoder
  193. encBuf *bufio.Writer
  194. }
  195. func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
  196. if err = c.enc.Encode(r); err != nil {
  197. return
  198. }
  199. if err = c.enc.Encode(body); err != nil {
  200. return
  201. }
  202. return c.encBuf.Flush()
  203. }
  204. func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
  205. return c.dec.Decode(r)
  206. }
  207. func (c *gobClientCodec) ReadResponseBody(body any) error {
  208. return c.dec.Decode(body)
  209. }
  210. func (c *gobClientCodec) Close() error {
  211. return c.rwc.Close()
  212. }
  213. // Dial connects to an RPC server at the specified network address.
  214. func Dial(network, address string) (*Client, error) {
  215. conn, err := net.Dial(network, address)
  216. if err != nil {
  217. return nil, err
  218. }
  219. return NewClient(conn), nil
  220. }
  221. // Close calls the underlying codec's Close method. If the connection is already
  222. // shutting down, ErrShutdown is returned.
  223. func (client *Client) Close() error {
  224. client.mutex.Lock()
  225. if client.closing {
  226. client.mutex.Unlock()
  227. return ErrShutdown
  228. }
  229. client.closing = true
  230. client.mutex.Unlock()
  231. return client.codec.Close()
  232. }
  233. // Go invokes the function asynchronously. It returns the Call structure representing
  234. // the invocation. The done channel will signal when the call is complete by returning
  235. // the same Call object. If done is nil, Go will allocate a new channel.
  236. // If non-nil, done must be buffered or Go will deliberately crash.
  237. func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
  238. call := new(Call)
  239. call.ServiceMethod = serviceMethod
  240. call.Args = args
  241. call.Reply = reply
  242. if done == nil {
  243. done = make(chan *Call, 10) // buffered.
  244. } else {
  245. // If caller passes done != nil, it must arrange that
  246. // done has enough buffer for the number of simultaneous
  247. // RPCs that will be using that channel. If the channel
  248. // is totally unbuffered, it's best not to run at all.
  249. if cap(done) == 0 {
  250. log.Panic("rpc: done channel is unbuffered")
  251. }
  252. }
  253. call.Done = done
  254. client.send(call)
  255. return call
  256. }
  257. // Call invokes the named function, waits for it to complete, and returns its error status.
  258. func (client *Client) Call(serviceMethod string, args any, reply any) error {
  259. call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
  260. return call.Error
  261. }