Skip to content

Commit

Permalink
feat: add paralell download
Browse files Browse the repository at this point in the history
  • Loading branch information
marcotuna committed Jul 24, 2023
1 parent b38de07 commit 0f1477d
Showing 1 changed file with 85 additions and 66 deletions.
151 changes: 85 additions & 66 deletions cmd/wp-go-static/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,24 @@ func (c *URLCache) Add(url string) {
// Get checks if a URL is in the cache
func (c *URLCache) Get(url string) bool {
c.mu.Lock()
defer c.mu.Unl
ock()
defer c.mu.Unlock()
_, ok := c.urls[url]
return ok
}

type Scrape struct {
urlCache *URLCache
c *colly.Collector
domain string
}

func NewScrape() *Scrape {
return &Scrape{
urlCache: &URLCache{urls: make(map[string]bool)},
c: colly.NewCollector(),
}
}

// Run ...
func Run(args []string) error {
RootCmd.SetArgs(args)
Expand All @@ -56,11 +68,14 @@ func init() {
RootCmd.PersistentFlags().String("dir", "dump", "directory to save downloaded files")
RootCmd.PersistentFlags().String("url", "", "URL to scrape")
RootCmd.PersistentFlags().String("cache", "", "Cache directory")
RootCmd.PersistentFlags().Bool("parallel", false, "Fetch in parallel")
RootCmd.PersistentFlags().Bool("images", true, "Download images")

// Bind command-line flags to Viper
viper.BindPFlag("dir", RootCmd.PersistentFlags().Lookup("dir"))
viper.BindPFlag("url", RootCmd.PersistentFlags().Lookup("url"))
viper.BindPFlag("cache", RootCmd.PersistentFlags().Lookup("cache"))
viper.BindPFlag("images", RootCmd.PersistentFlags().Lookup("images"))

viper.AutomaticEnv()
viper.EnvKeyReplacer(strings.NewReplacer("-", "_"))
Expand All @@ -76,15 +91,20 @@ func rootCmdF(command *cobra.Command, args []string) error {
commandDir := viper.GetString("dir")
commandURL := viper.GetString("url")
cacheDir := viper.GetString("cache")
parallel := viper.GetBool("parallel")

scrape := NewScrape()

c := colly.NewCollector()
scrape.domain = commandURL

if cacheDir != "" {
c.CacheDir = cacheDir
scrape.c.CacheDir = cacheDir
}

scrape.c.Async = parallel

// Ignore SSL errors
c.WithTransport(&http.Transport{
scrape.c.WithTransport(&http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
})

Expand All @@ -95,89 +115,42 @@ func rootCmdF(command *cobra.Command, args []string) error {
domain := parsedURL.Hostname()

// Visit only pages that are part of the website
c.AllowedDomains = []string{domain}

// Create URL cache
urlCache := &URLCache{urls: make(map[string]bool)}
scrape.c.AllowedDomains = []string{domain}

// On every a element which has href attribute call callback
c.OnHTML("a[href]", func(e *colly.HTMLElement) {
scrape.c.OnHTML("a[href]", func(e *colly.HTMLElement) {
link := e.Attr("href")
// Visit link found on page if it hasn't been visited before
if !urlCache.Get(link) {
urlCache.Add(link)
c.Visit(e.Request.AbsoluteURL(link))
}
scrape.visitURL(e.Request.AbsoluteURL(link))
})

// On every link element call callback
c.OnHTML("link[href]", func(e *colly.HTMLElement) {
scrape.c.OnHTML("link[href]", func(e *colly.HTMLElement) {
link := e.Attr("href")
// Download file found on page if it has a supported extension and hasn't been visited before
if !urlCache.Get(link) {
urlCache.Add(link)
c.Visit(e.Request.AbsoluteURL(link))
}
scrape.visitURL(e.Request.AbsoluteURL(link))
})

// On every script element call callback
c.OnHTML("script[src]", func(e *colly.HTMLElement) {
scrape.c.OnHTML("script[src]", func(e *colly.HTMLElement) {
link := e.Attr("src")
// Download file found on page if it has a supported extension and hasn't been visited before
if !urlCache.Get(link) {
urlCache.Add(link)
c.Visit(e.Request.AbsoluteURL(link))
}
scrape.visitURL(e.Request.AbsoluteURL(link))
})

// On every img element call callback
c.OnHTML("img", func(e *colly.HTMLElement) {

scrape.c.OnHTML("img", func(e *colly.HTMLElement) {
link := e.Attr("src")
// Download image found on page if it hasn't been visited before
if !urlCache.Get(link) {
urlCache.Add(link)
c.Visit(e.Request.AbsoluteURL(link))
}
scrape.visitURL(e.Request.AbsoluteURL(link))
})

// Before making a request print "Visiting ..."
c.OnRequest(func(r *colly.Request) {
scrape.c.OnRequest(func(r *colly.Request) {
fmt.Println("Visiting", r.URL.String())
})

// On response
c.OnResponse(func(r *colly.Response) {
scrape.c.OnResponse(func(r *colly.Response) {
dir, fileName := file.HandleFile(r, commandDir)

// Find all URLs in the CSS file
cssUrls := regexp.MustCompile(`url\((https?://[^\s]+)\)`).FindAllStringSubmatch(string(r.Body), -1)

// Download each referenced file if it hasn't been visited before
for _, cssUrl := range cssUrls {
url := strings.Trim(cssUrl[1], "'\"")
if url == "" {
continue
}
if !urlCache.Get(url) {
urlCache.Add(url)
fmt.Printf("Visiting from CSS: '%s'\n", url)
c.Visit(url)
}
}

optionList := []string{
fmt.Sprintf(`http://%s`, domain),
fmt.Sprintf(`http:\/\/%s`, domain),
fmt.Sprintf(`https://%s`, domain),
fmt.Sprintf(`https:\/\/%s`, domain),
domain,
}

for _, v := range optionList {
// Replace all occurrences of the base URL with a relative URL
replaceBody := strings.ReplaceAll(string(r.Body), v, "")
r.Body = []byte(replaceBody)
}
r.Body = scrape.parseBody(r.Body)

err := file.SaveFile(r, dir, fileName)
if err != nil {
Expand All @@ -187,5 +160,51 @@ func rootCmdF(command *cobra.Command, args []string) error {
})

// Start scraping
return c.Visit(commandURL)
err = scrape.c.Visit(commandURL)

if err != nil {
return err
}

scrape.c.Wait()

return nil
}

func (s *Scrape) visitURL(link string) {
// Download image found on page if it hasn't been visited before
if !s.urlCache.Get(link) {
s.urlCache.Add(link)
s.c.Visit(link)
}
}

func (s *Scrape) parseBody(body []byte) []byte {
// Find all URLs in the CSS file
cssUrls := regexp.MustCompile(`url\((https?://[^\s]+)\)`).FindAllStringSubmatch(string(body), -1)

// Download each referenced file if it hasn't been visited before
for _, cssUrl := range cssUrls {
link := strings.Trim(cssUrl[1], "'\"")
if link == "" {
continue
}
s.visitURL(link)
}

optionList := []string{
fmt.Sprintf(`http://%s`, s.domain),
fmt.Sprintf(`http:\/\/%s`, s.domain),
fmt.Sprintf(`https://%s`, s.domain),
fmt.Sprintf(`https:\/\/%s`, s.domain),
s.domain,
}

for _, v := range optionList {
// Replace all occurrences of the base URL with a relative URL
replaceBody := strings.ReplaceAll(string(body), v, "")
body = []byte(replaceBody)
}

return body
}

0 comments on commit 0f1477d

Please sign in to comment.