diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index 7de6325..1263861 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -15,9 +15,10 @@ type Crawler struct { mu *sync.Mutex concurrencyControl chan struct{} wg *sync.WaitGroup + maxPages int } -func NewCrawler(rawBaseURL string) (*Crawler, error) { +func NewCrawler(rawBaseURL string, maxConcurrency, maxPages int) (*Crawler, error) { baseURL, err := url.Parse(rawBaseURL) if err != nil { return nil, fmt.Errorf("unable to parse the base URL: %w", err) @@ -31,16 +32,15 @@ func NewCrawler(rawBaseURL string) (*Crawler, error) { pages: make(map[string]int), baseURL: baseURL, mu: &sync.Mutex{}, - concurrencyControl: make(chan struct{}, 2), + concurrencyControl: make(chan struct{}, maxConcurrency), wg: &waitGroup, + maxPages: maxPages, } return &crawler, nil } func (c *Crawler) Crawl(rawCurrentURL string) { - var err error - // Add an empty struct to channel here c.concurrencyControl <- struct{}{} @@ -51,6 +51,10 @@ func (c *Crawler) Crawl(rawCurrentURL string) { c.wg.Done() }() + if c.reachedMaxPages() { + return + } + // if current URL is not on the same domain as the base URL then return early. hasEqualDomain, err := c.HasEqualDomain(rawCurrentURL) if err != nil { @@ -147,9 +151,19 @@ func (c *Crawler) Wait() { } func (c *Crawler) PrintReport() { + c.mu.Lock() + defer c.mu.Unlock() + fmt.Printf("\n\nREPORT:\n") for page, count := range maps.All(c.pages) { fmt.Printf("%s: %d\n", page, count) } } + +func (c *Crawler) reachedMaxPages() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return len(c.pages) >= c.maxPages +} diff --git a/main.go b/main.go index 45689bd..41df9ce 100644 --- a/main.go +++ b/main.go @@ -1,18 +1,13 @@ package main import ( - "errors" "fmt" "os" + "strconv" "codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler" ) -var ( - errNoWebsiteProvided = errors.New("no website provided") - errTooManyArgs = errors.New("too many arguments provided") -) - func main() { if err := run(); err != nil { os.Stderr.WriteString("ERROR: " + err.Error() + "\n") @@ -24,17 +19,23 @@ func main() { func run() error { args := os.Args[1:] - if len(args) == 0 { - return errNoWebsiteProvided - } - - if len(args) > 1 { - return errTooManyArgs + if len(args) != 3 { + return fmt.Errorf("unexpected number of arguments received: want 3, got %d", len(args)) } baseURL := args[0] - c, err := crawler.NewCrawler(baseURL) + maxConcurrency, err := strconv.Atoi(args[1]) + if err != nil { + return fmt.Errorf("unable to convert the max concurrency (%s) to an integer: %w", args[1], err) + } + + maxPages, err := strconv.Atoi(args[2]) + if err != nil { + return fmt.Errorf("unable to convert the max pages (%s) to an integer: %w", args[2], err) + } + + c, err := crawler.NewCrawler(baseURL, maxConcurrency, maxPages) if err != nil { return fmt.Errorf("unable to create the crawler: %w", err) }