问题描述

go没有提供set数据结构,请用map实现set

要点

需要支持方法:

  • Add 添加元素
  • Remove 删除元素
  • Cardinality 获取 Set 长度
  • Clear 清空 Set
  • Contains 检测元素是否在 Set 中
  • Pop() 随机删除一个元素并返回被删除的元素
  • ToSlice() []interface{} 转换成slice返回

拓展

  • Clone 复制 Set
  • Difference(other Set) Set 返回和另一个Set的差集
  • Equal(other Set) bool 判断和另一个Set是否相等
  • Intersect(other Set) Set 返回和另一个Set的交集
  • SymmetricDifference(other Set) Set 返回不在交集中的元素的集合
  • Union(other Set) Set 返回并集
  • 实现一个线程安全的版本

实现


//set.go

// Package mapset implements a simple and generic set collection.
// Items stored within it are unordered and unique. It supports
// typical set operations: membership testing, intersection, union,
// difference, symmetric difference and cloning.
//
// Package mapset provides two implementations of the Set
// interface. The default implementation is safe for concurrent
// access, but a non-thread-safe implementation is also provided for
// programs that can benefit from the slight speed improvement and
// that can enforce mutual exclusion through other means.
package mapset

// Set is the primary interface provided by the mapset package.  It
// represents an unordered set of data and a large number of
// operations that can be applied to that set.
type Set interface {
    // Adds an element to the set. Returns whether
    // the item was added.
    Add(i interface{}) bool

    // Returns the number of elements in the set.
    Cardinality() int

    // Removes all elements from the set, leaving
    // the empty set.
    Clear()

    // Returns a clone of the set using the same
    // implementation, duplicating all keys.
    Clone() Set

    // Returns whether the given items
    // are all in the set.
    Contains(i ...interface{}) bool

    // Returns the difference between this set
    // and other. The returned set will contain
    // all elements of this set that are not also
    // elements of other.
    //
    // Note that the argument to Difference
    // must be of the same type as the receiver
    // of the method. Otherwise, Difference will
    // panic.
    Difference(other Set) Set

    // Determines if two sets are equal to each
    // other. If they have the same cardinality
    // and contain the same elements, they are
    // considered equal. The order in which
    // the elements were added is irrelevant.
    //
    // Note that the argument to Equal must be
    // of the same type as the receiver of the
    // method. Otherwise, Equal will panic.
    Equal(other Set) bool

    // Returns a new set containing only the elements
    // that exist only in both sets.
    //
    // Note that the argument to Intersect
    // must be of the same type as the receiver
    // of the method. Otherwise, Intersect will
    // panic.
    Intersect(other Set) Set

    // Determines if every element in this set is in
    // the other set but the two sets are not equal.
    //
    // Note that the argument to IsProperSubset
    // must be of the same type as the receiver
    // of the method. Otherwise, IsProperSubset
    // will panic.
    IsProperSubset(other Set) bool

    // Determines if every element in the other set
    // is in this set but the two sets are not
    // equal.
    //
    // Note that the argument to IsSuperset
    // must be of the same type as the receiver
    // of the method. Otherwise, IsSuperset will
    // panic.
    IsProperSuperset(other Set) bool

    // Determines if every element in this set is in
    // the other set.
    //
    // Note that the argument to IsSubset
    // must be of the same type as the receiver
    // of the method. Otherwise, IsSubset will
    // panic.
    IsSubset(other Set) bool

    // Determines if every element in the other set
    // is in this set.
    //
    // Note that the argument to IsSuperset
    // must be of the same type as the receiver
    // of the method. Otherwise, IsSuperset will
    // panic.
    IsSuperset(other Set) bool

    // Iterates over elements and executes the passed func against each element.
    // If passed func returns true, stop iteration at the time.
    Each(func(interface{}) bool)

    // Returns a channel of elements that you can
    // range over.
    Iter() <-chan interface{}

    // Returns an Iterator object that you can
    // use to range over the set.
    Iterator() *Iterator

    // Remove a single element from the set.
    Remove(i interface{})

    // Provides a convenient string representation
    // of the current state of the set.
    String() string

    // Returns a new set with all elements which are
    // in either this set or the other set but not in both.
    //
    // Note that the argument to SymmetricDifference
    // must be of the same type as the receiver
    // of the method. Otherwise, SymmetricDifference
    // will panic.
    SymmetricDifference(other Set) Set

    // Returns a new set with all elements in both sets.
    //
    // Note that the argument to Union must be of the

    // same type as the receiver of the method.
    // Otherwise, IsSuperset will panic.
    Union(other Set) Set

    // Pop removes and returns an arbitrary item from the set.
    Pop() interface{}

    // Returns all subsets of a given set (Power Set).
    PowerSet() Set

    // Returns the Cartesian Product of two sets.
    CartesianProduct(other Set) Set

    // Returns the members of the set as a slice.
    ToSlice() []interface{}
}

// NewSet creates and returns a reference to an empty set.  Operations
// on the resulting set are thread-safe.
func NewSet(s ...interface{}) Set {
    set := newThreadSafeSet()
    for _, item := range s {
        set.Add(item)
    }
    return &set
}

// NewSetWith creates and returns a new set with the given elements.
// Operations on the resulting set are thread-safe.
func NewSetWith(elts ...interface{}) Set {
    return NewSetFromSlice(elts)
}

// NewSetFromSlice creates and returns a reference to a set from an
// existing slice.  Operations on the resulting set are thread-safe.
func NewSetFromSlice(s []interface{}) Set {
    a := NewSet(s...)
    return a
}

// NewThreadUnsafeSet creates and returns a reference to an empty set.
// Operations on the resulting set are not thread-safe.
func NewThreadUnsafeSet() Set {
    set := newThreadUnsafeSet()
    return &set
}

// NewThreadUnsafeSetFromSlice creates and returns a reference to a
// set from an existing slice.  Operations on the resulting set are
// not thread-safe.
func NewThreadUnsafeSetFromSlice(s []interface{}) Set {
    a := NewThreadUnsafeSet()
    for _, item := range s {
        a.Add(item)
    }
    return a
}



// iterator.go

package mapset

// Iterator defines an iterator over a Set, its C channel can be used to range over the Set's
// elements.
type Iterator struct {
    C    <-chan interface{}
    stop chan struct{}
}

// Stop stops the Iterator, no further elements will be received on C, C will be closed.
func (i *Iterator) Stop() {
    // Allows for Stop() to be called multiple times
    // (close() panics when called on already closed channel)
    defer func() {
        recover()
    }()

    close(i.stop)

    // Exhaust any remaining elements.
    for range i.C {
    }
}

// newIterator returns a new Iterator instance together with its item and stop channels.
func newIterator() (*Iterator, chan<- interface{}, <-chan struct{}) {
    itemChan := make(chan interface{})
    stopChan := make(chan struct{})
    return &Iterator{
        C:    itemChan,
        stop: stopChan,
    }, itemChan, stopChan
}


// threadunsafe.go

package mapset

import (
    "bytes"
    "encoding/json"
    "fmt"
    "reflect"
    "strings"
)

type threadUnsafeSet map[interface{}]struct{}

// An OrderedPair represents a 2-tuple of values.
type OrderedPair struct {
    First  interface{}
    Second interface{}
}

func newThreadUnsafeSet() threadUnsafeSet {
    return make(threadUnsafeSet)
}

// Equal says whether two 2-tuples contain the same values in the same order.
func (pair *OrderedPair) Equal(other OrderedPair) bool {
    if pair.First == other.First &&
        pair.Second == other.Second {
        return true
    }

    return false
}

func (set *threadUnsafeSet) Add(i interface{}) bool {
    _, found := (*set)[i]
    if found {
        return false //False if it existed already
    }

    (*set)[i] = struct{}{}
    return true
}

func (set *threadUnsafeSet) Contains(i ...interface{}) bool {
    for _, val := range i {
        if _, ok := (*set)[val]; !ok {
            return false
        }
    }
    return true
}

func (set *threadUnsafeSet) IsSubset(other Set) bool {
    _ = other.(*threadUnsafeSet)
    if set.Cardinality() > other.Cardinality() {
        return false
    }
    for elem := range *set {
        if !other.Contains(elem) {
            return false
        }
    }
    return true
}

func (set *threadUnsafeSet) IsProperSubset(other Set) bool {
    return set.IsSubset(other) && !set.Equal(other)
}

func (set *threadUnsafeSet) IsSuperset(other Set) bool {
    return other.IsSubset(set)
}

func (set *threadUnsafeSet) IsProperSuperset(other Set) bool {
    return set.IsSuperset(other) && !set.Equal(other)
}

func (set *threadUnsafeSet) Union(other Set) Set {
    o := other.(*threadUnsafeSet)

    unionedSet := newThreadUnsafeSet()

    for elem := range *set {
        unionedSet.Add(elem)
    }
    for elem := range *o {
        unionedSet.Add(elem)
    }
    return &unionedSet
}

func (set *threadUnsafeSet) Intersect(other Set) Set {
    o := other.(*threadUnsafeSet)

    intersection := newThreadUnsafeSet()
    // loop over smaller set
    if set.Cardinality() < other.Cardinality() {
        for elem := range *set {
            if other.Contains(elem) {
                intersection.Add(elem)
            }
        }
    } else {
        for elem := range *o {
            if set.Contains(elem) {
                intersection.Add(elem)
            }
        }
    }
    return &intersection
}

func (set *threadUnsafeSet) Difference(other Set) Set {
    _ = other.(*threadUnsafeSet)

    difference := newThreadUnsafeSet()
    for elem := range *set {
        if !other.Contains(elem) {
            difference.Add(elem)
        }
    }
    return &difference
}

func (set *threadUnsafeSet) SymmetricDifference(other Set) Set {
    _ = other.(*threadUnsafeSet)

    aDiff := set.Difference(other)
    bDiff := other.Difference(set)
    return aDiff.Union(bDiff)
}

func (set *threadUnsafeSet) Clear() {
    *set = newThreadUnsafeSet()
}

func (set *threadUnsafeSet) Remove(i interface{}) {
    delete(*set, i)
}

func (set *threadUnsafeSet) Cardinality() int {
    return len(*set)
}

func (set *threadUnsafeSet) Each(cb func(interface{}) bool) {
    for elem := range *set {
        if cb(elem) {
            break
        }
    }
}

func (set *threadUnsafeSet) Iter() <-chan interface{} {
    ch := make(chan interface{})
    go func() {
        for elem := range *set {
            ch <- elem
        }
        close(ch)
    }()

    return ch
}

func (set *threadUnsafeSet) Iterator() *Iterator {
    iterator, ch, stopCh := newIterator()

    go func() {
    L:
        for elem := range *set {
            select {
            case <-stopCh:
                break L
            case ch <- elem:
            }
        }
        close(ch)
    }()

    return iterator
}

func (set *threadUnsafeSet) Equal(other Set) bool {
    _ = other.(*threadUnsafeSet)

    if set.Cardinality() != other.Cardinality() {
        return false
    }
    for elem := range *set {
        if !other.Contains(elem) {
            return false
        }
    }
    return true
}

func (set *threadUnsafeSet) Clone() Set {
    clonedSet := newThreadUnsafeSet()
    for elem := range *set {
        clonedSet.Add(elem)
    }
    return &clonedSet
}

func (set *threadUnsafeSet) String() string {
    items := make([]string, 0, len(*set))

    for elem := range *set {
        items = append(items, fmt.Sprintf("%v", elem))
    }
    return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
}

// String outputs a 2-tuple in the form "(A, B)".
func (pair OrderedPair) String() string {
    return fmt.Sprintf("(%v, %v)", pair.First, pair.Second)
}

func (set *threadUnsafeSet) Pop() interface{} {
    for item := range *set {
        delete(*set, item)
        return item
    }
    return nil
}

func (set *threadUnsafeSet) PowerSet() Set {
    powSet := NewThreadUnsafeSet()
    nullset := newThreadUnsafeSet()
    powSet.Add(&nullset)

    for es := range *set {
        u := newThreadUnsafeSet()
        j := powSet.Iter()
        for er := range j {
            p := newThreadUnsafeSet()
            if reflect.TypeOf(er).Name() == "" {
                k := er.(*threadUnsafeSet)
                for ek := range *(k) {
                    p.Add(ek)
                }
            } else {
                p.Add(er)
            }
            p.Add(es)
            u.Add(&p)
        }

        powSet = powSet.Union(&u)
    }

    return powSet
}

func (set *threadUnsafeSet) CartesianProduct(other Set) Set {
    o := other.(*threadUnsafeSet)
    cartProduct := NewThreadUnsafeSet()

    for i := range *set {
        for j := range *o {
            elem := OrderedPair{First: i, Second: j}
            cartProduct.Add(elem)
        }
    }

    return cartProduct
}

func (set *threadUnsafeSet) ToSlice() []interface{} {
    keys := make([]interface{}, 0, set.Cardinality())
    for elem := range *set {
        keys = append(keys, elem)
    }

    return keys
}

// MarshalJSON creates a JSON array from the set, it marshals all elements
func (set *threadUnsafeSet) MarshalJSON() ([]byte, error) {
    items := make([]string, 0, set.Cardinality())

    for elem := range *set {
        b, err := json.Marshal(elem)
        if err != nil {
            return nil, err
        }

        items = append(items, string(b))
    }

    return []byte(fmt.Sprintf("[%s]", strings.Join(items, ","))), nil
}

// UnmarshalJSON recreates a set from a JSON array, it only decodes
// primitive types. Numbers are decoded as json.Number.
func (set *threadUnsafeSet) UnmarshalJSON(b []byte) error {
    var i []interface{}

    d := json.NewDecoder(bytes.NewReader(b))
    d.UseNumber()
    err := d.Decode(&i)
    if err != nil {
        return err
    }

    for _, v := range i {
        switch t := v.(type) {
        case []interface{}, map[string]interface{}:
            continue
        default:
            set.Add(t)
        }
    }

    return nil
}



// threadsafe.go

package mapset

import "sync"

type threadSafeSet struct {
    s threadUnsafeSet
    sync.RWMutex
}

func newThreadSafeSet() threadSafeSet {
    return threadSafeSet{s: newThreadUnsafeSet()}
}

func (set *threadSafeSet) Add(i interface{}) bool {
    set.Lock()
    ret := set.s.Add(i)
    set.Unlock()
    return ret
}

func (set *threadSafeSet) Contains(i ...interface{}) bool {
    set.RLock()
    ret := set.s.Contains(i...)
    set.RUnlock()
    return ret
}

func (set *threadSafeSet) IsSubset(other Set) bool {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    ret := set.s.IsSubset(&o.s)
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) IsProperSubset(other Set) bool {
    o := other.(*threadSafeSet)

    set.RLock()
    defer set.RUnlock()
    o.RLock()
    defer o.RUnlock()

    return set.s.IsProperSubset(&o.s)
}

func (set *threadSafeSet) IsSuperset(other Set) bool {
    return other.IsSubset(set)
}

func (set *threadSafeSet) IsProperSuperset(other Set) bool {
    return other.IsProperSubset(set)
}

func (set *threadSafeSet) Union(other Set) Set {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    unsafeUnion := set.s.Union(&o.s).(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeUnion}
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) Intersect(other Set) Set {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    unsafeIntersection := set.s.Intersect(&o.s).(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeIntersection}
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) Difference(other Set) Set {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    unsafeDifference := set.s.Difference(&o.s).(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeDifference}
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) SymmetricDifference(other Set) Set {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    unsafeDifference := set.s.SymmetricDifference(&o.s).(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeDifference}
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) Clear() {
    set.Lock()
    set.s = newThreadUnsafeSet()
    set.Unlock()
}

func (set *threadSafeSet) Remove(i interface{}) {
    set.Lock()
    delete(set.s, i)
    set.Unlock()
}

func (set *threadSafeSet) Cardinality() int {
    set.RLock()
    defer set.RUnlock()
    return len(set.s)
}

func (set *threadSafeSet) Each(cb func(interface{}) bool) {
    set.RLock()
    for elem := range set.s {
        if cb(elem) {
            break
        }
    }
    set.RUnlock()
}

func (set *threadSafeSet) Iter() <-chan interface{} {
    ch := make(chan interface{})
    go func() {
        set.RLock()

        for elem := range set.s {
            ch <- elem
        }
        close(ch)
        set.RUnlock()
    }()

    return ch
}

func (set *threadSafeSet) Iterator() *Iterator {
    iterator, ch, stopCh := newIterator()

    go func() {
        set.RLock()
    L:
        for elem := range set.s {
            select {
            case <-stopCh:
                break L
            case ch <- elem:
            }
        }
        close(ch)
        set.RUnlock()
    }()

    return iterator
}

func (set *threadSafeSet) Equal(other Set) bool {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    ret := set.s.Equal(&o.s)
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) Clone() Set {
    set.RLock()

    unsafeClone := set.s.Clone().(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeClone}
    set.RUnlock()
    return ret
}

func (set *threadSafeSet) String() string {
    set.RLock()
    ret := set.s.String()
    set.RUnlock()
    return ret
}

func (set *threadSafeSet) PowerSet() Set {
    set.RLock()
    unsafePowerSet := set.s.PowerSet().(*threadUnsafeSet)
    set.RUnlock()

    ret := &threadSafeSet{s: newThreadUnsafeSet()}
    for subset := range unsafePowerSet.Iter() {
        unsafeSubset := subset.(*threadUnsafeSet)
        ret.Add(&threadSafeSet{s: *unsafeSubset})
    }
    return ret
}

func (set *threadSafeSet) Pop() interface{} {
    set.Lock()
    defer set.Unlock()
    return set.s.Pop()
}

func (set *threadSafeSet) CartesianProduct(other Set) Set {
    o := other.(*threadSafeSet)

    set.RLock()
    o.RLock()

    unsafeCartProduct := set.s.CartesianProduct(&o.s).(*threadUnsafeSet)
    ret := &threadSafeSet{s: *unsafeCartProduct}
    set.RUnlock()
    o.RUnlock()
    return ret
}

func (set *threadSafeSet) ToSlice() []interface{} {
    keys := make([]interface{}, 0, set.Cardinality())
    set.RLock()
    for elem := range set.s {
        keys = append(keys, elem)
    }
    set.RUnlock()
    return keys
}

func (set *threadSafeSet) MarshalJSON() ([]byte, error) {
    set.RLock()
    b, err := set.s.MarshalJSON()
    set.RUnlock()

    return b, err
}

func (set *threadSafeSet) UnmarshalJSON(p []byte) error {
    set.RLock()
    err := set.s.UnmarshalJSON(p)
    set.RUnlock()

    return err
}

zsxq.png