diff --git a/cmd/wp-go-static/commands/root.go b/cmd/wp-go-static/commands/root.go index f914ee2..9389909 100644 --- a/cmd/wp-go-static/commands/root.go +++ b/cmd/wp-go-static/commands/root.go @@ -31,12 +31,24 @@ func (c *URLCache) Add(url string) { // Get checks if a URL is in the cache func (c *URLCache) Get(url string) bool { c.mu.Lock() - defer c.mu.Unl - ock() + defer c.mu.Unlock() _, ok := c.urls[url] return ok } +type Scrape struct { + urlCache *URLCache + c *colly.Collector + domain string +} + +func NewScrape() *Scrape { + return &Scrape{ + urlCache: &URLCache{urls: make(map[string]bool)}, + c: colly.NewCollector(), + } +} + // Run ... func Run(args []string) error { RootCmd.SetArgs(args) @@ -56,11 +68,14 @@ func init() { RootCmd.PersistentFlags().String("dir", "dump", "directory to save downloaded files") RootCmd.PersistentFlags().String("url", "", "URL to scrape") RootCmd.PersistentFlags().String("cache", "", "Cache directory") + RootCmd.PersistentFlags().Bool("parallel", false, "Fetch in parallel") + RootCmd.PersistentFlags().Bool("images", true, "Download images") // Bind command-line flags to Viper viper.BindPFlag("dir", RootCmd.PersistentFlags().Lookup("dir")) viper.BindPFlag("url", RootCmd.PersistentFlags().Lookup("url")) viper.BindPFlag("cache", RootCmd.PersistentFlags().Lookup("cache")) + viper.BindPFlag("images", RootCmd.PersistentFlags().Lookup("images")) viper.AutomaticEnv() viper.EnvKeyReplacer(strings.NewReplacer("-", "_")) @@ -76,15 +91,20 @@ func rootCmdF(command *cobra.Command, args []string) error { commandDir := viper.GetString("dir") commandURL := viper.GetString("url") cacheDir := viper.GetString("cache") + parallel := viper.GetBool("parallel") + + scrape := NewScrape() - c := colly.NewCollector() + scrape.domain = commandURL if cacheDir != "" { - c.CacheDir = cacheDir + scrape.c.CacheDir = cacheDir } + scrape.c.Async = parallel + // Ignore SSL errors - c.WithTransport(&http.Transport{ + scrape.c.WithTransport(&http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }) @@ -95,89 +115,42 @@ func rootCmdF(command *cobra.Command, args []string) error { domain := parsedURL.Hostname() // Visit only pages that are part of the website - c.AllowedDomains = []string{domain} - - // Create URL cache - urlCache := &URLCache{urls: make(map[string]bool)} + scrape.c.AllowedDomains = []string{domain} // On every a element which has href attribute call callback - c.OnHTML("a[href]", func(e *colly.HTMLElement) { + scrape.c.OnHTML("a[href]", func(e *colly.HTMLElement) { link := e.Attr("href") - // Visit link found on page if it hasn't been visited before - if !urlCache.Get(link) { - urlCache.Add(link) - c.Visit(e.Request.AbsoluteURL(link)) - } + scrape.visitURL(e.Request.AbsoluteURL(link)) }) // On every link element call callback - c.OnHTML("link[href]", func(e *colly.HTMLElement) { + scrape.c.OnHTML("link[href]", func(e *colly.HTMLElement) { link := e.Attr("href") - // Download file found on page if it has a supported extension and hasn't been visited before - if !urlCache.Get(link) { - urlCache.Add(link) - c.Visit(e.Request.AbsoluteURL(link)) - } + scrape.visitURL(e.Request.AbsoluteURL(link)) }) // On every script element call callback - c.OnHTML("script[src]", func(e *colly.HTMLElement) { + scrape.c.OnHTML("script[src]", func(e *colly.HTMLElement) { link := e.Attr("src") - // Download file found on page if it has a supported extension and hasn't been visited before - if !urlCache.Get(link) { - urlCache.Add(link) - c.Visit(e.Request.AbsoluteURL(link)) - } + scrape.visitURL(e.Request.AbsoluteURL(link)) }) // On every img element call callback - c.OnHTML("img", func(e *colly.HTMLElement) { + + scrape.c.OnHTML("img", func(e *colly.HTMLElement) { link := e.Attr("src") - // Download image found on page if it hasn't been visited before - if !urlCache.Get(link) { - urlCache.Add(link) - c.Visit(e.Request.AbsoluteURL(link)) - } + scrape.visitURL(e.Request.AbsoluteURL(link)) }) // Before making a request print "Visiting ..." - c.OnRequest(func(r *colly.Request) { + scrape.c.OnRequest(func(r *colly.Request) { fmt.Println("Visiting", r.URL.String()) }) // On response - c.OnResponse(func(r *colly.Response) { + scrape.c.OnResponse(func(r *colly.Response) { dir, fileName := file.HandleFile(r, commandDir) - - // Find all URLs in the CSS file - cssUrls := regexp.MustCompile(`url\((https?://[^\s]+)\)`).FindAllStringSubmatch(string(r.Body), -1) - - // Download each referenced file if it hasn't been visited before - for _, cssUrl := range cssUrls { - url := strings.Trim(cssUrl[1], "'\"") - if url == "" { - continue - } - if !urlCache.Get(url) { - urlCache.Add(url) - fmt.Printf("Visiting from CSS: '%s'\n", url) - c.Visit(url) - } - } - - optionList := []string{ - fmt.Sprintf(`http://%s`, domain), - fmt.Sprintf(`http:\/\/%s`, domain), - fmt.Sprintf(`https://%s`, domain), - fmt.Sprintf(`https:\/\/%s`, domain), - domain, - } - - for _, v := range optionList { - // Replace all occurrences of the base URL with a relative URL - replaceBody := strings.ReplaceAll(string(r.Body), v, "") - r.Body = []byte(replaceBody) - } + r.Body = scrape.parseBody(r.Body) err := file.SaveFile(r, dir, fileName) if err != nil { @@ -187,5 +160,51 @@ func rootCmdF(command *cobra.Command, args []string) error { }) // Start scraping - return c.Visit(commandURL) + err = scrape.c.Visit(commandURL) + + if err != nil { + return err + } + + scrape.c.Wait() + + return nil +} + +func (s *Scrape) visitURL(link string) { + // Download image found on page if it hasn't been visited before + if !s.urlCache.Get(link) { + s.urlCache.Add(link) + s.c.Visit(link) + } +} + +func (s *Scrape) parseBody(body []byte) []byte { + // Find all URLs in the CSS file + cssUrls := regexp.MustCompile(`url\((https?://[^\s]+)\)`).FindAllStringSubmatch(string(body), -1) + + // Download each referenced file if it hasn't been visited before + for _, cssUrl := range cssUrls { + link := strings.Trim(cssUrl[1], "'\"") + if link == "" { + continue + } + s.visitURL(link) + } + + optionList := []string{ + fmt.Sprintf(`http://%s`, s.domain), + fmt.Sprintf(`http:\/\/%s`, s.domain), + fmt.Sprintf(`https://%s`, s.domain), + fmt.Sprintf(`https:\/\/%s`, s.domain), + s.domain, + } + + for _, v := range optionList { + // Replace all occurrences of the base URL with a relative URL + replaceBody := strings.ReplaceAll(string(body), v, "") + body = []byte(replaceBody) + } + + return body }