//------------------------------------------------------------------------------
// rbtree.go
// 红黑树
//------------------------------------------------------------------------------

package rbmap

const (
	RED   bool = true
	BLACK bool = false
)

type node struct {
	key   int64
	value interface{}

	left   *node
	right  *node
	parent *node

	color bool
}

var nilnode = &node{
	color: BLACK,
}

type rbtree struct {
	root *node
}

// 搜索key
func search(x *node, k int64) *node {
	if x == nilnode || k == x.key {
		return x
	}

	if k < x.key {
		return search(x.left, k)
	} else {
		return search(x.right, k)
	}
}

// 最小key元素
func maximum(x *node) *node {
	for x.left != nilnode {
		x = x.left
	}
	return x
}

// 最大key元素
func minimum(x *node) *node {
	for x.right != nilnode {
		x = x.right
	}
	return x
}

// 结点x的后继及结点
func successor(x *node) *node {
	if x.right != nilnode {
		return minimum(x)
	}

	y := x.parent

	for y != nilnode && x == y.right {
		x = y
		y = y.parent
	}

	return y
}

func remove(T *rbtree, z *node) *node {
	x := nilnode
	y := nilnode

	if z.left == nilnode || z.right == nilnode {
		y = z
	} else {
		y = successor(z)
	}

	if y.left != nilnode {
		x = y.left
	} else {
		x = y.right
	}

	x.parent = y.parent

	if y.parent == nilnode {
		T.root = x
	} else {
		if y == y.parent.left {
			y.parent.left = x
		} else {
			y.parent.right = x
		}
	}

	if y != z {
		z.key = y.key
		z.value = y.value
	}

	if y.color == BLACK {
		removefixup(T, x)
	}
	return y
}

func removefixup(T *rbtree, x *node) {
	for x != T.root && x.color == BLACK {
		if x == x.parent.left {
			w := x.parent.right
			if w.color == RED {
				w.color = BLACK
				x.parent.color = RED
				lrotated(T, x.parent)
				w = x.parent.right
			}

			if w.left.color == BLACK && w.right.color == BLACK {
				w.color = RED
				x = x.parent
			} else {
				if w.right.color == BLACK {
					w.left.color = BLACK
					w.color = RED
					rrotated(T, w)
					w = x.parent.right
				}

				w.color = x.parent.color
				x.parent.color = BLACK
				w.right.color = BLACK
				lrotated(T, x.parent)
				x = T.root
			}
		} else {
			w := x.parent.right
			if w.color == RED {
				w.color = BLACK
				x.parent.color = RED
				rrotated(T, x.parent)
				w = x.parent.left
			}

			if w.right.color == BLACK && w.left.color == BLACK {
				w.color = RED
				x = x.parent
			} else {
				if w.left.color == BLACK {
					w.right.color = BLACK
					w.color = RED
					lrotated(T, w)
					w = x.parent.left
				}

				w.color = x.parent.color
				x.parent.color = BLACK
				w.left.color = BLACK
				rrotated(T, x.parent)
				x = T.root
			}

		}
	}
}

// 中序遍历红黑树
func inorder(x *node, handler func(k int64, v interface{})) {
	if x != nilnode {
		inorder(x.left, handler)
		handler(x.key, x.value)
		inorder(x.right, handler)
	}
}

func insert(T *rbtree, z *node) {
	y := nilnode
	x := T.root

	for x != nilnode {
		y = x

		if z.key < x.key {
			x = x.left
		} else {
			x = x.right
		}
	}

	z.parent = y

	if y == nilnode {
		T.root = z
	} else if z.key < y.key {
		y.left = z
	} else {
		y.right = z
	}

	z.left = nilnode
	z.right = nilnode
	z.color = RED

	insertfixup(T, z)
}

func insertfixup(T *rbtree, z *node) {
	for z.parent.color == RED {
		if z.parent == z.parent.parent.left {
			y := z.parent.parent.right
			if y.color == RED {
				z.parent.color = BLACK
				y.color = BLACK
				z.parent.parent.color = RED
				z = z.parent.parent
			} else {
				if z == z.parent.right {
					z = z.parent
					lrotated(T, z)
				}

				z.parent.color = BLACK
				z.parent.parent.color = RED
				rrotated(T, z.parent.parent)
			}
		} else {
			y := z.parent.parent.left
			if y.color == RED {
				z.parent.color = BLACK
				y.color = BLACK
				z.parent.parent.color = RED
				z = z.parent.parent
			} else {
				if z == z.parent.left {
					z = z.parent
					rrotated(T, z)
				}

				z.parent.color = BLACK
				z.parent.parent.color = RED
				lrotated(T, z.parent.parent)
			}
		}
	}

	T.root.color = BLACK
}

// 左旋
func lrotated(T *rbtree, x *node) {
	y := x.right
	x.right = y.left

	if y.left != nilnode {
		y.left.parent = x
	}

	y.parent = x.parent

	if x.parent == nilnode {
		T.root = y
	} else if x == x.parent.left {
		x.parent.left = y
	} else {
		x.parent.right = y
	}
	y.left = x
	x.parent = y
}

// 右旋
func rrotated(T *rbtree, x *node) {
	y := x.left
	x.left = y.right

	if y.right != nilnode {
		y.right.parent = x
	}

	y.parent = x.parent

	if x.parent == nilnode {
		T.root = y
	} else if x == x.parent.left {
		x.parent.left = y
	} else {
		x.parent.right = y
	}

	y.right = x
	x.parent = y
}
//------------------------------------------------------------------------------
// rb_map.go
//
// 红黑树map
// Usage:
//
// func main() {
// 	  m := tmap.NewMap()
// 	  m.Put(10, "10")
// 	  m.Put(11, "11")
// 	  m.Put(15, "15")
// 	  m.Put(12, "12")
// 	  m.Put(12, "12-1")
// 	  m.Put(13, "13")
//
// 	  fmt.Println(m.Get(12))
// 	  fmt.Println(m.Get(18))
// 	  m.Range(func(k int64, v interface{}) {
// 		   fmt.Printf("%d %s\n", k, v.(string))
// 	})
// }
//
// 2019/8/27
//------------------------------------------------------------------------------

package rbmap

type Map struct {
	tree *rbtree
}

func NewMap() *Map {
	return &Map{
		tree: &rbtree{
			root: nilnode,
		},
	}
}

func (m *Map) Put(k int64, v interface{}) {
	f := search(m.tree.root, k)
	if f == nilnode {
		insert(m.tree, &node{
			key:    k,
			value:  v,
			left:   nilnode,
			right:  nilnode,
			parent: nilnode,
		})
	} else {
		f.value = v
	}
}

func (m *Map) Get(k int64) (interface{}, bool) {
	f := search(m.tree.root, k)
	if f == nilnode {
		return nil, false
	}

	return f.value, true
}

func (m *Map) Delete(k int64) interface{} {
	r := remove(m.tree, search(m.tree.root, k))
	if r == nilnode {
		return nil
	}

	return r.value
}

func (m *Map) Range(handle func(k int64, v interface{})) {
	x := m.tree.root
	if x != nilnode {
		inorder(x, handle)
	}
}