From a8a7bcaced425c385eed5b4497cf7b5da7706e76 Mon Sep 17 00:00:00 2001 From: Dan Anglin Date: Tue, 27 Aug 2024 13:11:16 +0100 Subject: [PATCH] fixed .gitignore; added concurrency --- .gitignore | 2 +- internal/crawler/crawler.go | 155 ++++++++++++++++++++++++++++ internal/crawler/crawler_test.go | 172 +++++++++++++++++++++++++++++++ internal/crawler/gethtml.go | 50 +++++++++ main.go | 27 +++-- 5 files changed, 390 insertions(+), 16 deletions(-) create mode 100644 internal/crawler/crawler.go create mode 100644 internal/crawler/crawler_test.go create mode 100644 internal/crawler/gethtml.go diff --git a/.gitignore b/.gitignore index 8bd4ac1..74d6f60 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -crawler +/crawler diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go new file mode 100644 index 0000000..7de6325 --- /dev/null +++ b/internal/crawler/crawler.go @@ -0,0 +1,155 @@ +package crawler + +import ( + "fmt" + "maps" + "net/url" + "sync" + + "codeflow.dananglin.me.uk/apollo/web-crawler/internal/util" +) + +type Crawler struct { + pages map[string]int + baseURL *url.URL + mu *sync.Mutex + concurrencyControl chan struct{} + wg *sync.WaitGroup +} + +func NewCrawler(rawBaseURL string) (*Crawler, error) { + baseURL, err := url.Parse(rawBaseURL) + if err != nil { + return nil, fmt.Errorf("unable to parse the base URL: %w", err) + } + + var waitGroup sync.WaitGroup + + waitGroup.Add(1) + + crawler := Crawler{ + pages: make(map[string]int), + baseURL: baseURL, + mu: &sync.Mutex{}, + concurrencyControl: make(chan struct{}, 2), + wg: &waitGroup, + } + + return &crawler, nil +} + +func (c *Crawler) Crawl(rawCurrentURL string) { + var err error + + // Add an empty struct to channel here + c.concurrencyControl <- struct{}{} + + // Decrement the wait group counter and free up the channel when finished + // crawling. + defer func() { + <-c.concurrencyControl + c.wg.Done() + }() + + // if current URL is not on the same domain as the base URL then return early. + hasEqualDomain, err := c.HasEqualDomain(rawCurrentURL) + if err != nil { + fmt.Printf( + "WARNING: Unable to determine if %q has the same domain as %q; %v.\n", + rawCurrentURL, + c.baseURL.Hostname(), + err, + ) + + return + } + + if !hasEqualDomain { + return + } + + // get normalised version of rawCurrentURL + normalisedCurrentURL, err := util.NormaliseURL(rawCurrentURL) + if err != nil { + fmt.Printf("WARNING: Error normalising %q: %v.\n", rawCurrentURL, err) + + return + } + + // Add (or update) a record of the URL in the pages map. + // If there's already an entry of the URL in the map then return early. + if existed := c.AddPageVisit(normalisedCurrentURL); existed { + return + } + + // Get the HTML from the current URL, print that you are getting the HTML doc from current URL. + fmt.Printf("Crawling %q\n", rawCurrentURL) + + htmlDoc, err := getHTML(rawCurrentURL) + if err != nil { + fmt.Printf( + "WARNING: Error retrieving the HTML document from %q: %v.\n", + rawCurrentURL, + err, + ) + + return + } + + // Get all the URLs from the HTML doc. + links, err := util.GetURLsFromHTML(htmlDoc, c.baseURL.String()) + if err != nil { + fmt.Printf( + "WARNING: Error retrieving the links from the HTML document: %v.\n", + err, + ) + + return + } + + // Recursively crawl each URL on the page. + for ind := range len(links) { + c.wg.Add(1) + go c.Crawl(links[ind]) + } +} + +func (c *Crawler) HasEqualDomain(rawURL string) (bool, error) { + parsedRawURL, err := url.Parse(rawURL) + if err != nil { + return false, fmt.Errorf("error parsing the URL %q: %w", rawURL, err) + } + + return c.baseURL.Hostname() == parsedRawURL.Hostname(), nil +} + +// addPageVisit adds a record of the visited page's URL to the pages map. +// If there is already a record of the URL then it's record is updated (incremented) +// and the method returns true. If the URL is not already recorded then it is created +// and the method returns false. +func (c *Crawler) AddPageVisit(normalisedURL string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + _, exists := c.pages[normalisedURL] + + if exists { + c.pages[normalisedURL]++ + } else { + c.pages[normalisedURL] = 1 + } + + return exists +} + +func (c *Crawler) Wait() { + c.wg.Wait() +} + +func (c *Crawler) PrintReport() { + fmt.Printf("\n\nREPORT:\n") + + for page, count := range maps.All(c.pages) { + fmt.Printf("%s: %d\n", page, count) + } +} diff --git a/internal/crawler/crawler_test.go b/internal/crawler/crawler_test.go new file mode 100644 index 0000000..91dc9ea --- /dev/null +++ b/internal/crawler/crawler_test.go @@ -0,0 +1,172 @@ +package crawler_test + +import ( + "fmt" + "slices" + "testing" + + "codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler" + "codeflow.dananglin.me.uk/apollo/web-crawler/internal/util" +) + +func TestCrawler(t *testing.T) { + testBaseURL := "https://example.com" + + testCrawler, err := crawler.NewCrawler(testBaseURL) + if err != nil { + t.Fatalf("Test 'TestCrawler' FAILED: unexpected error creating the crawler: %v", err) + } + + testCasesForEqualDomains := []struct { + name string + rawURL string + want bool + }{ + { + name: "Same domain", + rawURL: "https://example.com", + want: true, + }, + { + name: "Same domain, different path", + rawURL: "https://example.com/about/contact", + want: true, + }, + { + name: "Same domain, different protocol", + rawURL: "http://example.com", + want: true, + }, + { + name: "Different domain", + rawURL: "https://blog.person.me.uk", + want: false, + }, + { + name: "Different domain, same path", + rawURL: "https://example.org/blog", + want: false, + }, + } + + for ind, tc := range slices.All(testCasesForEqualDomains) { + t.Run(tc.name, testHasEqualDomains( + testCrawler, + ind+1, + tc.name, + tc.rawURL, + tc.want, + )) + } + + testCasesForPages := []struct { + rawURL string + wantVisited bool + }{ + { + rawURL: "https://example.com/tags/linux", + wantVisited: false, + }, + { + rawURL: "https://example.com/blog", + wantVisited: false, + }, + { + rawURL: "https://example.com/about/contact.html", + wantVisited: false, + }, + { + rawURL: "https://example.com/blog", + wantVisited: true, + }, + } + + for ind, tc := range slices.All(testCasesForPages) { + name := fmt.Sprintf("Adding %s to the pages map", tc.rawURL) + t.Run(name, testAddPageVisit( + testCrawler, + ind+1, + name, + tc.rawURL, + tc.wantVisited, + )) + } +} + +func testHasEqualDomains( + testCrawler *crawler.Crawler, + testNum int, + testName string, + rawURL string, + want bool, +) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + + got, err := testCrawler.HasEqualDomain(rawURL) + if err != nil { + t.Fatalf( + "Test %d - '%s' FAILED: unexpected error: %v", + testNum, + testName, + err, + ) + } + + if got != want { + t.Errorf( + "Test %d - '%s' FAILED: unexpected domain comparison received: want %t, got %t", + testNum, + testName, + want, + got, + ) + } else { + t.Logf( + "Test %d - '%s' PASSED: expected domain comparison received: got %t", + testNum, + testName, + got, + ) + } + } +} + +func testAddPageVisit( + testCrawler *crawler.Crawler, + testNum int, + testName string, + rawURL string, + wantVisited bool, +) func(t *testing.T) { + return func(t *testing.T) { + normalisedURL, err := util.NormaliseURL(rawURL) + if err != nil { + t.Fatalf( + "Test %d - '%s' FAILED: unexpected error: %v", + testNum, + testName, + err, + ) + } + + gotVisited := testCrawler.AddPageVisit(normalisedURL) + + if gotVisited != wantVisited { + t.Errorf( + "Test %d - '%s' FAILED: unexpected bool returned after updated pages record: want %t, got %t", + testNum, + testName, + wantVisited, + gotVisited, + ) + } else { + t.Logf( + "Test %d - '%s' PASSED: expected bool returned after updated pages record: got %t", + testNum, + testName, + gotVisited, + ) + } + } +} diff --git a/internal/crawler/gethtml.go b/internal/crawler/gethtml.go new file mode 100644 index 0000000..c475ae1 --- /dev/null +++ b/internal/crawler/gethtml.go @@ -0,0 +1,50 @@ +package crawler + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +func getHTML(rawURL string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10*time.Second)) + defer cancel() + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", fmt.Errorf("error creating the HTTP request: %w", err) + } + + client := http.Client{} + + resp, err := client.Do(request) + if err != nil { + return "", fmt.Errorf("error getting the response: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return "", fmt.Errorf( + "received a bad status from %s: (%d) %s", + rawURL, + resp.StatusCode, + resp.Status, + ) + } + + contentType := resp.Header.Get("content-type") + if !strings.Contains(contentType, "text/html") { + return "", fmt.Errorf("unexpected content type received: want text/html, got %s", contentType) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading the data from the response: %w", err) + } + + return string(data), nil +} diff --git a/main.go b/main.go index 2633e44..45689bd 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,9 @@ package main import ( "errors" "fmt" - "maps" "os" + + "codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler" ) var ( @@ -31,22 +32,18 @@ func run() error { return errTooManyArgs } - //baseURL := args[0] + baseURL := args[0] - pages := make(map[string]int) - - //var err error - - //pages, err = crawlPage(baseURL, baseURL, pages) - //if err != nil { - // return fmt.Errorf("received an error while crawling the website: %w", err) - //} - - fmt.Printf("\n\nRESULTS:\n") - - for page, count := range maps.All(pages) { - fmt.Printf("%s: %d\n", page, count) + c, err := crawler.NewCrawler(baseURL) + if err != nil { + return fmt.Errorf("unable to create the crawler: %w", err) } + go c.Crawl(baseURL) + + c.Wait() + + c.PrintReport() + return nil }