本系列整理了10个工作量和难度适中的Golang小项目,适合已经掌握Go语法的工程师进一步熟练语法和常用库的用法。

问题描述:
实现一个网络爬虫,以输入的URL为起点,使用广度优先顺序访问页面。

要点:

  • 实现对多个页面的并发访问,同时访问的页面数由参数 -concurrency 指定,默认为 20。
    使用 -depth 指定访问的页面深度,默认为 3。
    注意已经访问过的页面不要重复访问。

扩展:

  • 将访问到的页面写入到本地以创建目标网站的本地镜像,注意,只有指定域名下的页面需要采集,写入本地的页面里的元素的href的值需要被修改为指向镜像页面,而不是原始页面。

实现:

import (
   "bytes"
   "flag"
   "fmt"
   "golang.org/x/net/html"
   "io"
   "log"
   "net/http"
   "net/url"
   "os"
   "path/filepath"
   "strings"
   "sync"
   "time"
)

type URLInfo struct {
   url string
   depth int
}

var base *url.URL

func forEachNode(n *html.Node, pre, post func(n *html.Node)){
   if pre != nil{
      pre(n)
   }

   for c := n.FirstChild; c != nil; c = c.NextSibling{
      forEachNode(c, pre, post)
   }

   if post != nil{
      post(n)
   }
}

func linkNodes(n *html.Node) []*html.Node {
   var links []*html.Node
   visitNode := func(n *html.Node) {
      if n.Type == html.ElementNode && n.Data == "a" {
         links = append(links, n)
      }
   }
   forEachNode(n, visitNode, nil)
   return links
}

func linkURLs(linkNodes []*html.Node, base *url.URL) []string {
   var urls []string
   for _, n := range linkNodes {
      for _, a := range n.Attr {
         if a.Key != "href" {
            continue
         }
         link, err := base.Parse(a.Val)

         // ignore bad and non-local URLs
         if err != nil {
            log.Printf("skipping %q: %s", a.Val, err)
            continue
         }
         if link.Host != base.Host {

            //log.Printf("skipping %q: non-local host", a.Val)
            continue
         }

         if strings.HasPrefix(link.String(), "javascript"){
            continue
         }

         urls = append(urls, link.String())
      }
   }
   return urls
}

func rewriteLocalLinks(linkNodes []*html.Node, base *url.URL) {
   for _, n := range linkNodes {
      for i, a := range n.Attr {
         if a.Key != "href" {
            continue
         }
         link, err := base.Parse(a.Val)
         if err != nil || link.Host != base.Host {
            continue // ignore bad and non-local URLs
         }

         link.Scheme = ""
         link.Host = ""
         link.User = nil
         a.Val = link.String()

         n.Attr[i] = a
      }
   }
}

func Extract(url string)(urls []string, err error){

   timeout := time.Duration(10 * time.Second)
   client := http.Client{
      Timeout: timeout,
   }

   resp, err := client.Get(url)

   if err != nil{
      fmt.Println(err)
      return nil, err
   }

   if resp.StatusCode != http.StatusOK{
      resp.Body.Close()
      return nil, fmt.Errorf("getting %s:%s", url, resp.StatusCode)
   }


   if err != nil{
      return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)
   }

   u, err := base.Parse(url)
   if err != nil {
      return nil, err
   }
   if base.Host != u.Host {
      log.Printf("not saving %s: non-local", url)
      return nil, nil
   }

   var body io.Reader
   contentType := resp.Header["Content-Type"]
   if strings.Contains(strings.Join(contentType, ","), "text/html") {
      doc, err := html.Parse(resp.Body)
      resp.Body.Close()

      if err != nil {
         return nil, fmt.Errorf("parsing %s as HTML: %v", u, err)
      }

      nodes := linkNodes(doc)

      urls = linkURLs(nodes, u)

      rewriteLocalLinks(nodes, u)


      b := &bytes.Buffer{}
      err = html.Render(b, doc)
      if err != nil {
         log.Printf("render %s: %s", u, err)
      }
      body = b
   }

   err = save(resp, body)

   return urls, err
}

func crawl(url string) []string{
   list, err := Extract(url)

   if err != nil{
      log.Print(err)
   }

   return list
}

func save(resp *http.Response, body io.Reader) error {
   u := resp.Request.URL

   filename := filepath.Join(u.Host, u.Path)

   if filepath.Ext(u.Path) == "" {
      filename = filepath.Join(u.Host, u.Path, "index.html")
   }

   err := os.MkdirAll(filepath.Dir(filename), 0777)

   if err != nil {
      return err
   }

   fmt.Println("filename:", filename)

   file, err := os.Create(filename)

   if err != nil {
      return err
   }

   if body != nil {
      _, err = io.Copy(file, body)
   } else {
      _, err = io.Copy(file, resp.Body)
   }


   if err != nil {
      log.Print("save: ", err)
   }

   err = file.Close()

   if err != nil {
      log.Print("save: ", err)
   }

   return nil
}



func parallellyCrawl(initialLinks string, concurrency, depth int){
   worklist := make(chan []URLInfo, 1)
   unseenLinks := make(chan URLInfo, 1)

   //值为1时表示进入unseenLinks队列,值为2时表示crawl完成
   seen := make(map[string] int)
   seenLock := sync.Mutex{}

   var urlInfos []URLInfo

   for _, url := range strings.Split(initialLinks, " "){
      urlInfos = append(urlInfos, URLInfo{url, 1})
   }

   go func() {worklist <- urlInfos}()

   go func() {
      for{
         time.Sleep(1 * time.Second)

         seenFlag := true

         seenLock.Lock()

         for k := range seen{
            if seen[k] == 1{
               seenFlag = false
            }
         }

         seenLock.Unlock()

         if seenFlag && len(worklist) == 0{
            close(unseenLinks)
            close(worklist)

            break
         }

      }
   }()

   for i := 0; i < concurrency; i++{
      go func() {
         for link := range unseenLinks{
            foundLinks := crawl(link.url)


            var urlInfos []URLInfo

            for _, u := range foundLinks{
               urlInfos = append(urlInfos, URLInfo{u, link.depth + 1})
            }

            go func(finishedUrl string) {

               worklist <- urlInfos

               seenLock.Lock()
               seen[finishedUrl] = 2
               seenLock.Unlock()
            }(link.url)


         }
      }()
   }

   for list := range worklist{
      for _, link := range list {

         if link.depth > depth{
            continue
         }

         seenLock.Lock()
         _, ok := seen[link.url]
         seenLock.Unlock()

         if !ok{

            seenLock.Lock()
            seen[link.url] = 1
            seenLock.Unlock()

            unseenLinks <- link
         }
      }
   }

   fmt.Printf("共访问了%d个页面", len(seen))
}

func main() {
   var maxDepth int
   var concurrency int
   var initialLink string

   flag.IntVar(&maxDepth, "d", 3, "max crawl depth")
   flag.IntVar(&concurrency, "c", 20, "number of crawl goroutines")
   flag.StringVar(&initialLink, "u", "", "initial link")

   flag.Parse()

   u, err := url.Parse(initialLink)
   if err != nil {
      fmt.Fprintf(os.Stderr, "invalid url: %s\n", err)
      os.Exit(1)
   }

   base = u

   parallellyCrawl(initialLink, concurrency, maxDepth)
}