prober.go

package main

import (
    "cnprober/log"
    "flag"
    "fmt"
    "io"
    "strings"

    "os"
    "time"

    "github.com/miekg/dns"
    "github.com/remeh/sizedwaitgroup"
    "github.com/spf13/viper"
)

var logger = log.NewLogger()
var timeout time.Duration = 5000 * time.Millisecond

type Result struct {
    Servername string
    Serveraddr string
    Rtt        time.Duration // 毫秒
    DomainType string
    IPType     string
}

func Lookup(fqdn, serverAddr string) (Result, error) {
    var m dns.Msg
    var result Result
    client := dns.Client{}
    client.Timeout = timeout
    m.SetQuestion(dns.Fqdn(fqdn), dns.TypeNS)
    _, rtt, err := client.Exchange(&m, serverAddr+":53")
    if err != nil {
        result.Rtt = timeout
        logger.Log("level", "error", "serverAddr", serverAddr, "resolveTime", result.Rtt)
        return result, err
    }
    logger.Log("level", "info", "serverAddr", serverAddr, "resolveTime", rtt)

    if rtt.Microseconds() < 1000 {
        rtt = 1 * time.Millisecond
    }
    result.Rtt = rtt
    return result, nil
}

func worker(fqdn string, serverName, serverAddr string, resultChannel chan Result, swg *sizedwaitgroup.SizedWaitGroup, flag string) {
    result, _ := Lookup(fqdn, serverAddr)
    domain_ip_type := strings.Split(flag, ".")
    domainType := domain_ip_type[0]
    ipType := domain_ip_type[1]
    result.DomainType = domainType
    result.IPType = ipType
    result.Serveraddr = serverAddr
    result.Servername = serverName
    resultChannel <- result
}

func openFile(filename string) (*os.File, error) {
    return os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY|os.O_TRUNC, 0666)
}

func writeFile(line string, f *os.File) {
    line = line + "\n"
    _, err := io.WriteString(f, line)
    if err != nil {
        logger.Log("level", "error", "msg", err)
    }
}

func run(flDomain *string, flWorkerCount *int, flResultProm *string, config *viper.Viper) {
    swg := sizedwaitgroup.New(*flWorkerCount)
    resultsChannel := make(chan Result)
    var results []Result

    go func() {
        for r := range resultsChannel {
            results = append(results, r)
            swg.Done()
        }
    }()

    for serverName, ip := range config.GetStringMap("cn.ipv4") {
        swg.Add()

        go worker(*flDomain, serverName, ip.(string), resultsChannel, &swg, "cn.ipv4")
    }
    for serverName, ip := range config.GetStringMap("cn.ipv6") {
        swg.Add()
        serverAddr := fmt.Sprintf("[%s]", ip.(string))
        go worker(*flDomain, serverName, serverAddr, resultsChannel, &swg, "cn.ipv6")
    }
    for serverName, ip := range config.GetStringMap("root.ipv4") {
        swg.Add()
        go worker(*flDomain, serverName, ip.(string), resultsChannel, &swg, "root.ipv4")
    }
    for serverName, ip := range config.GetStringMap("root.ipv6") {
        swg.Add()
        serverAddr := fmt.Sprintf("[%s]", ip.(string))
        go worker(*flDomain, serverName, serverAddr, resultsChannel, &swg, "root.ipv6")
    }

    swg.Wait()
    defer close(resultsChannel)

    f, err := openFile(*flResultProm)
    if err != nil {
        logger.Log("level", "error", "msg", err)
    }
    line1 := `# HELP authority_response_time
# TYPE authority_response_time Gauge`
    writeFile(line1, f)
    for _, r := range results {
        // if r.Rtt.Milliseconds() < 5000 {
        // 如果取消注释,则在prom文件中不记录超时数据
        line := fmt.Sprintf(`authority_response_time{server="%s", server_addr="%s", domain_type="%s", ip_type="%s"}  %d`,
            r.Servername, r.Serveraddr, r.DomainType, r.IPType, r.Rtt.Milliseconds())
        writeFile(line, f)
        // }
    }
    defer f.Close()
}

func main() {
    var (
        flDomain      = flag.String("domain", "", "The domain to perform fuzzing against.")
        flWorkerCount = flag.Int("c", 100, "The amount of worker to use")
        flResultProm  = flag.String("prom", "response.prom", "The response result to save")
        flInterval    = flag.Int("interval", 60, "The interval of worker to run")
        flConfig      = flag.String("config", "config.toml", "The config file")
    )

    flag.Parse()
    if *flDomain == "" {
        fmt.Println("domain are rquired, for example: cn")
        logger.Log("level", "error", "msg", "domain and wordlist are rquired")
        os.Exit(1)
    }
    if *flConfig == "" {
        fmt.Println("Cannot read config file " + *flConfig)
        logger.Log("level", "error", "msg", "Cannot read file "+*flConfig)
        os.Exit(2)
    }
    config := GetConfig(*flConfig)

    ticker := time.NewTicker(time.Duration(*flInterval) * time.Second)
    defer ticker.Stop()
    for range ticker.C {
        run(flDomain, flWorkerCount, flResultProm, config)
    }
}

config.go

package main

import (
    "fmt"

    "github.com/spf13/viper"
)

func GetConfig(configFile string) *viper.Viper {
    v := viper.New()
    v.SetConfigName(configFile)
    v.SetConfigType("toml")
    v.AddConfigPath(".")
    err := v.ReadInConfig() // 搜索并读取配置文件
    if err != nil {         // 处理错误
        panic(fmt.Errorf("Fatal error config file: %s \n", err))
    }
    return v
}

logger.go

package log

import (
    "os"
    "time"

    "github.com/go-kit/log"
)

var (
    // This timestamp format differs from RFC3339Nano by using .000 instead
    // of .999999999 which changes the timestamp from 9 variable to 3 fixed
    // decimals (.130 instead of .130987456).
    timestampFormat = log.TimestampFormat(
        func() time.Time { return time.Now() },
        "2006-01-02 15:04:05",
    )
)

// New returns a new leveled oklog logger. Each logged line will be annotated
// with a timestamp. The output always goes to stderr.
func NewLogger() log.Logger {
    var l log.Logger
    l = log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))
    l = log.With(l, "ts", timestampFormat, "caller", log.DefaultCaller)
    return l
}

config.toml

[cn]
[cn.ipv4]
"a.dns.cn." = "203.119.25.1"
"b.dns.cn." = "203.119.26.1"
"c.dns.cn." = "203.119.27.1"
"d.dns.cn." = "203.119.28.1"
"e.dns.cn." = "203.119.29.1"
"f.dns.cn." = "195.219.8.90"
"g.dns.cn." = "66.198.183.65"
"ns.cernet.net." = "202.112.0.44"
[cn.ipv6]
"a.dns.cn." = "2001:dc7:0:0:0:0:0:1"
"d.dns.cn." = "2001:dc7:1000:0:0:0:0:1"
[root]
[root.ipv4]
"a.root-servers.net." = "198.41.0.4"
"b.root-servers.net." = "199.9.14.201"
"c.root-servers.net." = "192.33.4.12"
"d.root-servers.net." = "199.7.91.13"
"e.root-servers.net." = "192.203.230.10"
"f.root-servers.net." = "192.5.5.241"
"g.root-servers.net." = "192.112.36.4"
"h.root-servers.net." = "198.97.190.53"
"i.root-servers.net." = "192.36.148.17"
"j.root-servers.net." = "192.58.128.30"
"k.root-servers.net." = "193.0.14.129"
"l.root-servers.net." = "199.7.83.42"
"m.root-servers.net." = "202.12.27.33"
[root.ipv6]
"a.root-servers.net." = "2001:503:ba3e::2:30"
"b.root-servers.net." = "2001:500:200::b"
"c.root-servers.net." = "2001:500:2::c"
"d.root-servers.net." = "2001:500:2d::d"
"e.root-servers.net." = "2001:500:a8::e"
"f.root-servers.net." = "2001:500:2f::f"
"g.root-servers.net." = "2001:500:12::d0d"
"h.root-servers.net." = "2001:500:1::53"
"i.root-servers.net." = "2001:7fe::53"
"j.root-servers.net." = "2001:503:c27::2:30"
"k.root-servers.net." = "2001:7fd::1"
"l.root-servers.net." = "2001:500:9f::42"
"m.root-servers.net." = "2001:dc3::35"