1package cbft
2
3import (
4	"encoding/json"
5	"fmt"
6	"io"
7	"time"
8	"unsafe"
9
10	"github.com/blevesearch/bleve"
11	"github.com/blevesearch/bleve/search"
12	"github.com/blevesearch/bleve/search/query"
13	jsoniter "github.com/json-iterator/go"
14)
15
16func init() {
17	// registers the custom json encoders with jsoniter
18	registerCustomJSONEncoders()
19}
20
21// Marshal abstracts the underlying json lib used
22func (p *CustomJSONImpl) Marshal(v interface{}) ([]byte, error) {
23	return jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(v)
24}
25
26// Encode abstracts the underlying json lib used
27func (p *CustomJSONImpl) Encode(w io.Writer, v interface{}) error {
28	return jsoniter.ConfigCompatibleWithStandardLibrary.NewEncoder(w).Encode(v)
29}
30
31// MarshalJSON abstracts the underlying json lib used
32func MarshalJSON(v interface{}) ([]byte, error) {
33	if JSONImpl != nil && JSONImpl.GetManagerOptions()["jsonImpl"] != "std" {
34		return JSONImpl.Marshal(v)
35	}
36	return json.Marshal(v)
37}
38
39func registerCustomJSONEncoders() {
40	// adding all the custom encoders that bleve has implemented,
41	// and need to extend as bleve introduces new custom encoders.
42	jsoniter.RegisterTypeEncoderFunc("bleve.IndexErrMap", encodeBleveIndexErrMap, nil)
43
44	jsoniter.RegisterTypeEncoderFunc("bleve.dateTimeRange", encodeBleveDateTimeRange, nil)
45
46	jsoniter.RegisterTypeEncoderFunc("search.SortDocID", encodeSearchSortDocID, nil)
47
48	jsoniter.RegisterTypeEncoderFunc("search.SortScore", encodeSearchSortScore, nil)
49
50	jsoniter.RegisterTypeEncoderFunc("search.SortField", encodeSearchSortField, nil)
51
52	jsoniter.RegisterTypeEncoderFunc("query.BleveQueryTime", encodeBleveQueryTime, nil)
53
54	jsoniter.RegisterTypeEncoderFunc("search.SortGeoDistance", encodeSortGeoDistance, nil)
55
56	jsoniter.RegisterTypeEncoderFunc("query.MatchAllQuery", encodeMatchAllQuery, nil)
57
58	jsoniter.RegisterTypeEncoderFunc("query.MatchNoneQuery", encodeMatchNoneQuery, nil)
59
60	jsoniter.RegisterTypeEncoderFunc("query.MatchQueryOperator", encodeMatchQueryOperator, nil)
61}
62
63func encodeBleveIndexErrMap(ptr unsafe.Pointer, stream *jsoniter.Stream) {
64	mapPtr := unsafe.Pointer(&ptr)
65	iem := *((*bleve.IndexErrMap)(mapPtr))
66	tmp := make(map[string]string, len(iem))
67	for k, v := range iem {
68		tmp[k] = v.Error()
69	}
70	stream.WriteVal(tmp)
71}
72
73func encodeBleveDateTimeRange(ptr unsafe.Pointer, stream *jsoniter.Stream) {
74	type temp struct {
75		Name        string    `json:"name,omitempty"`
76		Start       time.Time `json:"start,omitempty"`
77		End         time.Time `json:"end,omitempty"`
78		startString *string
79		endString   *string
80	}
81	dr := *((*temp)(ptr))
82	rv := map[string]interface{}{
83		"name":  dr.Name,
84		"start": dr.Start,
85		"end":   dr.End,
86	}
87	if dr.Start.IsZero() && dr.startString != nil {
88		rv["start"] = dr.startString
89	}
90	if dr.End.IsZero() && dr.endString != nil {
91		rv["end"] = dr.endString
92	}
93
94	stream.WriteVal(rv)
95}
96
97func encodeSearchSortDocID(ptr unsafe.Pointer, stream *jsoniter.Stream) {
98	sid := *((*search.SortDocID)(ptr))
99	if sid.Desc {
100		stream.WriteString("-_id")
101		return
102	}
103	stream.WriteString("_id")
104}
105
106func encodeSearchSortScore(ptr unsafe.Pointer, stream *jsoniter.Stream) {
107	ss := *((*search.SortScore)(ptr))
108	if ss.Desc {
109		stream.WriteString("-_score")
110		return
111	}
112	stream.WriteString("_score")
113}
114
115func encodeSearchSortField(ptr unsafe.Pointer, stream *jsoniter.Stream) {
116	s := *((*search.SortField)(ptr))
117	if s.Missing == search.SortFieldMissingLast &&
118		s.Mode == search.SortFieldDefault &&
119		s.Type == search.SortFieldAuto {
120		if s.Desc {
121			stream.WriteString("-" + s.Field)
122			return
123
124		}
125		stream.WriteString(s.Field)
126		return
127	}
128	sfm := map[string]interface{}{
129		"by":    "field",
130		"field": s.Field,
131	}
132	if s.Desc {
133		sfm["desc"] = true
134	}
135	if s.Missing > search.SortFieldMissingLast {
136		switch s.Missing {
137		case search.SortFieldMissingFirst:
138			sfm["missing"] = "first"
139		}
140	}
141	if s.Mode > search.SortFieldDefault {
142		switch s.Mode {
143		case search.SortFieldMin:
144			sfm["mode"] = "min"
145		case search.SortFieldMax:
146			sfm["mode"] = "max"
147		}
148	}
149	if s.Type > search.SortFieldAuto {
150		switch s.Type {
151		case search.SortFieldAsString:
152			sfm["type"] = "string"
153		case search.SortFieldAsNumber:
154			sfm["type"] = "number"
155		case search.SortFieldAsDate:
156			sfm["type"] = "date"
157		}
158	}
159	stream.WriteVal(sfm)
160}
161
162func encodeBleveQueryTime(ptr unsafe.Pointer, stream *jsoniter.Stream) {
163	temp := *((*query.BleveQueryTime)(ptr))
164	tt := time.Time(temp.Time)
165	stream.WriteString(tt.Format(query.QueryDateTimeFormat))
166}
167
168func encodeSortGeoDistance(ptr unsafe.Pointer, stream *jsoniter.Stream) {
169	s := *((*search.SortGeoDistance)(ptr))
170	sfm := map[string]interface{}{
171		"by":    "geo_distance",
172		"field": s.Field,
173		"location": map[string]interface{}{
174			"lon": s.Lon,
175			"lat": s.Lat,
176		},
177	}
178	if s.Unit != "" {
179		sfm["unit"] = s.Unit
180	}
181	if s.Desc {
182		sfm["desc"] = true
183	}
184	stream.WriteVal(sfm)
185}
186
187func encodeMatchAllQuery(ptr unsafe.Pointer, stream *jsoniter.Stream) {
188	q := *((*query.MatchAllQuery)(ptr))
189	tmp := map[string]interface{}{
190		"boost":     q.BoostVal,
191		"match_all": map[string]interface{}{},
192	}
193	stream.WriteVal(tmp)
194}
195
196func encodeMatchNoneQuery(ptr unsafe.Pointer, stream *jsoniter.Stream) {
197	q := *((*query.MatchNoneQuery)(ptr))
198	tmp := map[string]interface{}{
199		"boost":      q.BoostVal,
200		"match_none": map[string]interface{}{},
201	}
202	stream.WriteVal(tmp)
203}
204
205func encodeMatchQueryOperator(ptr unsafe.Pointer, stream *jsoniter.Stream) {
206	o := *((*query.MatchQueryOperator)(ptr))
207	switch o {
208	case query.MatchQueryOperatorOr:
209		stream.WriteString("or")
210	case query.MatchQueryOperatorAnd:
211		stream.WriteString("and")
212	default:
213		stream.Error = fmt.Errorf("cannot marshal match operator %d to JSON", o)
214	}
215}
216