golang set 集合类的实现

简介

Set 是一种常见的集合类不允许储存同样的元素, 并提供交集, 并集, 差集的方法(常用作推荐系统)

setzset

Golang 中没有官方的集合类

下面基于 golang 的map 实现基本的集合的, 可以配置是否需要线程安全

接口方法

type SetIn interface {
	//添加
	Add(element interface{})
	//批量添加
	AddALL(elements []interface{})
	//删除
	Del(element interface{})
	//查看集合中是否含有元素
	Exists(element interface{}) (exists bool)
	//集合是否为空
	IsEmpty() (isEmpty bool)
	//返回集合长度
	Len() (length int)
	//返回集合所有元素,乱序
	All() (elements []interface{})
	//Inter 交集
	Inter(sets ...*Set) (resultSet *Set)
	//Union 并集
	Union(sets ...*Set) (resultSet *Set)
	//Diff 差集
	Diff(sets ...*Set) (resultSet *Set)
}

核心代码

set.go

package set

import (
	"sync"

	"github.com/dengjiawen8955/gostl/util/gosync"
)

// Options holds the Set's options
type Options struct {
	locker gosync.Locker
}

// Option is a function  type used to set Options
type Option func(option *Options)

func WithSync() Option {
	return func(option *Options) {
		option.locker = &sync.RWMutex{}
	}
}

type Set struct {
	//集合的底层用 map 实现
	setMap map[interface{}]struct{}
	//可选: 锁
	locker gosync.Locker
}

//返回一个空的 set 对象
//opts 支持线程安全
func New(opts ...Option) (set *Set) {
	option := &Options{
		// 默认使用假锁,线程不安全
		locker: gosync.FakeLocker{},
	}
	//如果 opts 选项中有锁,将会在这里加锁
	for _, opt := range opts {
		opt(option)
	}
	return &Set{
		setMap: make(map[interface{}]struct{}),
		locker: option.locker,
	}
}

//添加
func (s *Set) Add(element interface{}) {
	s.locker.Lock()
	defer s.locker.Unlock()
	s.setMap[element] = struct{}{}
}

//删除
func (s *Set) Del(element interface{}) {
	s.locker.Lock()
	defer s.locker.Unlock()
	delete(s.setMap, element)
}

//查看集合中是否含有元素
func (s *Set) Exists(element interface{}) (has bool) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	_, has = s.setMap[element]
	return
}

//集合是否为空
func (s *Set) IsEmpty() (isEmpty bool) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	return len(s.setMap) == 0
}

//返回集合长度
func (s *Set) Len() (length int) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	return len(s.setMap)
}

//返回集合所有元素,乱序
func (s *Set) All() (elements []interface{}) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	elements = make([]interface{}, 0)
	for element, _ := range s.setMap {
		elements = append(elements, element)
	}
	return
}

//Inter 交集(默认返回线程不安全的集合)
func (s *Set) Inter(sets ...*Set) (resultSet *Set) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	resultSet = New()
	for e1, _ := range s.setMap {
		isInter := true
		for _, set := range sets {
			if !set.Exists(e1) {
				isInter = false
				break
			}
		}
		if isInter {
			resultSet.Add(e1)
		}
	}
	return
}

//Union 并集(默认返回线程不安全的集合)
//todo 使用迭代器
func (s *Set) Union(sets ...*Set) (resultSet *Set) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	resultSet = New()
	for e1 := range s.setMap {
		resultSet.Add(e1)
	}
	for _, set := range sets {
		for e2 := range set.setMap {
			resultSet.Add(e2)
		}
	}
	return
}

//Diff 差集(默认返回线程不安全的集合)
func (s *Set) Diff(sets ...*Set) (resultSet *Set) {
	s.locker.RLock()
	defer s.locker.RUnlock()
	resultSet = New()
	for e1 := range s.setMap {
		isDiff := true
		for _, set := range sets {
			if set.Exists(e1) {
				isDiff = false
				break
			}
		}
		if isDiff {
			resultSet.Add(e1)
		}
	}
	return
}

fake_locker.go 假锁

fake_locker.go 主要是提供 Locker 的读写锁接口和 FackerLocker 假锁

如果用户不选择线程安全的 set , 就默认使用假锁

package gosync

// Locker define an abstract locker interface
type Locker interface {
	Lock()
	Unlock()
	RLock()
	RUnlock()
}

// FakeLocker is a fake locker
type FakeLocker struct {
}

// Lock does nothing
func (l FakeLocker) Lock() {

}

// Unlock does nothing
func (l FakeLocker) Unlock() {

}

// RLock does nothing
func (l FakeLocker) RLock() {

}

// RUnlock does nothing
func (l FakeLocker) RUnlock() {

}

测试

下面是 set 集合的测试用例

package set

import (
	"fmt"
	"testing"
)

//功能测试
// s1.Exists(1): true
// s1.Exists(3): false
// s1.Exists(1): false
// is: [2]
// 交集: [2]
// 并集=[]interface {}{1, 2, 3}
// 差集=[]interface {}{1}
func TestSet(t *testing.T) {
	s1 := New()
	s1.Add(1)
	s1.Add(2)
	fmt.Printf("s1.Exists(1): %v\n", s1.Exists(1))
	fmt.Printf("s1.Exists(3): %v\n", s1.Exists(3))
	s1.Del(1)
	fmt.Printf("s1.Exists(1): %v\n", s1.Exists(1))
	is := s1.All()
	fmt.Printf("is: %v\n", is)
	s1.Add(1)
	s2 := New()
	s2.Add(2)
	s2.Add(3)
	//交集
	s3 := s1.Inter(s2)
	s3s := s3.All()
	fmt.Printf("交集: %v\n", s3s)
	//并集
	s4 := s1.Union(s2)
	s4s := s4.All()
	fmt.Printf("并集=%#v\n", s4s)
	//差集
	s5 := s1.Diff(s2)
	s5s := s5.All()
	fmt.Printf("差集=%#v\n", s5s)
}

func TestWithSync(t *testing.T) {
	s1 := New(WithSync())
	go func() {
		for i := 0; i < 100000; i++ {
			s1.Add(1)
		}
	}()
	for i := 0; i < 100000; i++ {
		s1.Exists(1)
	}
	fmt.Printf("%s\n", "OK")
}

//会发生panic=
func TestWithoutSync(t *testing.T) {
	s1 := New()
	go func() {
		for i := 0; i < 100000; i++ {
			s1.Add(1)
		}
	}()
	for i := 0; i < 100000; i++ {
		s1.Exists(1)
	}
	fmt.Printf("%s\n", "OK")
}

refrence

仓库地址, gostl 会持续更新常见的 stl 标准库, 类似 C++ stl

https://github.com/dengjiawen8955/gostl