1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonrpc2_test
6
7import (
8	"context"
9	"encoding/json"
10	"flag"
11	"fmt"
12	"io"
13	"path"
14	"reflect"
15	"testing"
16
17	"golang.org/x/tools/internal/jsonrpc2"
18)
19
20var logRPC = flag.Bool("logrpc", false, "Enable jsonrpc2 communication logging")
21
22type callTest struct {
23	method string
24	params interface{}
25	expect interface{}
26}
27
28var callTests = []callTest{
29	{"no_args", nil, true},
30	{"one_string", "fish", "got:fish"},
31	{"one_number", 10, "got:10"},
32	{"join", []string{"a", "b", "c"}, "a/b/c"},
33	//TODO: expand the test cases
34}
35
36func (test *callTest) newResults() interface{} {
37	switch e := test.expect.(type) {
38	case []interface{}:
39		var r []interface{}
40		for _, v := range e {
41			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
42		}
43		return r
44	case nil:
45		return nil
46	default:
47		return reflect.New(reflect.TypeOf(test.expect)).Interface()
48	}
49}
50
51func (test *callTest) verifyResults(t *testing.T, results interface{}) {
52	if results == nil {
53		return
54	}
55	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
56	if !reflect.DeepEqual(val, test.expect) {
57		t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect)
58	}
59}
60
61func TestPlainCall(t *testing.T) {
62	ctx := context.Background()
63	a, b := prepare(ctx, t, false)
64	for _, test := range callTests {
65		results := test.newResults()
66		if err := a.Call(ctx, test.method, test.params, results); err != nil {
67			t.Fatalf("%v:Call failed: %v", test.method, err)
68		}
69		test.verifyResults(t, results)
70		if err := b.Call(ctx, test.method, test.params, results); err != nil {
71			t.Fatalf("%v:Call failed: %v", test.method, err)
72		}
73		test.verifyResults(t, results)
74	}
75}
76
77func TestHeaderCall(t *testing.T) {
78	ctx := context.Background()
79	a, b := prepare(ctx, t, true)
80	for _, test := range callTests {
81		results := test.newResults()
82		if err := a.Call(ctx, test.method, test.params, results); err != nil {
83			t.Fatalf("%v:Call failed: %v", test.method, err)
84		}
85		test.verifyResults(t, results)
86		if err := b.Call(ctx, test.method, test.params, results); err != nil {
87			t.Fatalf("%v:Call failed: %v", test.method, err)
88		}
89		test.verifyResults(t, results)
90	}
91}
92
93func prepare(ctx context.Context, t *testing.T, withHeaders bool) (*testHandler, *testHandler) {
94	a := &testHandler{t: t}
95	b := &testHandler{t: t}
96	a.reader, b.writer = io.Pipe()
97	b.reader, a.writer = io.Pipe()
98	for _, h := range []*testHandler{a, b} {
99		h := h
100		if withHeaders {
101			h.stream = jsonrpc2.NewHeaderStream(h.reader, h.writer)
102		} else {
103			h.stream = jsonrpc2.NewStream(h.reader, h.writer)
104		}
105		args := []interface{}{handle}
106		if *logRPC {
107			args = append(args, jsonrpc2.Log)
108		}
109		h.Conn = jsonrpc2.NewConn(ctx, h.stream, args...)
110		go func() {
111			defer func() {
112				h.reader.Close()
113				h.writer.Close()
114			}()
115			if err := h.Conn.Wait(ctx); err != nil {
116				t.Fatalf("Stream failed: %v", err)
117			}
118		}()
119	}
120	return a, b
121}
122
123type testHandler struct {
124	t      *testing.T
125	reader *io.PipeReader
126	writer *io.PipeWriter
127	stream jsonrpc2.Stream
128	*jsonrpc2.Conn
129}
130
131func handle(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) (interface{}, *jsonrpc2.Error) {
132	switch r.Method {
133	case "no_args":
134		if r.Params != nil {
135			return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeInvalidParams, "Expected no params")
136		}
137		return true, nil
138	case "one_string":
139		var v string
140		if err := json.Unmarshal(*r.Params, &v); err != nil {
141			return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())
142		}
143		return "got:" + v, nil
144	case "one_number":
145		var v int
146		if err := json.Unmarshal(*r.Params, &v); err != nil {
147			return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())
148		}
149		return fmt.Sprintf("got:%d", v), nil
150	case "join":
151		var v []string
152		if err := json.Unmarshal(*r.Params, &v); err != nil {
153			return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())
154		}
155		return path.Join(v...), nil
156	default:
157		return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeMethodNotFound, "method %q not found", r.Method)
158	}
159}
160