refactor: project restructuring

- Moved the state implementation to internal/state.
- Moved the executors to internal/executors.
This commit is contained in:
Dan Anglin 2024-10-01 05:44:21 +01:00
parent a832053349
commit 7783c10504
Signed by: dananglin
GPG key ID: 0C1D44CFBEE68638
19 changed files with 639 additions and 504 deletions

View file

@ -22,6 +22,7 @@ linters-settings:
- $gostd - $gostd
- codeflow.dananglin.me.uk/apollo/gator - codeflow.dananglin.me.uk/apollo/gator
- github.com/google/uuid - github.com/google/uuid
- github.com/lib/pq
lll: lll:
line-length: 140 line-length: 140

View file

@ -1,454 +0,0 @@
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"time"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/rss"
"github.com/google/uuid"
"github.com/lib/pq"
)
type commands struct {
commandMap map[string]commandFunc
}
type commandFunc func(*state, command) error
type command struct {
name string
args []string
}
func (c *commands) register(name string, f commandFunc) {
c.commandMap[name] = f
}
func (c *commands) run(s *state, cmd command) error {
runFunc, ok := c.commandMap[cmd.name]
if !ok {
return fmt.Errorf("unrecognised command: %s", cmd.name)
}
return runFunc(s, cmd)
}
func handlerLogin(s *state, cmd command) error {
if len(cmd.args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(cmd.args))
}
username := cmd.args[0]
user, err := s.db.GetUserByName(context.Background(), username)
if err != nil {
return fmt.Errorf("unable to get the user from the database: %w", err)
}
if err := s.config.SetUser(user.Name); err != nil {
return fmt.Errorf("login error: %w", err)
}
fmt.Printf("The current user is set to %q.\n", username)
return nil
}
func handlerRegister(s *state, cmd command) error {
if len(cmd.args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(cmd.args))
}
name := cmd.args[0]
timestamp := time.Now()
args := database.CreateUserParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Name: name,
}
user, err := s.db.CreateUser(context.Background(), args)
if err != nil {
if uniqueViolation(err) {
return errors.New("this user is already registered")
}
return fmt.Errorf("unable to register the user: %w", err)
}
if err := s.config.SetUser(name); err != nil {
return fmt.Errorf("unable to update the configuration: %w", err)
}
fmt.Printf("Successfully registered %s.\n", user.Name)
fmt.Println("DEBUG:", user)
return nil
}
func handlerReset(s *state, _ command) error {
if err := s.db.DeleteAllUsers(context.Background()); err != nil {
fmt.Errorf("unable to delete the users from the database: %w", err)
}
fmt.Println("Successfully removed all users from the database.")
return nil
}
func handlerUsers(s *state, _ command) error {
users, err := s.db.GetAllUsers(context.Background())
if err != nil {
fmt.Errorf("unable to get the users from the database: %w", err)
}
if len(users) == 0 {
fmt.Println("There are no registered users.")
return nil
}
fmt.Printf("Registered users:\n\n")
for _, user := range users {
if user.Name == s.config.CurrentUsername {
fmt.Printf("- %s (current)\n", user.Name)
} else {
fmt.Printf("- %s\n", user.Name)
}
}
return nil
}
func handlerAgg(s *state, cmd command) error {
if len(cmd.args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(cmd.args))
}
intervalArg := cmd.args[0]
interval, err := time.ParseDuration(intervalArg)
if err != nil {
return fmt.Errorf("unable to parse the interval: %w", err)
}
fmt.Printf("Fetching feeds every %s\n", interval.String())
tick := time.Tick(interval)
for range tick {
if err := scrapeFeeds(s); err != nil {
fmt.Println("ERROR: %v", err)
}
}
return nil
}
func handlerAddFeed(s *state, cmd command, user database.User) error {
if len(cmd.args) != 2 {
return fmt.Errorf("unexpected number of arguments: want 2, got %d", len(cmd.args))
}
name, url := cmd.args[0], cmd.args[1]
timestamp := time.Now()
createdFeedArgs := database.CreateFeedParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Name: name,
Url: url,
UserID: user.ID,
}
feed, err := s.db.CreateFeed(context.Background(), createdFeedArgs)
if err != nil {
return fmt.Errorf("unable to add the feed: %w", err)
}
fmt.Println("Successfully added the feed.")
fmt.Println("DEBUG:", feed)
createFeedFollowArgs := database.CreateFeedFollowParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
UserID: user.ID,
FeedID: feed.ID,
}
followRecord, err := s.db.CreateFeedFollow(context.Background(), createFeedFollowArgs)
if err != nil {
return fmt.Errorf("unable to create the feed follow record in the database: %w", err)
}
fmt.Printf("You are now following the feed %q.\n", followRecord.FeedName)
fmt.Println("DEBUG:", followRecord)
return nil
}
func handlerFeeds(s *state, _ command) error {
feeds, err := s.db.GetAllFeeds(context.Background())
if err != nil {
return fmt.Errorf("unable to get the feeds from the database: %w", err)
}
fmt.Printf("Feeds:\n\n")
for _, feed := range feeds {
user, err := s.db.GetUserByID(context.Background(), feed.UserID)
if err != nil {
return fmt.Errorf(
"unable to get the creator of %s: %w",
feed.Name,
err,
)
}
fmt.Printf(
"- Name: %s\n URL: %s\n Created by: %s\n",
feed.Name,
feed.Url,
user.Name,
)
}
return nil
}
func handlerFollow(s *state, cmd command, user database.User) error {
if len(cmd.args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 2, got %d", len(cmd.args))
}
url := cmd.args[0]
feed, err := s.db.GetFeedByUrl(context.Background(), url)
if err != nil {
return fmt.Errorf("unable to get the feed data from the database: %w", err)
}
timestamp := time.Now()
args := database.CreateFeedFollowParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
UserID: user.ID,
FeedID: feed.ID,
}
followRecord, err := s.db.CreateFeedFollow(context.Background(), args)
if err != nil {
if uniqueViolation(err) {
return errors.New("you are already following this feed")
}
return fmt.Errorf("unable to create the feed follow record in the database: %w", err)
}
fmt.Printf("You are now following the feed %q.\n", followRecord.FeedName)
fmt.Println("DEBUG:", followRecord)
return nil
}
func handlerFollowing(s *state, _ command, user database.User) error {
following, err := s.db.GetFeedFollowsForUser(context.Background(), user.ID)
if err != nil {
return fmt.Errorf("unable to get the list of feeds from the database: %w", err)
}
if len(following) == 0 {
fmt.Println("You are not following any feeds.")
return nil
}
fmt.Printf("\nYou are following:\n\n")
for _, feed := range following {
fmt.Printf("- %s\n", feed)
}
return nil
}
func handlerUnfollow(s *state, cmd command, user database.User) error {
if len(cmd.args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 2, got %d", len(cmd.args))
}
url := cmd.args[0]
feed, err := s.db.GetFeedByUrl(context.Background(), url)
if err != nil {
return fmt.Errorf("unable to get the feed data from the database: %w", err)
}
args := database.DeleteFeedFollowParams{
UserID: user.ID,
FeedID: feed.ID,
}
if err := s.db.DeleteFeedFollow(context.Background(), args); err != nil {
return fmt.Errorf("unable to delete the feed follow record from the database: %w", err)
}
fmt.Printf("You have successfully unfollowed %q.\n", feed.Name)
return nil
}
func handlerBrowse(s *state, cmd command, user database.User) error {
if len(cmd.args) > 1 {
return fmt.Errorf("unexpected number of arguments: want 0 or 1, got %d", len(cmd.args))
}
var err error
limit := 2
if len(cmd.args) == 1 {
limit, err = strconv.Atoi(cmd.args[0])
if err != nil {
return fmt.Errorf("unable to convert %s to a number: %w", cmd.args[0], err)
}
}
args := database.GetPostsForUserParams{
UserID: user.ID,
Limit: int32(limit),
}
posts, err := s.db.GetPostsForUser(context.Background(), args)
if err != nil {
return fmt.Errorf("unable to get the posts: %w", err)
}
fmt.Printf("\nPosts:\n\n")
for _, post := range posts {
fmt.Printf(
"- Title: %s\n URL: %s\n Published at: %s\n",
post.Title,
post.Url,
post.PublishedAt,
)
}
return nil
}
func scrapeFeeds(s *state) error {
feed, err := s.db.GetNextFeedToFetch(context.Background())
if err != nil {
return fmt.Errorf("unable to get the next feed from the database: %w", err)
}
fmt.Printf("\nFetching feed from %s\n", feed.Url)
feedDetails, err := rss.FetchFeed(context.Background(), feed.Url)
if err != nil {
return fmt.Errorf("unable to fetch the feed: %w", err)
}
timestamp := time.Now()
lastFetched := sql.NullTime{
Time: timestamp,
Valid: true,
}
markFeedFetchedArgs := database.MarkFeedFetchedParams{
ID: feed.ID,
LastFetchedAt: lastFetched,
UpdatedAt: timestamp,
}
if err := s.db.MarkFeedFetched(context.Background(), markFeedFetchedArgs); err != nil {
return fmt.Errorf("unable to mark the feed as fetched in the database: %w", err)
}
timeParsingFormats := []string{
time.RFC1123Z,
time.RFC1123,
}
for _, item := range feedDetails.Channel.Items {
var (
pubDate time.Time
err error
)
pubDateFormatted := false
for _, format := range timeParsingFormats {
pubDate, err = time.Parse(format, item.PubDate)
if err == nil {
pubDateFormatted = true
break
}
}
if !pubDateFormatted {
fmt.Printf(
"Error: unable to format the publication date (%s) of %q.\n",
item.PubDate,
item.Title,
)
continue
}
timestamp := time.Now()
args := database.CreatePostParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Title: item.Title,
Url: item.Link,
Description: item.Description,
FeedID: feed.ID,
PublishedAt: pubDate,
}
_, err = s.db.CreatePost(context.Background(), args)
if err != nil && !uniqueViolation(err) {
fmt.Printf(
"Error: unable to add the post %q to the database: %v.\n",
item.Title,
err,
)
}
}
return nil
}
func uniqueViolation(err error) bool {
var pqError *pq.Error
if errors.As(err, &pqError) {
if pqError.Code.Name() == "unique_violation" {
return true
}
}
return false
}

View file

@ -0,0 +1,63 @@
package executors
import (
"context"
"fmt"
"time"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
"github.com/google/uuid"
)
func AddFeed(s *state.State, exe Executor, user database.User) error {
wantArgs := 2
if len(exe.Args) != wantArgs {
return fmt.Errorf(
"unexpected number of arguments: want %d, got %d",
wantArgs,
len(exe.Args),
)
}
name, url := exe.Args[0], exe.Args[1]
timestamp := time.Now()
createdFeedArgs := database.CreateFeedParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Name: name,
Url: url,
UserID: user.ID,
}
feed, err := s.DB.CreateFeed(context.Background(), createdFeedArgs)
if err != nil {
return fmt.Errorf("unable to add the feed: %w", err)
}
fmt.Println("Successfully added the feed.")
fmt.Println("DEBUG:", feed)
createFeedFollowArgs := database.CreateFeedFollowParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
UserID: user.ID,
FeedID: feed.ID,
}
followRecord, err := s.DB.CreateFeedFollow(context.Background(), createFeedFollowArgs)
if err != nil {
return fmt.Errorf("unable to create the feed follow record in the database: %w", err)
}
fmt.Printf("You are now following the feed %q.\n", followRecord.FeedName)
fmt.Println("DEBUG:", followRecord)
return nil
}

View file

@ -0,0 +1,126 @@
package executors
import (
"context"
"database/sql"
"fmt"
"time"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/rss"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
"github.com/google/uuid"
)
func Aggregate(s *state.State, exe Executor) error {
if len(exe.Args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(exe.Args))
}
intervalArg := exe.Args[0]
interval, err := time.ParseDuration(intervalArg)
if err != nil {
return fmt.Errorf("unable to parse the interval: %w", err)
}
fmt.Printf("Fetching feeds every %s\n", interval.String())
tick := time.Tick(interval)
for range tick {
if err := scrapeFeeds(s); err != nil {
fmt.Println("ERROR: %v", err)
}
}
return nil
}
func scrapeFeeds(s *state.State) error {
feed, err := s.DB.GetNextFeedToFetch(context.Background())
if err != nil {
return fmt.Errorf("unable to get the next feed from the database: %w", err)
}
fmt.Printf("\nFetching feed from %s\n", feed.Url)
feedDetails, err := rss.FetchFeed(context.Background(), feed.Url)
if err != nil {
return fmt.Errorf("unable to fetch the feed: %w", err)
}
timestamp := time.Now()
lastFetched := sql.NullTime{
Time: timestamp,
Valid: true,
}
markFeedFetchedArgs := database.MarkFeedFetchedParams{
ID: feed.ID,
LastFetchedAt: lastFetched,
UpdatedAt: timestamp,
}
if err := s.DB.MarkFeedFetched(context.Background(), markFeedFetchedArgs); err != nil {
return fmt.Errorf("unable to mark the feed as fetched in the database: %w", err)
}
timeParsingFormats := []string{
time.RFC1123Z,
time.RFC1123,
}
for _, item := range feedDetails.Channel.Items {
var (
pubDate time.Time
err error
)
pubDateFormatted := false
for _, format := range timeParsingFormats {
pubDate, err = time.Parse(format, item.PubDate)
if err == nil {
pubDateFormatted = true
break
}
}
if !pubDateFormatted {
fmt.Printf(
"Error: unable to format the publication date (%s) of %q.\n",
item.PubDate,
item.Title,
)
continue
}
timestamp := time.Now()
args := database.CreatePostParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Title: item.Title,
Url: item.Link,
Description: item.Description,
FeedID: feed.ID,
PublishedAt: pubDate,
}
_, err = s.DB.CreatePost(context.Background(), args)
if err != nil && !uniqueViolation(err) {
fmt.Printf(
"Error: unable to add the post %q to the database: %v.\n",
item.Title,
err,
)
}
}
return nil
}

View file

@ -0,0 +1,50 @@
package executors
import (
"context"
"fmt"
"strconv"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Browse(s *state.State, exe Executor, user database.User) error {
if len(exe.Args) > 1 {
return fmt.Errorf("unexpected number of arguments: want 0 or 1, got %d", len(exe.Args))
}
var err error
limit := 2
if len(exe.Args) == 1 {
limit, err = strconv.Atoi(exe.Args[0])
if err != nil {
return fmt.Errorf("unable to convert %s to a number: %w", exe.Args[0], err)
}
}
args := database.GetPostsForUserParams{
UserID: user.ID,
Limit: int32(limit),
}
posts, err := s.DB.GetPostsForUser(context.Background(), args)
if err != nil {
return fmt.Errorf("unable to get the posts: %w", err)
}
fmt.Printf("\nPosts:\n\n")
for _, post := range posts {
fmt.Printf(
"- Title: %s\n URL: %s\n Published at: %s\n",
post.Title,
post.Url,
post.PublishedAt,
)
}
return nil
}

View file

@ -0,0 +1,31 @@
package executors
import (
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
type ExecutorMap struct {
Map map[string]ExecutorFunc
}
type ExecutorFunc func(*state.State, Executor) error
type Executor struct {
Name string
Args []string
}
func (e *ExecutorMap) Register(name string, f ExecutorFunc) {
e.Map[name] = f
}
func (e *ExecutorMap) Run(s *state.State, exe Executor) error {
runFunc, ok := e.Map[exe.Name]
if !ok {
return fmt.Errorf("unrecognised command: %s", exe.Name)
}
return runFunc(s, exe)
}

View file

@ -0,0 +1,37 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Feeds(s *state.State, _ Executor) error {
feeds, err := s.DB.GetAllFeeds(context.Background())
if err != nil {
return fmt.Errorf("unable to get the feeds from the database: %w", err)
}
fmt.Printf("Feeds:\n\n")
for _, feed := range feeds {
user, err := s.DB.GetUserByID(context.Background(), feed.UserID)
if err != nil {
return fmt.Errorf(
"unable to get the creator of %s: %w",
feed.Name,
err,
)
}
fmt.Printf(
"- Name: %s\n URL: %s\n Created by: %s\n",
feed.Name,
feed.Url,
user.Name,
)
}
return nil
}

View file

@ -0,0 +1,55 @@
package executors
import (
"context"
"errors"
"fmt"
"time"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
"github.com/google/uuid"
)
func Follow(s *state.State, exe Executor, user database.User) error {
wantNumArgs := 1
if len(exe.Args) != wantNumArgs {
return fmt.Errorf(
"unexpected number of arguments: want %d, got %d",
wantNumArgs,
len(exe.Args),
)
}
url := exe.Args[0]
feed, err := s.DB.GetFeedByUrl(context.Background(), url)
if err != nil {
return fmt.Errorf("unable to get the feed data from the database: %w", err)
}
timestamp := time.Now()
args := database.CreateFeedFollowParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
UserID: user.ID,
FeedID: feed.ID,
}
followRecord, err := s.DB.CreateFeedFollow(context.Background(), args)
if err != nil {
if uniqueViolation(err) {
return errors.New("you are already following this feed")
}
return fmt.Errorf("unable to create the feed follow record in the database: %w", err)
}
fmt.Printf("You are now following the feed %q.\n", followRecord.FeedName)
fmt.Println("DEBUG:", followRecord)
return nil
}

View file

@ -0,0 +1,30 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Following(s *state.State, _ Executor, user database.User) error {
following, err := s.DB.GetFeedFollowsForUser(context.Background(), user.ID)
if err != nil {
return fmt.Errorf("unable to get the list of feeds from the database: %w", err)
}
if len(following) == 0 {
fmt.Println("You are not following any feeds.")
return nil
}
fmt.Printf("\nYou are following:\n\n")
for _, feed := range following {
fmt.Printf("- %s\n", feed)
}
return nil
}

View file

@ -0,0 +1,29 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Login(s *state.State, exe Executor) error {
if len(exe.Args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(exe.Args))
}
username := exe.Args[0]
user, err := s.DB.GetUserByName(context.Background(), username)
if err != nil {
return fmt.Errorf("unable to get the user from the database: %w", err)
}
if err := s.Config.SetUser(user.Name); err != nil {
return fmt.Errorf("login error: %w", err)
}
fmt.Printf("The current user is set to %q.\n", username)
return nil
}

View file

@ -0,0 +1,20 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func MiddlewareLoggedIn(handler func(s *state.State, exe Executor, user database.User) error) ExecutorFunc {
return func(s *state.State, exe Executor) error {
user, err := s.DB.GetUserByName(context.Background(), s.Config.CurrentUsername)
if err != nil {
return fmt.Errorf("unable to get the user from the database: %w", err)
}
return handler(s, exe, user)
}
}

View file

@ -0,0 +1,47 @@
package executors
import (
"context"
"errors"
"fmt"
"time"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
"github.com/google/uuid"
)
func Register(s *state.State, exe Executor) error {
if len(exe.Args) != 1 {
return fmt.Errorf("unexpected number of arguments: want 1, got %d", len(exe.Args))
}
name := exe.Args[0]
timestamp := time.Now()
args := database.CreateUserParams{
ID: uuid.New(),
CreatedAt: timestamp,
UpdatedAt: timestamp,
Name: name,
}
user, err := s.DB.CreateUser(context.Background(), args)
if err != nil {
if uniqueViolation(err) {
return errors.New("this user is already registered")
}
return fmt.Errorf("unable to register the user: %w", err)
}
if err := s.Config.SetUser(name); err != nil {
return fmt.Errorf("unable to update the configuration: %w", err)
}
fmt.Printf("Successfully registered %s.\n", user.Name)
fmt.Println("DEBUG:", user)
return nil
}

View file

@ -0,0 +1,18 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Reset(s *state.State, _ Executor) error {
if err := s.DB.DeleteAllUsers(context.Background()); err != nil {
fmt.Errorf("unable to delete the users from the database: %w", err)
}
fmt.Println("Successfully removed all users from the database.")
return nil
}

View file

@ -0,0 +1,41 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Unfollow(s *state.State, exe Executor, user database.User) error {
wantNumArgs := 1
if len(exe.Args) != wantNumArgs {
return fmt.Errorf(
"unexpected number of arguments: want %d, got %d",
wantNumArgs,
len(exe.Args),
)
}
url := exe.Args[0]
feed, err := s.DB.GetFeedByUrl(context.Background(), url)
if err != nil {
return fmt.Errorf("unable to get the feed data from the database: %w", err)
}
args := database.DeleteFeedFollowParams{
UserID: user.ID,
FeedID: feed.ID,
}
if err := s.DB.DeleteFeedFollow(context.Background(), args); err != nil {
return fmt.Errorf("unable to delete the feed follow record from the database: %w", err)
}
fmt.Printf("You have successfully unfollowed %q.\n", feed.Name)
return nil
}

View file

@ -0,0 +1,33 @@
package executors
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
)
func Users(s *state.State, _ Executor) error {
users, err := s.DB.GetAllUsers(context.Background())
if err != nil {
fmt.Errorf("unable to get the users from the database: %w", err)
}
if len(users) == 0 {
fmt.Println("There are no registered users.")
return nil
}
fmt.Printf("Registered users:\n\n")
for _, user := range users {
if user.Name == s.Config.CurrentUsername {
fmt.Printf("- %s (current)\n", user.Name)
} else {
fmt.Printf("- %s\n", user.Name)
}
}
return nil
}

View file

@ -0,0 +1,19 @@
package executors
import (
"errors"
"github.com/lib/pq"
)
func uniqueViolation(err error) bool {
var pqError *pq.Error
if errors.As(err, &pqError) {
if pqError.Code.Name() == "unique_violation" {
return true
}
}
return false
}

11
internal/state/state.go Normal file
View file

@ -0,0 +1,11 @@
package state
import (
"codeflow.dananglin.me.uk/apollo/gator/internal/config"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
)
type State struct {
DB *database.Queries
Config *config.Config
}

59
main.go
View file

@ -8,14 +8,11 @@ import (
"codeflow.dananglin.me.uk/apollo/gator/internal/config" "codeflow.dananglin.me.uk/apollo/gator/internal/config"
"codeflow.dananglin.me.uk/apollo/gator/internal/database" "codeflow.dananglin.me.uk/apollo/gator/internal/database"
"codeflow.dananglin.me.uk/apollo/gator/internal/executors"
"codeflow.dananglin.me.uk/apollo/gator/internal/state"
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
type state struct {
db *database.Queries
config *config.Config
}
var ( var (
binaryVersion string binaryVersion string
buildTime string buildTime string
@ -41,49 +38,49 @@ func run() error {
return fmt.Errorf("unable to open a connection to the database: %w", err) return fmt.Errorf("unable to open a connection to the database: %w", err)
} }
s := state{ s := state.State{
db: database.New(db), DB: database.New(db),
config: &cfg, Config: &cfg,
} }
cmds := commands{ executorMap := executors.ExecutorMap{
commandMap: make(map[string]commandFunc), Map: make(map[string]executors.ExecutorFunc),
} }
cmds.register("login", handlerLogin) executorMap.Register("login", executors.Login)
cmds.register("register", handlerRegister) executorMap.Register("register", executors.Register)
cmds.register("reset", handlerReset) executorMap.Register("reset", executors.Reset)
cmds.register("users", handlerUsers) executorMap.Register("users", executors.Users)
cmds.register("agg", handlerAgg) executorMap.Register("aggregate", executors.Aggregate)
cmds.register("addfeed", middlewareLoggedIn(handlerAddFeed)) executorMap.Register("addfeed", executors.MiddlewareLoggedIn(executors.AddFeed))
cmds.register("feeds", handlerFeeds) executorMap.Register("feeds", executors.Feeds)
cmds.register("follow", middlewareLoggedIn(handlerFollow)) executorMap.Register("follow", executors.MiddlewareLoggedIn(executors.Follow))
cmds.register("unfollow", middlewareLoggedIn(handlerUnfollow)) executorMap.Register("unfollow", executors.MiddlewareLoggedIn(executors.Unfollow))
cmds.register("following", middlewareLoggedIn(handlerFollowing)) executorMap.Register("following", executors.MiddlewareLoggedIn(executors.Following))
cmds.register("browse", middlewareLoggedIn(handlerBrowse)) executorMap.Register("browse", executors.MiddlewareLoggedIn(executors.Browse))
cmd, err := parseArgs(os.Args[1:]) executor, err := parseArgs(os.Args[1:])
if err != nil { if err != nil {
return fmt.Errorf("unable to parse the command: %w", err) return fmt.Errorf("unable to parse the command: %w", err)
} }
return cmds.run(&s, cmd) return executorMap.Run(&s, executor)
} }
func parseArgs(args []string) (command, error) { func parseArgs(args []string) (executors.Executor, error) {
if len(args) == 0 { if len(args) == 0 {
return command{}, errors.New("no arguments given") return executors.Executor{}, errors.New("no arguments given")
} }
if len(args) == 1 { if len(args) == 1 {
return command{ return executors.Executor{
name: args[0], Name: args[0],
args: make([]string, 0), Args: make([]string, 0),
}, nil }, nil
} }
return command{ return executors.Executor{
name: args[0], Name: args[0],
args: args[1:], Args: args[1:],
}, nil }, nil
} }

View file

@ -1,19 +0,0 @@
package main
import (
"context"
"fmt"
"codeflow.dananglin.me.uk/apollo/gator/internal/database"
)
func middlewareLoggedIn(handler func(s *state, cmd command, user database.User) error) commandFunc {
return func(s *state, cmd command) error {
user, err := s.db.GetUserByName(context.Background(), s.config.CurrentUsername)
if err != nil {
return fmt.Errorf("unable to get the user from the database: %w", err)
}
return handler(s, cmd, user)
}
}