package main
import (
"os"
"io"
"fmt"
"net"
"strings"
"strconv"
"syscall"
"encoding/binary"
)
type CSPair struct {
clientaddr net.Addr
serveraddr net.Addr
clientconn *net.TCPConn
serverconn *net.TCPConn
}
const (
SO_ORIGINAL_DST = 80
)
var (
connection_count = 0
)
func main() {
laddr := &net.TCPAddr{}
laddr.Port = 8838
ln, err := net.ListenTCP("tcp4", laddr)
handle_error(err)
fmt.Printf("listen on %d\n", laddr.Port)
defer ln.Close()
for {
conn, err := ln.AcceptTCP()
handle_error(err)
pair := construct_connection(conn)
handle_data(pair)
}
}
func handle_data(pair *CSPair) {
go handle_cs(pair)
go handle_sc(pair)
}
func handle_cs(pair *CSPair) {
defer pair.clientconn.Close()
if strings.Index(pair.serveraddr.String(), ":843") != -1 {
fmt.Println(":843 connection.")
io.Copy(pair.serverconn, pair.clientconn)
return
}
var remain_data []byte
for {
bs, err := readPacket(pair.clientconn)
handle_error(err)
remain_data = append(remain_data, bs...)
packet_len := int(binary.LittleEndian.Uint32(remain_data))
packet_len += 4 //fixed len.
fmt.Printf("remain_data: 0x%x, packet_len: 0x%x\n", len(remain_data), packet_len)
if packet_len > len(remain_data) {
continue
}
packet_data := remain_data[:packet_len]
remain_data = remain_data[packet_len:]
//packet_data = append(packet_data, 0)
fmt.Printf("receive 0x%x:%s\n", packet_len, string(packet_data))
n, err := pair.serverconn.Write(packet_data)
handle_error(err)
fmt.Printf("handle_cs write 0x%x bytes\n", n)
}
}
func handle_sc(pair *CSPair) {
defer pair.serverconn.Close()
io.Copy(pair.clientconn, pair.serverconn)
fmt.Println("handle_sc close pair.serverconn")
/*
bs, err := readPacket(pair.serverconn)
handle_error(err)
fmt.Printf("handle_sc:%s\n", string(bs))
pair.clientconn.Write(bs)
*/
}
func construct_connection(c *net.TCPConn) *CSPair {
var pair = &CSPair{}
pair.clientconn = c
pair.clientaddr = (*c).RemoteAddr()
f, err := c.File()
handle_error(err)
addr, err := syscall.GetsockoptIPv6Mreq(int(f.Fd()), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
handle_error(err)
ipv4 := strconv.Itoa(int(addr.Multiaddr[4])) + "." +
strconv.Itoa(int(addr.Multiaddr[5])) + "." +
strconv.Itoa(int(addr.Multiaddr[6])) + "." +
strconv.Itoa(int(addr.Multiaddr[7]))
port := uint16(addr.Multiaddr[2]) << 8 + uint16(addr.Multiaddr[3])
origin_ipv4 := ipv4
origin_port := port
sa, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ipv4, port))
handle_error(err)
pair.serveraddr = sa
pair.serverconn, err = net.DialTCP("tcp4", nil, sa)
handle_error(err)
connection_count++
fmt.Printf("accept %d, %s and create a new connection to server %s(%s:%d)\n",
connection_count, pair.clientaddr.String(), pair.serveraddr.String(), origin_ipv4, origin_port)
return pair
}
func handle_error(err error) {
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}