package main

import (
    "os"
    "io"
    "fmt"
    "net"
    "strings"
    "strconv"
    "syscall"
    "encoding/binary"
)

type CSPair struct {
    clientaddr net.Addr
    serveraddr net.Addr
    clientconn *net.TCPConn
    serverconn *net.TCPConn
}

const (
    SO_ORIGINAL_DST = 80
)

var (
    connection_count = 0
)

func main() {
    laddr := &net.TCPAddr{}
    laddr.Port = 8838
    ln, err := net.ListenTCP("tcp4", laddr)
    handle_error(err)
    fmt.Printf("listen on %d\n", laddr.Port)
    defer ln.Close()

    for {
        conn, err := ln.AcceptTCP()
        handle_error(err)
        pair := construct_connection(conn)
        handle_data(pair)
    }
}

func handle_data(pair *CSPair) {
    go handle_cs(pair)
    go handle_sc(pair)
}

func handle_cs(pair *CSPair) {
    defer pair.clientconn.Close()
    if strings.Index(pair.serveraddr.String(), ":843") != -1 {
        fmt.Println(":843 connection.")
        io.Copy(pair.serverconn, pair.clientconn)
        return
    }

    var remain_data []byte
    for {
        bs, err := readPacket(pair.clientconn)
        handle_error(err)
        remain_data = append(remain_data, bs...)
        packet_len := int(binary.LittleEndian.Uint32(remain_data))
        packet_len += 4 //fixed len.
        fmt.Printf("remain_data: 0x%x, packet_len: 0x%x\n", len(remain_data), packet_len)
        if packet_len > len(remain_data) {
            continue
        }
        packet_data := remain_data[:packet_len]
        remain_data = remain_data[packet_len:]
        //packet_data = append(packet_data, 0)
        fmt.Printf("receive 0x%x:%s\n", packet_len, string(packet_data))
        n, err := pair.serverconn.Write(packet_data)
        handle_error(err)
        fmt.Printf("handle_cs write 0x%x bytes\n", n)
    }
}

func handle_sc(pair *CSPair) {
    defer pair.serverconn.Close()
    io.Copy(pair.clientconn, pair.serverconn)
    fmt.Println("handle_sc close pair.serverconn")
    /*
    bs, err := readPacket(pair.serverconn)
    handle_error(err)
    fmt.Printf("handle_sc:%s\n", string(bs))
    pair.clientconn.Write(bs)
    */
}

func construct_connection(c *net.TCPConn) *CSPair {
    var pair = &CSPair{}
    pair.clientconn = c
    pair.clientaddr = (*c).RemoteAddr()
    f, err := c.File()
    handle_error(err)

    addr, err :=  syscall.GetsockoptIPv6Mreq(int(f.Fd()), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
    handle_error(err)

    ipv4 := strconv.Itoa(int(addr.Multiaddr[4])) + "." +
            strconv.Itoa(int(addr.Multiaddr[5])) + "." +
            strconv.Itoa(int(addr.Multiaddr[6])) + "." +
            strconv.Itoa(int(addr.Multiaddr[7]))

    port := uint16(addr.Multiaddr[2]) << 8 + uint16(addr.Multiaddr[3])
    origin_ipv4 := ipv4
    origin_port := port

    sa, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ipv4, port))
    handle_error(err)
    pair.serveraddr = sa
    pair.serverconn, err = net.DialTCP("tcp4", nil, sa)
    handle_error(err)

    connection_count++
    fmt.Printf("accept %d, %s and create a new connection to server %s(%s:%d)\n",
        connection_count, pair.clientaddr.String(), pair.serveraddr.String(), origin_ipv4, origin_port)
    return pair
}

func handle_error(err error) {
    if err != nil {
        fmt.Println(err)
        os.Exit(1)
    }
}