fixed .gitignore; added concurrency

This commit is contained in:
Dan Anglin 2024-08-27 13:11:16 +01:00
parent b76aeba7b5
commit a8a7bcaced
Signed by: dananglin
GPG key ID: 0C1D44CFBEE68638
5 changed files with 390 additions and 16 deletions

2
.gitignore vendored
View file

@ -1 +1 @@
crawler
/crawler

155
internal/crawler/crawler.go Normal file
View file

@ -0,0 +1,155 @@
package crawler
import (
"fmt"
"maps"
"net/url"
"sync"
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/util"
)
type Crawler struct {
pages map[string]int
baseURL *url.URL
mu *sync.Mutex
concurrencyControl chan struct{}
wg *sync.WaitGroup
}
func NewCrawler(rawBaseURL string) (*Crawler, error) {
baseURL, err := url.Parse(rawBaseURL)
if err != nil {
return nil, fmt.Errorf("unable to parse the base URL: %w", err)
}
var waitGroup sync.WaitGroup
waitGroup.Add(1)
crawler := Crawler{
pages: make(map[string]int),
baseURL: baseURL,
mu: &sync.Mutex{},
concurrencyControl: make(chan struct{}, 2),
wg: &waitGroup,
}
return &crawler, nil
}
func (c *Crawler) Crawl(rawCurrentURL string) {
var err error
// Add an empty struct to channel here
c.concurrencyControl <- struct{}{}
// Decrement the wait group counter and free up the channel when finished
// crawling.
defer func() {
<-c.concurrencyControl
c.wg.Done()
}()
// if current URL is not on the same domain as the base URL then return early.
hasEqualDomain, err := c.HasEqualDomain(rawCurrentURL)
if err != nil {
fmt.Printf(
"WARNING: Unable to determine if %q has the same domain as %q; %v.\n",
rawCurrentURL,
c.baseURL.Hostname(),
err,
)
return
}
if !hasEqualDomain {
return
}
// get normalised version of rawCurrentURL
normalisedCurrentURL, err := util.NormaliseURL(rawCurrentURL)
if err != nil {
fmt.Printf("WARNING: Error normalising %q: %v.\n", rawCurrentURL, err)
return
}
// Add (or update) a record of the URL in the pages map.
// If there's already an entry of the URL in the map then return early.
if existed := c.AddPageVisit(normalisedCurrentURL); existed {
return
}
// Get the HTML from the current URL, print that you are getting the HTML doc from current URL.
fmt.Printf("Crawling %q\n", rawCurrentURL)
htmlDoc, err := getHTML(rawCurrentURL)
if err != nil {
fmt.Printf(
"WARNING: Error retrieving the HTML document from %q: %v.\n",
rawCurrentURL,
err,
)
return
}
// Get all the URLs from the HTML doc.
links, err := util.GetURLsFromHTML(htmlDoc, c.baseURL.String())
if err != nil {
fmt.Printf(
"WARNING: Error retrieving the links from the HTML document: %v.\n",
err,
)
return
}
// Recursively crawl each URL on the page.
for ind := range len(links) {
c.wg.Add(1)
go c.Crawl(links[ind])
}
}
func (c *Crawler) HasEqualDomain(rawURL string) (bool, error) {
parsedRawURL, err := url.Parse(rawURL)
if err != nil {
return false, fmt.Errorf("error parsing the URL %q: %w", rawURL, err)
}
return c.baseURL.Hostname() == parsedRawURL.Hostname(), nil
}
// addPageVisit adds a record of the visited page's URL to the pages map.
// If there is already a record of the URL then it's record is updated (incremented)
// and the method returns true. If the URL is not already recorded then it is created
// and the method returns false.
func (c *Crawler) AddPageVisit(normalisedURL string) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.pages[normalisedURL]
if exists {
c.pages[normalisedURL]++
} else {
c.pages[normalisedURL] = 1
}
return exists
}
func (c *Crawler) Wait() {
c.wg.Wait()
}
func (c *Crawler) PrintReport() {
fmt.Printf("\n\nREPORT:\n")
for page, count := range maps.All(c.pages) {
fmt.Printf("%s: %d\n", page, count)
}
}

View file

@ -0,0 +1,172 @@
package crawler_test
import (
"fmt"
"slices"
"testing"
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler"
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/util"
)
func TestCrawler(t *testing.T) {
testBaseURL := "https://example.com"
testCrawler, err := crawler.NewCrawler(testBaseURL)
if err != nil {
t.Fatalf("Test 'TestCrawler' FAILED: unexpected error creating the crawler: %v", err)
}
testCasesForEqualDomains := []struct {
name string
rawURL string
want bool
}{
{
name: "Same domain",
rawURL: "https://example.com",
want: true,
},
{
name: "Same domain, different path",
rawURL: "https://example.com/about/contact",
want: true,
},
{
name: "Same domain, different protocol",
rawURL: "http://example.com",
want: true,
},
{
name: "Different domain",
rawURL: "https://blog.person.me.uk",
want: false,
},
{
name: "Different domain, same path",
rawURL: "https://example.org/blog",
want: false,
},
}
for ind, tc := range slices.All(testCasesForEqualDomains) {
t.Run(tc.name, testHasEqualDomains(
testCrawler,
ind+1,
tc.name,
tc.rawURL,
tc.want,
))
}
testCasesForPages := []struct {
rawURL string
wantVisited bool
}{
{
rawURL: "https://example.com/tags/linux",
wantVisited: false,
},
{
rawURL: "https://example.com/blog",
wantVisited: false,
},
{
rawURL: "https://example.com/about/contact.html",
wantVisited: false,
},
{
rawURL: "https://example.com/blog",
wantVisited: true,
},
}
for ind, tc := range slices.All(testCasesForPages) {
name := fmt.Sprintf("Adding %s to the pages map", tc.rawURL)
t.Run(name, testAddPageVisit(
testCrawler,
ind+1,
name,
tc.rawURL,
tc.wantVisited,
))
}
}
func testHasEqualDomains(
testCrawler *crawler.Crawler,
testNum int,
testName string,
rawURL string,
want bool,
) func(t *testing.T) {
return func(t *testing.T) {
t.Parallel()
got, err := testCrawler.HasEqualDomain(rawURL)
if err != nil {
t.Fatalf(
"Test %d - '%s' FAILED: unexpected error: %v",
testNum,
testName,
err,
)
}
if got != want {
t.Errorf(
"Test %d - '%s' FAILED: unexpected domain comparison received: want %t, got %t",
testNum,
testName,
want,
got,
)
} else {
t.Logf(
"Test %d - '%s' PASSED: expected domain comparison received: got %t",
testNum,
testName,
got,
)
}
}
}
func testAddPageVisit(
testCrawler *crawler.Crawler,
testNum int,
testName string,
rawURL string,
wantVisited bool,
) func(t *testing.T) {
return func(t *testing.T) {
normalisedURL, err := util.NormaliseURL(rawURL)
if err != nil {
t.Fatalf(
"Test %d - '%s' FAILED: unexpected error: %v",
testNum,
testName,
err,
)
}
gotVisited := testCrawler.AddPageVisit(normalisedURL)
if gotVisited != wantVisited {
t.Errorf(
"Test %d - '%s' FAILED: unexpected bool returned after updated pages record: want %t, got %t",
testNum,
testName,
wantVisited,
gotVisited,
)
} else {
t.Logf(
"Test %d - '%s' PASSED: expected bool returned after updated pages record: got %t",
testNum,
testName,
gotVisited,
)
}
}
}

View file

@ -0,0 +1,50 @@
package crawler
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
func getHTML(rawURL string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10*time.Second))
defer cancel()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", fmt.Errorf("error creating the HTTP request: %w", err)
}
client := http.Client{}
resp, err := client.Do(request)
if err != nil {
return "", fmt.Errorf("error getting the response: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return "", fmt.Errorf(
"received a bad status from %s: (%d) %s",
rawURL,
resp.StatusCode,
resp.Status,
)
}
contentType := resp.Header.Get("content-type")
if !strings.Contains(contentType, "text/html") {
return "", fmt.Errorf("unexpected content type received: want text/html, got %s", contentType)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading the data from the response: %w", err)
}
return string(data), nil
}

27
main.go
View file

@ -3,8 +3,9 @@ package main
import (
"errors"
"fmt"
"maps"
"os"
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler"
)
var (
@ -31,22 +32,18 @@ func run() error {
return errTooManyArgs
}
//baseURL := args[0]
baseURL := args[0]
pages := make(map[string]int)
//var err error
//pages, err = crawlPage(baseURL, baseURL, pages)
//if err != nil {
// return fmt.Errorf("received an error while crawling the website: %w", err)
//}
fmt.Printf("\n\nRESULTS:\n")
for page, count := range maps.All(pages) {
fmt.Printf("%s: %d\n", page, count)
c, err := crawler.NewCrawler(baseURL)
if err != nil {
return fmt.Errorf("unable to create the crawler: %w", err)
}
go c.Crawl(baseURL)
c.Wait()
c.PrintReport()
return nil
}