generated from templates/go-generic
fixed .gitignore; added concurrency
This commit is contained in:
parent
b76aeba7b5
commit
a8a7bcaced
5 changed files with 390 additions and 16 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1 @@
|
|||
crawler
|
||||
/crawler
|
||||
|
|
155
internal/crawler/crawler.go
Normal file
155
internal/crawler/crawler.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package crawler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/util"
|
||||
)
|
||||
|
||||
type Crawler struct {
|
||||
pages map[string]int
|
||||
baseURL *url.URL
|
||||
mu *sync.Mutex
|
||||
concurrencyControl chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewCrawler(rawBaseURL string) (*Crawler, error) {
|
||||
baseURL, err := url.Parse(rawBaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse the base URL: %w", err)
|
||||
}
|
||||
|
||||
var waitGroup sync.WaitGroup
|
||||
|
||||
waitGroup.Add(1)
|
||||
|
||||
crawler := Crawler{
|
||||
pages: make(map[string]int),
|
||||
baseURL: baseURL,
|
||||
mu: &sync.Mutex{},
|
||||
concurrencyControl: make(chan struct{}, 2),
|
||||
wg: &waitGroup,
|
||||
}
|
||||
|
||||
return &crawler, nil
|
||||
}
|
||||
|
||||
func (c *Crawler) Crawl(rawCurrentURL string) {
|
||||
var err error
|
||||
|
||||
// Add an empty struct to channel here
|
||||
c.concurrencyControl <- struct{}{}
|
||||
|
||||
// Decrement the wait group counter and free up the channel when finished
|
||||
// crawling.
|
||||
defer func() {
|
||||
<-c.concurrencyControl
|
||||
c.wg.Done()
|
||||
}()
|
||||
|
||||
// if current URL is not on the same domain as the base URL then return early.
|
||||
hasEqualDomain, err := c.HasEqualDomain(rawCurrentURL)
|
||||
if err != nil {
|
||||
fmt.Printf(
|
||||
"WARNING: Unable to determine if %q has the same domain as %q; %v.\n",
|
||||
rawCurrentURL,
|
||||
c.baseURL.Hostname(),
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !hasEqualDomain {
|
||||
return
|
||||
}
|
||||
|
||||
// get normalised version of rawCurrentURL
|
||||
normalisedCurrentURL, err := util.NormaliseURL(rawCurrentURL)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING: Error normalising %q: %v.\n", rawCurrentURL, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Add (or update) a record of the URL in the pages map.
|
||||
// If there's already an entry of the URL in the map then return early.
|
||||
if existed := c.AddPageVisit(normalisedCurrentURL); existed {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the HTML from the current URL, print that you are getting the HTML doc from current URL.
|
||||
fmt.Printf("Crawling %q\n", rawCurrentURL)
|
||||
|
||||
htmlDoc, err := getHTML(rawCurrentURL)
|
||||
if err != nil {
|
||||
fmt.Printf(
|
||||
"WARNING: Error retrieving the HTML document from %q: %v.\n",
|
||||
rawCurrentURL,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Get all the URLs from the HTML doc.
|
||||
links, err := util.GetURLsFromHTML(htmlDoc, c.baseURL.String())
|
||||
if err != nil {
|
||||
fmt.Printf(
|
||||
"WARNING: Error retrieving the links from the HTML document: %v.\n",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Recursively crawl each URL on the page.
|
||||
for ind := range len(links) {
|
||||
c.wg.Add(1)
|
||||
go c.Crawl(links[ind])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Crawler) HasEqualDomain(rawURL string) (bool, error) {
|
||||
parsedRawURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error parsing the URL %q: %w", rawURL, err)
|
||||
}
|
||||
|
||||
return c.baseURL.Hostname() == parsedRawURL.Hostname(), nil
|
||||
}
|
||||
|
||||
// addPageVisit adds a record of the visited page's URL to the pages map.
|
||||
// If there is already a record of the URL then it's record is updated (incremented)
|
||||
// and the method returns true. If the URL is not already recorded then it is created
|
||||
// and the method returns false.
|
||||
func (c *Crawler) AddPageVisit(normalisedURL string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, exists := c.pages[normalisedURL]
|
||||
|
||||
if exists {
|
||||
c.pages[normalisedURL]++
|
||||
} else {
|
||||
c.pages[normalisedURL] = 1
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
func (c *Crawler) Wait() {
|
||||
c.wg.Wait()
|
||||
}
|
||||
|
||||
func (c *Crawler) PrintReport() {
|
||||
fmt.Printf("\n\nREPORT:\n")
|
||||
|
||||
for page, count := range maps.All(c.pages) {
|
||||
fmt.Printf("%s: %d\n", page, count)
|
||||
}
|
||||
}
|
172
internal/crawler/crawler_test.go
Normal file
172
internal/crawler/crawler_test.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package crawler_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler"
|
||||
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/util"
|
||||
)
|
||||
|
||||
func TestCrawler(t *testing.T) {
|
||||
testBaseURL := "https://example.com"
|
||||
|
||||
testCrawler, err := crawler.NewCrawler(testBaseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Test 'TestCrawler' FAILED: unexpected error creating the crawler: %v", err)
|
||||
}
|
||||
|
||||
testCasesForEqualDomains := []struct {
|
||||
name string
|
||||
rawURL string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Same domain",
|
||||
rawURL: "https://example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Same domain, different path",
|
||||
rawURL: "https://example.com/about/contact",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Same domain, different protocol",
|
||||
rawURL: "http://example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Different domain",
|
||||
rawURL: "https://blog.person.me.uk",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different domain, same path",
|
||||
rawURL: "https://example.org/blog",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for ind, tc := range slices.All(testCasesForEqualDomains) {
|
||||
t.Run(tc.name, testHasEqualDomains(
|
||||
testCrawler,
|
||||
ind+1,
|
||||
tc.name,
|
||||
tc.rawURL,
|
||||
tc.want,
|
||||
))
|
||||
}
|
||||
|
||||
testCasesForPages := []struct {
|
||||
rawURL string
|
||||
wantVisited bool
|
||||
}{
|
||||
{
|
||||
rawURL: "https://example.com/tags/linux",
|
||||
wantVisited: false,
|
||||
},
|
||||
{
|
||||
rawURL: "https://example.com/blog",
|
||||
wantVisited: false,
|
||||
},
|
||||
{
|
||||
rawURL: "https://example.com/about/contact.html",
|
||||
wantVisited: false,
|
||||
},
|
||||
{
|
||||
rawURL: "https://example.com/blog",
|
||||
wantVisited: true,
|
||||
},
|
||||
}
|
||||
|
||||
for ind, tc := range slices.All(testCasesForPages) {
|
||||
name := fmt.Sprintf("Adding %s to the pages map", tc.rawURL)
|
||||
t.Run(name, testAddPageVisit(
|
||||
testCrawler,
|
||||
ind+1,
|
||||
name,
|
||||
tc.rawURL,
|
||||
tc.wantVisited,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func testHasEqualDomains(
|
||||
testCrawler *crawler.Crawler,
|
||||
testNum int,
|
||||
testName string,
|
||||
rawURL string,
|
||||
want bool,
|
||||
) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := testCrawler.HasEqualDomain(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf(
|
||||
"Test %d - '%s' FAILED: unexpected error: %v",
|
||||
testNum,
|
||||
testName,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if got != want {
|
||||
t.Errorf(
|
||||
"Test %d - '%s' FAILED: unexpected domain comparison received: want %t, got %t",
|
||||
testNum,
|
||||
testName,
|
||||
want,
|
||||
got,
|
||||
)
|
||||
} else {
|
||||
t.Logf(
|
||||
"Test %d - '%s' PASSED: expected domain comparison received: got %t",
|
||||
testNum,
|
||||
testName,
|
||||
got,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAddPageVisit(
|
||||
testCrawler *crawler.Crawler,
|
||||
testNum int,
|
||||
testName string,
|
||||
rawURL string,
|
||||
wantVisited bool,
|
||||
) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
normalisedURL, err := util.NormaliseURL(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf(
|
||||
"Test %d - '%s' FAILED: unexpected error: %v",
|
||||
testNum,
|
||||
testName,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
gotVisited := testCrawler.AddPageVisit(normalisedURL)
|
||||
|
||||
if gotVisited != wantVisited {
|
||||
t.Errorf(
|
||||
"Test %d - '%s' FAILED: unexpected bool returned after updated pages record: want %t, got %t",
|
||||
testNum,
|
||||
testName,
|
||||
wantVisited,
|
||||
gotVisited,
|
||||
)
|
||||
} else {
|
||||
t.Logf(
|
||||
"Test %d - '%s' PASSED: expected bool returned after updated pages record: got %t",
|
||||
testNum,
|
||||
testName,
|
||||
gotVisited,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
50
internal/crawler/gethtml.go
Normal file
50
internal/crawler/gethtml.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package crawler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getHTML(rawURL string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10*time.Second))
|
||||
defer cancel()
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating the HTTP request: %w", err)
|
||||
}
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting the response: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", fmt.Errorf(
|
||||
"received a bad status from %s: (%d) %s",
|
||||
rawURL,
|
||||
resp.StatusCode,
|
||||
resp.Status,
|
||||
)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("content-type")
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
return "", fmt.Errorf("unexpected content type received: want text/html, got %s", contentType)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading the data from the response: %w", err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
27
main.go
27
main.go
|
@ -3,8 +3,9 @@ package main
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"codeflow.dananglin.me.uk/apollo/web-crawler/internal/crawler"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -31,22 +32,18 @@ func run() error {
|
|||
return errTooManyArgs
|
||||
}
|
||||
|
||||
//baseURL := args[0]
|
||||
baseURL := args[0]
|
||||
|
||||
pages := make(map[string]int)
|
||||
|
||||
//var err error
|
||||
|
||||
//pages, err = crawlPage(baseURL, baseURL, pages)
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("received an error while crawling the website: %w", err)
|
||||
//}
|
||||
|
||||
fmt.Printf("\n\nRESULTS:\n")
|
||||
|
||||
for page, count := range maps.All(pages) {
|
||||
fmt.Printf("%s: %d\n", page, count)
|
||||
c, err := crawler.NewCrawler(baseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create the crawler: %w", err)
|
||||
}
|
||||
|
||||
go c.Crawl(baseURL)
|
||||
|
||||
c.Wait()
|
||||
|
||||
c.PrintReport()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue