1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec.
6// https://www.jsonrpc.org/specification
7// It is intended to be compatible with other implementations at the wire level.
8package jsonrpc2
9
10import (
11	"context"
12	"encoding/json"
13	"fmt"
14	"sync"
15	"sync/atomic"
16	"time"
17)
18
19// Conn is a JSON RPC 2 client server connection.
20// Conn is bidirectional; it does not have a designated server or client end.
21type Conn struct {
22	handle     Handler
23	cancel     Canceler
24	log        Logger
25	stream     Stream
26	done       chan struct{}
27	err        error
28	seq        int64      // must only be accessed using atomic operations
29	pendingMu  sync.Mutex // protects the pending map
30	pending    map[ID]chan *Response
31	handlingMu sync.Mutex // protects the handling map
32	handling   map[ID]context.CancelFunc
33}
34
35// Handler is an option you can pass to NewConn to handle incoming requests.
36// If the request returns true from IsNotify then the Handler should not return a
37// result or error, otherwise it should handle the Request and return either
38// an encoded result, or an error.
39// Handlers must be concurrency-safe.
40type Handler = func(context.Context, *Conn, *Request) (interface{}, *Error)
41
42// Canceler is an option you can pass to NewConn which is invoked for
43// cancelled outgoing requests.
44// The request will have the ID filled in, which can be used to propagate the
45// cancel to the other process if needed.
46// It is okay to use the connection to send notifications, but the context will
47// be in the cancelled state, so you must do it with the background context
48// instead.
49type Canceler = func(context.Context, *Conn, *Request)
50
51// NewErrorf builds a Error struct for the suppied message and code.
52// If args is not empty, message and args will be passed to Sprintf.
53func NewErrorf(code int64, format string, args ...interface{}) *Error {
54	return &Error{
55		Code:    code,
56		Message: fmt.Sprintf(format, args...),
57	}
58}
59
60// NewConn creates a new connection object that reads and writes messages from
61// the supplied stream and dispatches incoming messages to the supplied handler.
62func NewConn(ctx context.Context, s Stream, options ...interface{}) *Conn {
63	conn := &Conn{
64		stream:   s,
65		done:     make(chan struct{}),
66		pending:  make(map[ID]chan *Response),
67		handling: make(map[ID]context.CancelFunc),
68	}
69	for _, opt := range options {
70		switch opt := opt.(type) {
71		case Handler:
72			if conn.handle != nil {
73				panic("Duplicate Handler function in options list")
74			}
75			conn.handle = opt
76		case Canceler:
77			if conn.cancel != nil {
78				panic("Duplicate Canceler function in options list")
79			}
80			conn.cancel = opt
81		case Logger:
82			if conn.log != nil {
83				panic("Duplicate Logger function in options list")
84			}
85			conn.log = opt
86		default:
87			panic(fmt.Errorf("Unknown option type %T in options list", opt))
88		}
89	}
90	if conn.handle == nil {
91		// the default handler reports a method error
92		conn.handle = func(ctx context.Context, c *Conn, r *Request) (interface{}, *Error) {
93			return nil, NewErrorf(CodeMethodNotFound, "method %q not found", r.Method)
94		}
95	}
96	if conn.cancel == nil {
97		// the default canceller does nothing
98		conn.cancel = func(context.Context, *Conn, *Request) {}
99	}
100	if conn.log == nil {
101		// the default logger does nothing
102		conn.log = func(Direction, *ID, time.Duration, string, *json.RawMessage, *Error) {}
103	}
104	go func() {
105		conn.err = conn.run(ctx)
106		close(conn.done)
107	}()
108	return conn
109}
110
111// Wait blocks until the connection is terminated, and returns any error that
112// cause the termination.
113func (c *Conn) Wait(ctx context.Context) error {
114	select {
115	case <-c.done:
116		return c.err
117	case <-ctx.Done():
118		return ctx.Err()
119	}
120}
121
122// Cancel cancels a pending Call on the server side.
123// The call is identified by its id.
124// JSON RPC 2 does not specify a cancel message, so cancellation support is not
125// directly wired in. This method allows a higher level protocol to choose how
126// to propagate the cancel.
127func (c *Conn) Cancel(id ID) {
128	c.handlingMu.Lock()
129	cancel := c.handling[id]
130	c.handlingMu.Unlock()
131	if cancel != nil {
132		cancel()
133	}
134}
135
136// Notify is called to send a notification request over the connection.
137// It will return as soon as the notification has been sent, as no response is
138// possible.
139func (c *Conn) Notify(ctx context.Context, method string, params interface{}) error {
140	jsonParams, err := marshalToRaw(params)
141	if err != nil {
142		return fmt.Errorf("marshalling notify parameters: %v", err)
143	}
144	request := &Request{
145		Method: method,
146		Params: jsonParams,
147	}
148	data, err := json.Marshal(request)
149	if err != nil {
150		return fmt.Errorf("marshalling notify request: %v", err)
151	}
152	c.log(Send, nil, -1, request.Method, request.Params, nil)
153	return c.stream.Write(ctx, data)
154}
155
156// Call sends a request over the connection and then waits for a response.
157// If the response is not an error, it will be decoded into result.
158// result must be of a type you an pass to json.Unmarshal.
159func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) error {
160	jsonParams, err := marshalToRaw(params)
161	if err != nil {
162		return fmt.Errorf("marshalling call parameters: %v", err)
163	}
164	// generate a new request identifier
165	id := ID{Number: atomic.AddInt64(&c.seq, 1)}
166	request := &Request{
167		ID:     &id,
168		Method: method,
169		Params: jsonParams,
170	}
171	// marshal the request now it is complete
172	data, err := json.Marshal(request)
173	if err != nil {
174		return fmt.Errorf("marshalling call request: %v", err)
175	}
176	// we have to add ourselves to the pending map before we send, otherwise we
177	// are racing the response
178	rchan := make(chan *Response)
179	c.pendingMu.Lock()
180	c.pending[id] = rchan
181	c.pendingMu.Unlock()
182	defer func() {
183		// clean up the pending response handler on the way out
184		c.pendingMu.Lock()
185		delete(c.pending, id)
186		c.pendingMu.Unlock()
187	}()
188	// now we are ready to send
189	before := time.Now()
190	c.log(Send, request.ID, -1, request.Method, request.Params, nil)
191	if err := c.stream.Write(ctx, data); err != nil {
192		// sending failed, we will never get a response, so don't leave it pending
193		return err
194	}
195	// now wait for the response
196	select {
197	case response := <-rchan:
198		elapsed := time.Since(before)
199		c.log(Send, response.ID, elapsed, request.Method, response.Result, response.Error)
200		// is it an error response?
201		if response.Error != nil {
202			return response.Error
203		}
204		if result == nil || response.Result == nil {
205			return nil
206		}
207		if err := json.Unmarshal(*response.Result, result); err != nil {
208			return fmt.Errorf("unmarshalling result: %v", err)
209		}
210		return nil
211	case <-ctx.Done():
212		// allow the handler to propagate the cancel
213		c.cancel(ctx, c, request)
214		return ctx.Err()
215	}
216}
217
218// combined has all the fields of both Request and Response.
219// We can decode this and then work out which it is.
220type combined struct {
221	VersionTag VersionTag       `json:"jsonrpc"`
222	ID         *ID              `json:"id,omitempty"`
223	Method     string           `json:"method"`
224	Params     *json.RawMessage `json:"params,omitempty"`
225	Result     *json.RawMessage `json:"result,omitempty"`
226	Error      *Error           `json:"error,omitempty"`
227}
228
229// Run starts a read loop on the supplied reader.
230// It must be called exactly once for each Conn.
231// It returns only when the reader is closed or there is an error in the stream.
232func (c *Conn) run(ctx context.Context) error {
233	ctx, cancelRun := context.WithCancel(ctx)
234	for {
235		// get the data for a message
236		data, err := c.stream.Read(ctx)
237		if err != nil {
238			// the stream failed, we cannot continue
239			return err
240		}
241		// read a combined message
242		msg := &combined{}
243		if err := json.Unmarshal(data, msg); err != nil {
244			// a badly formed message arrived, log it and continue
245			// we trust the stream to have isolated the error to just this message
246			c.log(Receive, nil, -1, "", nil, NewErrorf(0, "unmarshal failed: %v", err))
247			continue
248		}
249		// work out which kind of message we have
250		switch {
251		case msg.Method != "":
252			// if method is set it must be a request
253			request := &Request{
254				Method: msg.Method,
255				Params: msg.Params,
256				ID:     msg.ID,
257			}
258			if request.IsNotify() {
259				c.log(Receive, request.ID, -1, request.Method, request.Params, nil)
260				// we have a Notify, forward to the handler in a go routine
261				go func() {
262					if _, err := c.handle(ctx, c, request); err != nil {
263						// notify produced an error, we can't forward it to the other side
264						// because there is no id, so we just log it
265						c.log(Receive, nil, -1, request.Method, nil, err)
266					}
267				}()
268			} else {
269				// we have a Call, forward to the handler in another go routine
270				reqCtx, cancelReq := context.WithCancel(ctx)
271				c.handlingMu.Lock()
272				c.handling[*request.ID] = cancelReq
273				c.handlingMu.Unlock()
274				go func() {
275					defer func() {
276						// clean up the cancel handler on the way out
277						c.handlingMu.Lock()
278						delete(c.handling, *request.ID)
279						c.handlingMu.Unlock()
280						cancelReq()
281					}()
282					c.log(Receive, request.ID, -1, request.Method, request.Params, nil)
283					before := time.Now()
284					resp, callErr := c.handle(reqCtx, c, request)
285					elapsed := time.Since(before)
286					var result *json.RawMessage
287					if result, err = marshalToRaw(resp); err != nil {
288						callErr = &Error{Message: err.Error()}
289					}
290					response := &Response{
291						Result: result,
292						Error:  callErr,
293						ID:     request.ID,
294					}
295					data, err := json.Marshal(response)
296					if err != nil {
297						// failure to marshal leaves the call without a response
298						// possibly we could attempt to respond with a different message
299						// but we can probably rely on timeouts instead
300						c.log(Send, request.ID, elapsed, request.Method, nil, NewErrorf(0, "%s", err))
301						return
302					}
303					c.log(Send, response.ID, elapsed, request.Method, response.Result, response.Error)
304					if err = c.stream.Write(ctx, data); err != nil {
305						// if a stream write fails, we really need to shut down the whole
306						// stream and return from the run
307						c.log(Send, request.ID, elapsed, request.Method, nil, NewErrorf(0, "%s", err))
308						cancelRun()
309						return
310					}
311				}()
312			}
313		case msg.ID != nil:
314			// we have a response, get the pending entry from the map
315			c.pendingMu.Lock()
316			rchan := c.pending[*msg.ID]
317			if rchan != nil {
318				delete(c.pending, *msg.ID)
319			}
320			c.pendingMu.Unlock()
321			// and send the reply to the channel
322			response := &Response{
323				Result: msg.Result,
324				Error:  msg.Error,
325				ID:     msg.ID,
326			}
327			rchan <- response
328			close(rchan)
329		default:
330			c.log(Receive, nil, -1, "", nil, NewErrorf(0, "message not a call, notify or response, ignoring"))
331		}
332	}
333}
334
335func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
336	data, err := json.Marshal(obj)
337	if err != nil {
338		return nil, err
339	}
340	raw := json.RawMessage(data)
341	return &raw, nil
342}
343