generated from templates/go-generic
156 lines
3.4 KiB
Go
156 lines
3.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
errNoWebsiteProvided = errors.New("no website provided")
|
|
errTooManyArgs = errors.New("too many arguments provided")
|
|
)
|
|
|
|
func main() {
|
|
if err := run(); err != nil {
|
|
os.Stderr.WriteString("ERROR: " + err.Error() + "\n")
|
|
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func run() error {
|
|
args := os.Args[1:]
|
|
|
|
if len(args) == 0 {
|
|
return errNoWebsiteProvided
|
|
}
|
|
|
|
if len(args) > 1 {
|
|
return errTooManyArgs
|
|
}
|
|
|
|
baseURL := args[0]
|
|
|
|
pages := make(map[string]int)
|
|
|
|
var err error
|
|
|
|
pages, err = crawlPage(baseURL, baseURL, pages)
|
|
if err != nil {
|
|
return fmt.Errorf("received an error while crawling the website: %w", err)
|
|
}
|
|
|
|
fmt.Printf("\n\nRESULTS:\n")
|
|
|
|
for page, count := range maps.All(pages) {
|
|
fmt.Printf("%s: %d\n", page, count)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func crawlPage(rawBaseURL, rawCurrentURL string, pages map[string]int) (map[string]int, error) {
|
|
var err error
|
|
|
|
// if current URL is not on the same domain as the base URL, return the current pages.
|
|
sameDomain, err := equalDomains(rawBaseURL, rawCurrentURL)
|
|
if err != nil {
|
|
return pages, err
|
|
}
|
|
|
|
if !sameDomain {
|
|
return pages, nil
|
|
}
|
|
|
|
// get normalised version of rawCurrentURL
|
|
normalisedCurrentURL, err := normaliseURL(rawCurrentURL)
|
|
if err != nil {
|
|
return pages, err
|
|
}
|
|
|
|
// check if normalised URL has an entry in pages.
|
|
_, exists := pages[normalisedCurrentURL]
|
|
|
|
// If it has an entry, increment the count by 1 and return the pages.
|
|
if exists {
|
|
pages[normalisedCurrentURL]++
|
|
|
|
return pages, nil
|
|
}
|
|
|
|
// Create an entry for the page
|
|
pages[normalisedCurrentURL] = 1
|
|
|
|
// Get the HTML from the current URL, print that you are getting the HTML doc from current URL.
|
|
fmt.Printf("Crawling %q\n", rawCurrentURL)
|
|
|
|
htmlDoc, err := getHTML(rawCurrentURL)
|
|
if err != nil {
|
|
return pages, fmt.Errorf("error retrieving the HTML document from %q: %w", rawCurrentURL, err)
|
|
}
|
|
|
|
// Get all the URLs from the HTML doc.
|
|
links, err := getURLsFromHTML(htmlDoc, rawBaseURL)
|
|
if err != nil {
|
|
return pages, fmt.Errorf("error retrieving the links from the HTML document: %w", err)
|
|
}
|
|
|
|
// Recursively crawl each URL on the page. (add a timeout?)
|
|
for ind := range len(links) {
|
|
time.Sleep(time.Duration(1 * time.Second))
|
|
|
|
pages, err = crawlPage(rawBaseURL, links[ind], pages)
|
|
if err != nil {
|
|
fmt.Println("WARNING: error received while crawling %q: %v", links[ind], err)
|
|
}
|
|
}
|
|
|
|
return pages, nil
|
|
}
|
|
|
|
func getHTML(rawURL string) (string, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10*time.Second))
|
|
defer cancel()
|
|
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating the HTTP request: %w", err)
|
|
}
|
|
|
|
client := http.Client{}
|
|
|
|
resp, err := client.Do(request)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error getting the response: %w", err)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
return "", fmt.Errorf(
|
|
"received a bad status from %s: (%d) %s",
|
|
rawURL,
|
|
resp.StatusCode,
|
|
resp.Status,
|
|
)
|
|
}
|
|
|
|
contentType := resp.Header.Get("content-type")
|
|
if !strings.Contains(contentType, "text/html") {
|
|
return "", fmt.Errorf("unexpected content type received: want text/html, got %s", contentType)
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading the data from the response: %w", err)
|
|
}
|
|
|
|
return string(data), nil
|
|
}
|