1/*
2Package bitset implements bitsets, a mapping
3between non-negative integers and boolean values. It should be more
4efficient than map[uint] bool.
5
6It provides methods for setting, clearing, flipping, and testing
7individual integers.
8
9But it also provides set intersection, union, difference,
10complement, and symmetric operations, as well as tests to
11check whether any, all, or no bits are set, and querying a
12bitset's current length and number of positive bits.
13
14BitSets are expanded to the size of the largest set bit; the
15memory allocation is approximately Max bits, where Max is
16the largest set bit. BitSets are never shrunk. On creation,
17a hint can be given for the number of bits that will be used.
18
19Many of the methods, including Set,Clear, and Flip, return
20a BitSet pointer, which allows for chaining.
21
22Example use:
23
24	import "bitset"
25	var b BitSet
26	b.Set(10).Set(11)
27	if b.Test(1000) {
28		b.Clear(1000)
29	}
30	if B.Intersection(bitset.New(100).Set(10)).Count() > 1 {
31		fmt.Println("Intersection works.")
32	}
33
34As an alternative to BitSets, one should check out the 'big' package,
35which provides a (less set-theoretical) view of bitsets.
36
37*/
38package bitset
39
40import (
41	"bufio"
42	"bytes"
43	"encoding/base64"
44	"encoding/binary"
45	"encoding/json"
46	"errors"
47	"fmt"
48	"io"
49	"strconv"
50)
51
52// the wordSize of a bit set
53const wordSize = uint(64)
54
55// log2WordSize is lg(wordSize)
56const log2WordSize = uint(6)
57
58// allBits has every bit set
59const allBits uint64 = 0xffffffffffffffff
60
61// A BitSet is a set of bits. The zero value of a BitSet is an empty set of length 0.
62type BitSet struct {
63	length uint
64	set    []uint64
65}
66
67// Error is used to distinguish errors (panics) generated in this package.
68type Error string
69
70// safeSet will fixup b.set to be non-nil and return the field value
71func (b *BitSet) safeSet() []uint64 {
72	if b.set == nil {
73		b.set = make([]uint64, wordsNeeded(0))
74	}
75	return b.set
76}
77
78// From is a constructor used to create a BitSet from an array of integers
79func From(buf []uint64) *BitSet {
80	return &BitSet{uint(len(buf)) * 64, buf}
81}
82
83// Bytes returns the bitset as array of integers
84func (b *BitSet) Bytes() []uint64 {
85	return b.set
86}
87
88// wordsNeeded calculates the number of words needed for i bits
89func wordsNeeded(i uint) int {
90	if i > ((^uint(0)) - wordSize + 1) {
91		return int((^uint(0)) >> log2WordSize)
92	}
93	return int((i + (wordSize - 1)) >> log2WordSize)
94}
95
96// New creates a new BitSet with a hint that length bits will be required
97func New(length uint) (bset *BitSet) {
98	defer func() {
99		if r := recover(); r != nil {
100			bset = &BitSet{
101				0,
102				make([]uint64, 0),
103			}
104		}
105	}()
106
107	bset = &BitSet{
108		length,
109		make([]uint64, wordsNeeded(length)),
110	}
111
112	return bset
113}
114
115// Cap returns the total possible capicity, or number of bits
116func Cap() uint {
117	return ^uint(0)
118}
119
120// Len returns the length of the BitSet in words
121func (b *BitSet) Len() uint {
122	return b.length
123}
124
125// extendSetMaybe adds additional words to incorporate new bits if needed
126func (b *BitSet) extendSetMaybe(i uint) {
127	if i >= b.length { // if we need more bits, make 'em
128		nsize := wordsNeeded(i + 1)
129		if b.set == nil {
130			b.set = make([]uint64, nsize)
131		} else if cap(b.set) >= nsize {
132			b.set = b.set[:nsize] // fast resize
133		} else if len(b.set) < nsize {
134			newset := make([]uint64, nsize, 2*nsize) // increase capacity 2x
135			copy(newset, b.set)
136			b.set = newset
137		}
138		b.length = i + 1
139	}
140}
141
142// Test whether bit i is set.
143func (b *BitSet) Test(i uint) bool {
144	if i >= b.length {
145		return false
146	}
147	return b.set[i>>log2WordSize]&(1<<(i&(wordSize-1))) != 0
148}
149
150// Set bit i to 1
151func (b *BitSet) Set(i uint) *BitSet {
152	b.extendSetMaybe(i)
153	b.set[i>>log2WordSize] |= 1 << (i & (wordSize - 1))
154	return b
155}
156
157// Clear bit i to 0
158func (b *BitSet) Clear(i uint) *BitSet {
159	if i >= b.length {
160		return b
161	}
162	b.set[i>>log2WordSize] &^= 1 << (i & (wordSize - 1))
163	return b
164}
165
166// SetTo sets bit i to value
167func (b *BitSet) SetTo(i uint, value bool) *BitSet {
168	if value {
169		return b.Set(i)
170	}
171	return b.Clear(i)
172}
173
174// Flip bit at i
175func (b *BitSet) Flip(i uint) *BitSet {
176	if i >= b.length {
177		return b.Set(i)
178	}
179	b.set[i>>log2WordSize] ^= 1 << (i & (wordSize - 1))
180	return b
181}
182
183// String creates a string representation of the Bitmap
184func (b *BitSet) String() string {
185	// follows code from https://github.com/RoaringBitmap/roaring
186	var buffer bytes.Buffer
187	start := []byte("{")
188	buffer.Write(start)
189	counter := 0
190	i, e := b.NextSet(0)
191	for e {
192		counter = counter + 1
193		// to avoid exhausting the memory
194		if counter > 0x40000 {
195			buffer.WriteString("...")
196			break
197		}
198		buffer.WriteString(strconv.FormatInt(int64(i), 10))
199		i, e = b.NextSet(i + 1)
200		if e {
201			buffer.WriteString(",")
202		}
203	}
204	buffer.WriteString("}")
205	return buffer.String()
206}
207
208// NextSet returns the next bit set from the specified index,
209// including possibly the current index
210// along with an error code (true = valid, false = no set bit found)
211// for i,e := v.NextSet(0); e; i,e = v.NextSet(i + 1) {...}
212func (b *BitSet) NextSet(i uint) (uint, bool) {
213	x := int(i >> log2WordSize)
214	if x >= len(b.set) {
215		return 0, false
216	}
217	w := b.set[x]
218	w = w >> (i & (wordSize - 1))
219	if w != 0 {
220		return i + trailingZeroes64(w), true
221	}
222	x = x + 1
223	for x < len(b.set) {
224		if b.set[x] != 0 {
225			return uint(x)*wordSize + trailingZeroes64(b.set[x]), true
226		}
227		x = x + 1
228
229	}
230	return 0, false
231}
232
233// NextClear returns the next clear bit from the specified index,
234// including possibly the current index
235// along with an error code (true = valid, false = no bit found i.e. all bits are set)
236func (b *BitSet) NextClear(i uint) (uint, bool) {
237	x := int(i >> log2WordSize)
238	if x >= len(b.set) {
239		return 0, false
240	}
241	w := b.set[x]
242	w = w >> (i & (wordSize - 1))
243	wA := allBits >> (i & (wordSize - 1))
244	if w != wA {
245		return i + trailingZeroes64(^w), true
246	}
247	x++
248	for x < len(b.set) {
249		if b.set[x] != allBits {
250			return uint(x)*wordSize + trailingZeroes64(^b.set[x]), true
251		}
252		x++
253	}
254	return 0, false
255}
256
257// ClearAll clears the entire BitSet
258func (b *BitSet) ClearAll() *BitSet {
259	if b != nil && b.set != nil {
260		for i := range b.set {
261			b.set[i] = 0
262		}
263	}
264	return b
265}
266
267// wordCount returns the number of words used in a bit set
268func (b *BitSet) wordCount() int {
269	return len(b.set)
270}
271
272// Clone this BitSet
273func (b *BitSet) Clone() *BitSet {
274	c := New(b.length)
275	if b.set != nil { // Clone should not modify current object
276		copy(c.set, b.set)
277	}
278	return c
279}
280
281// Copy into a destination BitSet
282// Returning the size of the destination BitSet
283// like array copy
284func (b *BitSet) Copy(c *BitSet) (count uint) {
285	if c == nil {
286		return
287	}
288	if b.set != nil { // Copy should not modify current object
289		copy(c.set, b.set)
290	}
291	count = c.length
292	if b.length < c.length {
293		count = b.length
294	}
295	return
296}
297
298// Count (number of set bits)
299func (b *BitSet) Count() uint {
300	if b != nil && b.set != nil {
301		return uint(popcntSlice(b.set))
302	}
303	return 0
304}
305
306var deBruijn = [...]byte{
307	0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
308	62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
309	63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
310	54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
311}
312
313func trailingZeroes64(v uint64) uint {
314	return uint(deBruijn[((v&-v)*0x03f79d71b4ca8b09)>>58])
315}
316
317// Equal tests the equvalence of two BitSets.
318// False if they are of different sizes, otherwise true
319// only if all the same bits are set
320func (b *BitSet) Equal(c *BitSet) bool {
321	if c == nil {
322		return false
323	}
324	if b.length != c.length {
325		return false
326	}
327	if b.length == 0 { // if they have both length == 0, then could have nil set
328		return true
329	}
330	// testing for equality shoud not transform the bitset (no call to safeSet)
331
332	for p, v := range b.set {
333		if c.set[p] != v {
334			return false
335		}
336	}
337	return true
338}
339
340func panicIfNull(b *BitSet) {
341	if b == nil {
342		panic(Error("BitSet must not be null"))
343	}
344}
345
346// Difference of base set and other set
347// This is the BitSet equivalent of &^ (and not)
348func (b *BitSet) Difference(compare *BitSet) (result *BitSet) {
349	panicIfNull(b)
350	panicIfNull(compare)
351	result = b.Clone() // clone b (in case b is bigger than compare)
352	l := int(compare.wordCount())
353	if l > int(b.wordCount()) {
354		l = int(b.wordCount())
355	}
356	for i := 0; i < l; i++ {
357		result.set[i] = b.set[i] &^ compare.set[i]
358	}
359	return
360}
361
362// DifferenceCardinality computes the cardinality of the differnce
363func (b *BitSet) DifferenceCardinality(compare *BitSet) uint {
364	panicIfNull(b)
365	panicIfNull(compare)
366	l := int(compare.wordCount())
367	if l > int(b.wordCount()) {
368		l = int(b.wordCount())
369	}
370	cnt := uint64(0)
371	cnt += popcntMaskSlice(b.set[:l], compare.set[:l])
372	cnt += popcntSlice(b.set[l:])
373	return uint(cnt)
374}
375
376// InPlaceDifference computes the difference of base set and other set
377// This is the BitSet equivalent of &^ (and not)
378func (b *BitSet) InPlaceDifference(compare *BitSet) {
379	panicIfNull(b)
380	panicIfNull(compare)
381	l := int(compare.wordCount())
382	if l > int(b.wordCount()) {
383		l = int(b.wordCount())
384	}
385	for i := 0; i < l; i++ {
386		b.set[i] &^= compare.set[i]
387	}
388}
389
390// Convenience function: return two bitsets ordered by
391// increasing length. Note: neither can be nil
392func sortByLength(a *BitSet, b *BitSet) (ap *BitSet, bp *BitSet) {
393	if a.length <= b.length {
394		ap, bp = a, b
395	} else {
396		ap, bp = b, a
397	}
398	return
399}
400
401// Intersection of base set and other set
402// This is the BitSet equivalent of & (and)
403func (b *BitSet) Intersection(compare *BitSet) (result *BitSet) {
404	panicIfNull(b)
405	panicIfNull(compare)
406	b, compare = sortByLength(b, compare)
407	result = New(b.length)
408	for i, word := range b.set {
409		result.set[i] = word & compare.set[i]
410	}
411	return
412}
413
414// IntersectionCardinality computes the cardinality of the union
415func (b *BitSet) IntersectionCardinality(compare *BitSet) uint {
416	panicIfNull(b)
417	panicIfNull(compare)
418	b, compare = sortByLength(b, compare)
419	cnt := popcntAndSlice(b.set, compare.set)
420	return uint(cnt)
421}
422
423// InPlaceIntersection destructively computes the intersection of
424// base set and the compare set.
425// This is the BitSet equivalent of & (and)
426func (b *BitSet) InPlaceIntersection(compare *BitSet) {
427	panicIfNull(b)
428	panicIfNull(compare)
429	l := int(compare.wordCount())
430	if l > int(b.wordCount()) {
431		l = int(b.wordCount())
432	}
433	for i := 0; i < l; i++ {
434		b.set[i] &= compare.set[i]
435	}
436	for i := l; i < len(b.set); i++ {
437		b.set[i] = 0
438	}
439	if compare.length > 0 {
440		b.extendSetMaybe(compare.length - 1)
441	}
442	return
443}
444
445// Union of base set and other set
446// This is the BitSet equivalent of | (or)
447func (b *BitSet) Union(compare *BitSet) (result *BitSet) {
448	panicIfNull(b)
449	panicIfNull(compare)
450	b, compare = sortByLength(b, compare)
451	result = compare.Clone()
452	for i, word := range b.set {
453		result.set[i] = word | compare.set[i]
454	}
455	return
456}
457
458// UnionCardinality computes the cardinality of the uniton of the base set
459// and the compare set.
460func (b *BitSet) UnionCardinality(compare *BitSet) uint {
461	panicIfNull(b)
462	panicIfNull(compare)
463	b, compare = sortByLength(b, compare)
464	cnt := popcntOrSlice(b.set, compare.set)
465	if len(compare.set) > len(b.set) {
466		cnt += popcntSlice(compare.set[len(b.set):])
467	}
468	return uint(cnt)
469}
470
471// InPlaceUnion creates the destructive union of base set and compare set.
472// This is the BitSet equivalent of | (or).
473func (b *BitSet) InPlaceUnion(compare *BitSet) {
474	panicIfNull(b)
475	panicIfNull(compare)
476	l := int(compare.wordCount())
477	if l > int(b.wordCount()) {
478		l = int(b.wordCount())
479	}
480	if compare.length > 0 {
481		b.extendSetMaybe(compare.length - 1)
482	}
483	for i := 0; i < l; i++ {
484		b.set[i] |= compare.set[i]
485	}
486	if len(compare.set) > l {
487		for i := l; i < len(compare.set); i++ {
488			b.set[i] = compare.set[i]
489		}
490	}
491}
492
493// SymmetricDifference of base set and other set
494// This is the BitSet equivalent of ^ (xor)
495func (b *BitSet) SymmetricDifference(compare *BitSet) (result *BitSet) {
496	panicIfNull(b)
497	panicIfNull(compare)
498	b, compare = sortByLength(b, compare)
499	// compare is bigger, so clone it
500	result = compare.Clone()
501	for i, word := range b.set {
502		result.set[i] = word ^ compare.set[i]
503	}
504	return
505}
506
507// SymmetricDifferenceCardinality computes the cardinality of the symmetric difference
508func (b *BitSet) SymmetricDifferenceCardinality(compare *BitSet) uint {
509	panicIfNull(b)
510	panicIfNull(compare)
511	b, compare = sortByLength(b, compare)
512	cnt := popcntXorSlice(b.set, compare.set)
513	if len(compare.set) > len(b.set) {
514		cnt += popcntSlice(compare.set[len(b.set):])
515	}
516	return uint(cnt)
517}
518
519// InPlaceSymmetricDifference creates the destructive SymmetricDifference of base set and other set
520// This is the BitSet equivalent of ^ (xor)
521func (b *BitSet) InPlaceSymmetricDifference(compare *BitSet) {
522	panicIfNull(b)
523	panicIfNull(compare)
524	l := int(compare.wordCount())
525	if l > int(b.wordCount()) {
526		l = int(b.wordCount())
527	}
528	if compare.length > 0 {
529		b.extendSetMaybe(compare.length - 1)
530	}
531	for i := 0; i < l; i++ {
532		b.set[i] ^= compare.set[i]
533	}
534	if len(compare.set) > l {
535		for i := l; i < len(compare.set); i++ {
536			b.set[i] = compare.set[i]
537		}
538	}
539}
540
541// Is the length an exact multiple of word sizes?
542func (b *BitSet) isLenExactMultiple() bool {
543	return b.length%wordSize == 0
544}
545
546// Clean last word by setting unused bits to 0
547func (b *BitSet) cleanLastWord() {
548	if !b.isLenExactMultiple() {
549		b.set[len(b.set)-1] &= allBits >> (wordSize - b.length%wordSize)
550	}
551}
552
553// Complement computes the (local) complement of a biset (up to length bits)
554func (b *BitSet) Complement() (result *BitSet) {
555	panicIfNull(b)
556	result = New(b.length)
557	for i, word := range b.set {
558		result.set[i] = ^word
559	}
560	result.cleanLastWord()
561	return
562}
563
564// All returns true if all bits are set, false otherwise. Returns true for
565// empty sets.
566func (b *BitSet) All() bool {
567	panicIfNull(b)
568	return b.Count() == b.length
569}
570
571// None returns true if no bit is set, false otherwise. Retursn true for
572// empty sets.
573func (b *BitSet) None() bool {
574	panicIfNull(b)
575	if b != nil && b.set != nil {
576		for _, word := range b.set {
577			if word > 0 {
578				return false
579			}
580		}
581		return true
582	}
583	return true
584}
585
586// Any returns true if any bit is set, false otherwise
587func (b *BitSet) Any() bool {
588	panicIfNull(b)
589	return !b.None()
590}
591
592// IsSuperSet returns true if this is a superset of the other set
593func (b *BitSet) IsSuperSet(other *BitSet) bool {
594	for i, e := other.NextSet(0); e; i, e = other.NextSet(i + 1) {
595		if !b.Test(i) {
596			return false
597		}
598	}
599	return true
600}
601
602// IsStrictSuperSet returns true if this is a strict superset of the other set
603func (b *BitSet) IsStrictSuperSet(other *BitSet) bool {
604	return b.Count() > other.Count() && b.IsSuperSet(other)
605}
606
607// DumpAsBits dumps a bit set as a string of bits
608func (b *BitSet) DumpAsBits() string {
609	if b.set == nil {
610		return "."
611	}
612	buffer := bytes.NewBufferString("")
613	i := len(b.set) - 1
614	for ; i >= 0; i-- {
615		fmt.Fprintf(buffer, "%064b.", b.set[i])
616	}
617	return string(buffer.Bytes())
618}
619
620// BinaryStorageSize returns the binary storage requirements
621func (b *BitSet) BinaryStorageSize() int {
622	return binary.Size(uint64(0)) + binary.Size(b.set)
623}
624
625// WriteTo writes a BitSet to a stream
626func (b *BitSet) WriteTo(stream io.Writer) (int64, error) {
627	length := uint64(b.length)
628
629	// Write length
630	err := binary.Write(stream, binary.BigEndian, length)
631	if err != nil {
632		return 0, err
633	}
634
635	// Write set
636	err = binary.Write(stream, binary.BigEndian, b.set)
637	return int64(b.BinaryStorageSize()), err
638}
639
640// ReadFrom reads a BitSet from a stream written using WriteTo
641func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
642	var length uint64
643
644	// Read length first
645	err := binary.Read(stream, binary.BigEndian, &length)
646	if err != nil {
647		return 0, err
648	}
649	newset := New(uint(length))
650
651	if uint64(newset.length) != length {
652		return 0, errors.New("Unmarshalling error: type mismatch")
653	}
654
655	// Read remaining bytes as set
656	err = binary.Read(stream, binary.BigEndian, newset.set)
657	if err != nil {
658		return 0, err
659	}
660
661	*b = *newset
662	return int64(b.BinaryStorageSize()), nil
663}
664
665// MarshalBinary encodes a BitSet into a binary form and returns the result.
666func (b *BitSet) MarshalBinary() ([]byte, error) {
667	var buf bytes.Buffer
668	writer := bufio.NewWriter(&buf)
669
670	_, err := b.WriteTo(writer)
671	if err != nil {
672		return []byte{}, err
673	}
674
675	err = writer.Flush()
676
677	return buf.Bytes(), err
678}
679
680// UnmarshalBinary decodes the binary form generated by MarshalBinary.
681func (b *BitSet) UnmarshalBinary(data []byte) error {
682	buf := bytes.NewReader(data)
683	reader := bufio.NewReader(buf)
684
685	_, err := b.ReadFrom(reader)
686
687	return err
688}
689
690// MarshalJSON marshals a BitSet as a JSON structure
691func (b *BitSet) MarshalJSON() ([]byte, error) {
692	buffer := bytes.NewBuffer(make([]byte, 0, b.BinaryStorageSize()))
693	_, err := b.WriteTo(buffer)
694	if err != nil {
695		return nil, err
696	}
697
698	// URLEncode all bytes
699	return json.Marshal(base64.URLEncoding.EncodeToString(buffer.Bytes()))
700}
701
702// UnmarshalJSON unmarshals a BitSet from JSON created using MarshalJSON
703func (b *BitSet) UnmarshalJSON(data []byte) error {
704	// Unmarshal as string
705	var s string
706	err := json.Unmarshal(data, &s)
707	if err != nil {
708		return err
709	}
710
711	// URLDecode string
712	buf, err := base64.URLEncoding.DecodeString(s)
713	if err != nil {
714		return err
715	}
716
717	_, err = b.ReadFrom(bytes.NewReader(buf))
718	return err
719}
720