diff --git a/get_urls_from_html.go b/html.go
similarity index 100%
rename from get_urls_from_html.go
rename to html.go
diff --git a/get_urls_from_html_test.go b/html_test.go
similarity index 100%
rename from get_urls_from_html_test.go
rename to html_test.go
diff --git a/magefiles/mage.go b/magefiles/mage.go
index 008a8d7..cac393d 100644
--- a/magefiles/mage.go
+++ b/magefiles/mage.go
@@ -57,9 +57,9 @@ func Lint() error {
// To enable verbose mode set PROJECT_BUILD_VERBOSE=1
func Build() error {
main := "."
- flags := ldflags()
+ //flags := ldflags()
build := sh.RunCmd("go", "build")
- args := []string{"-ldflags=" + flags, "-o", binary}
+ args := []string{"-ldflags=-s -w", "-o", binary}
if os.Getenv(envBuildRebuildAll) == "1" {
args = append(args, "-a")
diff --git a/main.go b/main.go
index e3babb4..4034cc8 100644
--- a/main.go
+++ b/main.go
@@ -1,9 +1,11 @@
package main
import (
+ "context"
"errors"
"fmt"
"io"
+ "maps"
"net/http"
"os"
"strings"
@@ -11,15 +13,14 @@ import (
)
var (
- binaryVersion string
- buildTime string
- goVersion string
- gitCommit string
+ errNoWebsiteProvided = errors.New("no website provided")
+ errTooManyArgs = errors.New("too many arguments provided")
)
func main() {
if err := run(); err != nil {
- fmt.Println(err)
+ os.Stderr.WriteString("ERROR: " + err.Error() + "\n")
+
os.Exit(1)
}
}
@@ -28,36 +29,104 @@ func run() error {
args := os.Args[1:]
if len(args) == 0 {
- return errors.New("no website provided")
+ return errNoWebsiteProvided
}
if len(args) > 1 {
- return errors.New("too many arguments provided")
+ return errTooManyArgs
}
baseURL := args[0]
- htmlBody, err := getHTML(baseURL)
+ pages := make(map[string]int)
+
+ var err error
+
+ pages, err = crawlPage(baseURL, baseURL, pages)
if err != nil {
- return err
+ return fmt.Errorf("received an error while crawling the website: %w", err)
}
- fmt.Println(htmlBody)
+ fmt.Printf("\n\nRESULTS:\n")
+
+ for page, count := range maps.All(pages) {
+ fmt.Printf("%s: %d\n", page, count)
+ }
return nil
}
-func getHTML(rawURL string) (string, error) {
- req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+func crawlPage(rawBaseURL, rawCurrentURL string, pages map[string]int) (map[string]int, error) {
+ var err error
+
+ // if current URL is not on the same domain as the base URL, return the current pages.
+ sameDomain, err := equalDomains(rawBaseURL, rawCurrentURL)
if err != nil {
- return "", fmt.Errorf("error creating the request: %w", err)
+ return pages, err
}
- client := http.Client{
- Timeout: time.Duration(10 * time.Second),
+ if !sameDomain {
+ return pages, nil
}
- resp, err := client.Do(req)
+ // get normalised version of rawCurrentURL
+ normalisedCurrentURL, err := normaliseURL(rawCurrentURL)
+ if err != nil {
+ return pages, err
+ }
+
+ // check if normalised URL has an entry in pages.
+ _, exists := pages[normalisedCurrentURL]
+
+ // If it has an entry, increment the count by 1 and return the pages.
+ if exists {
+ pages[normalisedCurrentURL]++
+
+ return pages, nil
+ }
+
+ // Create an entry for the page
+ pages[normalisedCurrentURL] = 1
+
+ // Get the HTML from the current URL, print that you are getting the HTML doc from current URL.
+ fmt.Printf("Crawling %q\n", rawCurrentURL)
+
+ htmlDoc, err := getHTML(rawCurrentURL)
+ if err != nil {
+ return pages, fmt.Errorf("error retrieving the HTML document from %q: %w", rawCurrentURL, err)
+ }
+
+ // Get all the URLs from the HTML doc.
+ links, err := getURLsFromHTML(htmlDoc, rawBaseURL)
+ if err != nil {
+ return pages, fmt.Errorf("error retrieving the links from the HTML document: %w", err)
+ }
+
+ // Recursively crawl each URL on the page. (add a timeout?)
+ for ind := range len(links) {
+ time.Sleep(time.Duration(1 * time.Second))
+
+ pages, err = crawlPage(rawBaseURL, links[ind], pages)
+ if err != nil {
+ fmt.Println("WARNING: error received while crawling %q: %v", links[ind], err)
+ }
+ }
+
+ return pages, nil
+}
+
+func getHTML(rawURL string) (string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10*time.Second))
+ defer cancel()
+
+ request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
+ if err != nil {
+ return "", fmt.Errorf("error creating the HTTP request: %w", err)
+ }
+
+ client := http.Client{}
+
+ resp, err := client.Do(request)
if err != nil {
return "", fmt.Errorf("error getting the response: %w", err)
}
diff --git a/normalise_url.go b/normalise_url.go
deleted file mode 100644
index a6ef682..0000000
--- a/normalise_url.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package main
-
-import (
- "fmt"
- "net/url"
- "strings"
-)
-
-func normaliseURL(input string) (string, error) {
- const normalisedFormat string = "%s%s"
-
- parsedURL, err := url.Parse(input)
- if err != nil {
- return "", fmt.Errorf("error parsing the URL %q: %w", input, err)
- }
-
- return fmt.Sprintf(normalisedFormat, parsedURL.Hostname(), strings.TrimSuffix(parsedURL.Path, "/")), nil
-}
diff --git a/normalise_url_test.go b/normalise_url_test.go
deleted file mode 100644
index 4df7b4f..0000000
--- a/normalise_url_test.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package main
-
-import (
- "slices"
- "testing"
-)
-
-func TestNormaliseURL(t *testing.T) {
- t.Parallel()
-
- wantNormalisedURL := "blog.boot.dev/path"
-
- cases := []struct {
- name string
- inputURL string
- }{
- {
- name: "remove HTTPS scheme",
- inputURL: "https://blog.boot.dev/path",
- },
- {
- name: "remove HTTP scheme",
- inputURL: "http://blog.boot.dev/path",
- },
- {
- name: "remove HTTPS scheme with a trailing slash",
- inputURL: "https://blog.boot.dev/path/",
- },
- {
- name: "remove HTTP scheme with a trailing slash",
- inputURL: "http://blog.boot.dev/path/",
- },
- {
- name: "remove HTTPS scheme with port 443",
- inputURL: "https://blog.boot.dev:443/path",
- },
- {
- name: "remove HTTP scheme with port 80",
- inputURL: "http://blog.boot.dev:80/path",
- },
- {
- name: "normalised URL",
- inputURL: "blog.boot.dev/path",
- },
- }
-
- for ind, tc := range slices.All(cases) {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- got, err := normaliseURL(tc.inputURL)
- if err != nil {
- t.Fatalf(
- "Test %v - '%s' FAILED: unexpected error: %v",
- ind,
- tc.name,
- err,
- )
- }
-
- if got != wantNormalisedURL {
- t.Errorf(
- "Test %d - %s PASSED: unexpected normalised URL returned: want %s, got %s",
- ind,
- tc.name,
- wantNormalisedURL,
- got,
- )
- } else {
- t.Logf(
- "Test %d - %s PASSED: expected normalised URL returned: got %s",
- ind,
- tc.name,
- got,
- )
- }
- })
- }
-}
diff --git a/url.go b/url.go
new file mode 100644
index 0000000..e049d56
--- /dev/null
+++ b/url.go
@@ -0,0 +1,32 @@
+package main
+
+import (
+ "fmt"
+ "net/url"
+ "strings"
+)
+
+func normaliseURL(rawURL string) (string, error) {
+ const normalisedFormat string = "%s%s"
+
+ parsedURL, err := url.Parse(rawURL)
+ if err != nil {
+ return "", fmt.Errorf("error parsing the URL %q: %w", rawURL, err)
+ }
+
+ return fmt.Sprintf(normalisedFormat, parsedURL.Hostname(), strings.TrimSuffix(parsedURL.Path, "/")), nil
+}
+
+func equalDomains(urlA, urlB string) (bool, error) {
+ parsedURLA, err := url.Parse(urlA)
+ if err != nil {
+ return false, fmt.Errorf("error parsing the URL %q: %w", urlA, err)
+ }
+
+ parsedURLB, err := url.Parse(urlB)
+ if err != nil {
+ return false, fmt.Errorf("error parsing the URL %q: %w", urlB, err)
+ }
+
+ return parsedURLA.Hostname() == parsedURLB.Hostname(), nil
+}
diff --git a/url_test.go b/url_test.go
new file mode 100644
index 0000000..e36d64a
--- /dev/null
+++ b/url_test.go
@@ -0,0 +1,146 @@
+package main
+
+import (
+ "slices"
+ "testing"
+)
+
+func TestNormaliseURL(t *testing.T) {
+ t.Parallel()
+
+ wantNormalisedURL := "blog.boot.dev/path"
+
+ cases := []struct {
+ name string
+ inputURL string
+ }{
+ {
+ name: "remove HTTPS scheme",
+ inputURL: "https://blog.boot.dev/path",
+ },
+ {
+ name: "remove HTTP scheme",
+ inputURL: "http://blog.boot.dev/path",
+ },
+ {
+ name: "remove HTTPS scheme with a trailing slash",
+ inputURL: "https://blog.boot.dev/path/",
+ },
+ {
+ name: "remove HTTP scheme with a trailing slash",
+ inputURL: "http://blog.boot.dev/path/",
+ },
+ {
+ name: "remove HTTPS scheme with port 443",
+ inputURL: "https://blog.boot.dev:443/path",
+ },
+ {
+ name: "remove HTTP scheme with port 80",
+ inputURL: "http://blog.boot.dev:80/path",
+ },
+ {
+ name: "normalised URL",
+ inputURL: "blog.boot.dev/path",
+ },
+ }
+
+ for ind, tc := range slices.All(cases) {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := normaliseURL(tc.inputURL)
+ if err != nil {
+ t.Fatalf(
+ "Test %d - '%s' FAILED: unexpected error: %v",
+ ind,
+ tc.name,
+ err,
+ )
+ }
+
+ if got != wantNormalisedURL {
+ t.Errorf(
+ "Test %d - %s FAILED: unexpected normalised URL returned: want %s, got %s",
+ ind,
+ tc.name,
+ wantNormalisedURL,
+ got,
+ )
+ } else {
+ t.Logf(
+ "Test %d - %s PASSED: expected normalised URL returned: got %s",
+ ind,
+ tc.name,
+ got,
+ )
+ }
+ })
+ }
+}
+
+func TestEqualDomains(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ urlA string
+ urlB string
+ want bool
+ }{
+ {
+ name: "Same domain, different paths",
+ urlA: "https://example.com/news",
+ urlB: "https://example.com/about/contact",
+ want: true,
+ },
+ {
+ name: "Different domains, same path",
+ urlA: "http://example.com/blog",
+ urlB: "http://example.org/blog",
+ want: false,
+ },
+ {
+ name: "Same domain, different protocols",
+ urlA: "http://code.person.me.uk/projects/orion",
+ urlB: "https://code.person.me.uk/user/person/README.md",
+ want: true,
+ },
+ }
+
+ for ind, tc := range slices.All(cases) {
+ t.Run(tc.name, testEqualDomains(ind+1, tc.name, tc.urlA, tc.urlB, tc.want))
+ }
+}
+
+func testEqualDomains(testNum int, testName, urlA, urlB string, want bool) func(t *testing.T) {
+ return func(t *testing.T) {
+ t.Parallel()
+
+ got, err := equalDomains(urlA, urlB)
+ if err != nil {
+ t.Fatalf(
+ "Test %d - '%s' FAILED: unexpected error: %v",
+ testNum,
+ testName,
+ err,
+ )
+ }
+
+ if got != want {
+ t.Errorf(
+ "Test %d - '%s' FAILED: unexpected domain comparison received: want %t, got %t",
+ testNum,
+ testName,
+ want,
+ got,
+ )
+ } else {
+ t.Logf(
+ "Test %d - '%s' PASSED: expected domain comparison received: got %t",
+ testNum,
+ testName,
+ got,
+ )
+ }
+ }
+}