package main
import (
	"fmt"
	"sync"
)
//golang 红黑树
//只是实现了添加节点和一个分层打印节点的函数
//ref: https://www.jianshu.com/p/e136ec79235c
const (
	Red = 1
	Black = 2
)
type node struct {
	right  *node
	left   *node
	parent *node
	key    int
	color  int
}
type RBTree struct {
	root *node
	//the number of members
	len  int
}
//add operation occured only on leaf
//so just ignore the parent's original son node
//color of the added node is always red
func (n *node) addLeft(key int) *node {
	add := new(node)
	add.key = key
	add.color = Red
	add.parent = n
	n.left = add
	return add
}
//add operation occured only on leaf
//so just ignore the parent's original son node
//color of the added node is always red
func (n *node) addRight(key int) *node {
	add := new(node)
	add.key = key
	add.color = Red
	add.parent = n
	n.right = add
	return add
}
func (rb *RBTree) rotateLeft(n *node) {
	parent := n.parent
	right := n.right
	n.right = right.left
	if right.left != nil {
		right.left.parent = n
	}
	right.left = n
	n.parent = right
	if parent == nil {
		right.parent = nil
		rb.root = right
	} else {
		if n == parent.left {
			parent.left = right
		} else {
			parent.right = right
		}
		right.parent = parent
	}
}
func (rb *RBTree) rotateRight(n *node) {
	parent := n.parent
	left := n.left
	n.left = left.right
	if left.right != nil {
		left.right.parent = n
	}
	left.right = n
	n.parent = left
	if parent == nil {
		left.parent = nil
		rb.root = left
	} else {
		if n == parent.left {
			parent.left = left
		} else {
			parent.right = left
		}
		left.parent = parent
	}
}
func (rb *RBTree) insert(key int) {
	if rb.root == nil {
		rb.root = new(node)
		rb.root.key = key
		rb.root.color = Black
	} else {
		p, ok := rb.find(key)
		if ok {
			//todo replace node value
			return
		}
		//here p must not be nil
		if p.key > key {
			add := p.addLeft(key)
			rb.ajustAdd(add)
		} else {
			add := p.addRight(key)
			rb.ajustAdd(add)
		}
	}
	rb.len++
}
func (rb *RBTree) ajustAdd(add *node) {
	t := add
	p := add.parent
	for p != nil && p.color != Black {
		//pp must not be nil because p is red
		pp := p.parent
		// get uncle
		uncle := pp.left
		if pp.left == p {
			uncle = pp.right
		}
		if uncle == nil || uncle.color == Black {
			if p == pp.left {
				if t == p.right {
					rb.rotateLeft(p)
				}
				rb.rotateRight(pp)
			} else {
				if t == p.left {
					rb.rotateRight(p)
				}
				rb.rotateLeft(pp)
			}
			pp.color = Red
			pp.parent.color = Black
			break
		} else {
			// uncle must be red
			p.color = Black
			uncle.color = Black
			pp.color = Red
			t = pp
			p = t.parent
		}
	}
	if p == nil {
		rb.root = t
		rb.root.color = Black
		rb.root.parent = nil
	}
}
//if the second parameter return true,it indicates found
//else return the parent node that the key will be added
func (rb *RBTree) find(key int) (*node, bool) {
	if rb.root == nil {
		return nil, false
	}
	c := rb.root
	p := c
	for c != nil {
		if c.key > key {
			p = c
			c = c.left
		} else if c.key < key {
			p = c
			c = c.right
		} else {
			return c, true
		}
	}
	return p, false
}
//just for debug
func (rb *RBTree)print() {
	ch := make(chan interface{}, 1000)
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		var dealed int
		for {
			select {
			case v := <-ch:
				if n, ok := v.(*node); ok {
					if n == nil {
						continue
					}
					dealed++
					color := "red"
					if n.color == Black {
						color = "black"
					}
					pos := ""
					if n.parent == nil {
						pos = "root"
					} else if n == n.parent.left {
						pos = "left"
					} else if n == n.parent.right {
						pos = "right"
					}
					if n.parent != nil {
						fmt.Printf(" %d(p %d,c %s,p %s)", n.key, n.parent.key, color, pos)
					} else {
						fmt.Printf(" %d(p %s,c %s,p %s)", n.key, "", color, pos)
					}
					ch <- n.left
					ch <- n.right
				} else {
					fmt.Println()
					if dealed == rb.len {
						return
					}
					ch <- "next"
				}
			}
		}
	}()
	ch <- rb.root
	ch <- "next"
	wg.Wait()
}
func main() {
	rb := new(RBTree)
	rb.insert(1)
	rb.insert(2)
	rb.insert(3)
	rb.insert(4)
	rb.insert(5)
	rb.insert(6)
	rb.insert(7)
	rb.insert(8)
	rb.insert(9)
	rb.insert(10)
	rb.insert(11)
	rb.insert(12)
	rb.insert(13)
	rb.print()
}