checkpoint: fully implemented, needs extensive testing

This commit is contained in:
Dan Anglin 2024-05-22 21:30:42 +01:00
parent 1b76ceb4da
commit 08df511506
Signed by: dananglin
GPG key ID: 0C1D44CFBEE68638
16 changed files with 172 additions and 127 deletions

View file

@ -8,7 +8,7 @@ import (
"codeflow.dananglin.me.uk/apollo/enbas/internal/model" "codeflow.dananglin.me.uk/apollo/enbas/internal/model"
) )
func getAccountID(gtsClient *client.Client, myAccount bool, accountName string) (string, error) { func getAccountID(gtsClient *client.Client, myAccount bool, accountName, configDir string) (string, error) {
var ( var (
accountID string accountID string
err error err error
@ -16,7 +16,7 @@ func getAccountID(gtsClient *client.Client, myAccount bool, accountName string)
switch { switch {
case myAccount: case myAccount:
accountID, err = getMyAccountID(gtsClient) accountID, err = getMyAccountID(gtsClient, configDir)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to get your account ID; %w", err) return "", fmt.Errorf("unable to get your account ID; %w", err)
} }
@ -41,8 +41,8 @@ func getTheirAccountID(gtsClient *client.Client, accountURI string) (string, err
return account.ID, nil return account.ID, nil
} }
func getMyAccountID(gtsClient *client.Client) (string, error) { func getMyAccountID(gtsClient *client.Client, configDir string) (string, error) {
account, err := getMyAccount(gtsClient) account, err := getMyAccount(gtsClient, configDir)
if err != nil { if err != nil {
return "", fmt.Errorf("received an error while getting your account details; %w", err) return "", fmt.Errorf("received an error while getting your account details; %w", err)
} }
@ -50,8 +50,8 @@ func getMyAccountID(gtsClient *client.Client) (string, error) {
return account.ID, nil return account.ID, nil
} }
func getMyAccount(gtsClient *client.Client) (model.Account, error) { func getMyAccount(gtsClient *client.Client, configDir string) (model.Account, error) {
authConfig, err := config.NewAuthenticationConfigFromFile() authConfig, err := config.NewCredentialsConfigFromFile(configDir)
if err != nil { if err != nil {
return model.Account{}, fmt.Errorf("unable to retrieve the authentication configuration; %w", err) return model.Account{}, fmt.Errorf("unable to retrieve the authentication configuration; %w", err)
} }

View file

@ -11,6 +11,7 @@ import (
type addCommand struct { type addCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
toResourceType string toResourceType string
listID string listID string
@ -18,12 +19,13 @@ type addCommand struct {
content string content string
} }
func newAddCommand(name, summary string) *addCommand { func newAddCommand(tlf topLevelFlags, name, summary string) *addCommand {
emptyArr := make([]string, 0, 3) emptyArr := make([]string, 0, 3)
command := addCommand{ command := addCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
accountNames: accountNames(emptyArr), accountNames: accountNames(emptyArr),
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to add (e.g. account, note)") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to add (e.g. account, note)")
@ -52,7 +54,7 @@ func (c *addCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.toResourceType} return unsupportedResourceTypeError{resourceType: c.toResourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }
@ -126,7 +128,7 @@ func (c *addCommand) addNoteToAccount(gtsClient *client.Client) error {
return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames)) return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames))
} }
accountID, err := getAccountID(gtsClient, false, c.accountNames[0]) accountID, err := getAccountID(gtsClient, false, c.accountNames[0], c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }

View file

@ -10,15 +10,17 @@ import (
type blockCommand struct { type blockCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
accountName string accountName string
unblock bool unblock bool
} }
func newBlockCommand(name, summary string, unblock bool) *blockCommand { func newBlockCommand(tlf topLevelFlags, name, summary string, unblock bool) *blockCommand {
command := blockCommand{ command := blockCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
unblock: unblock, unblock: unblock,
} }
@ -40,7 +42,7 @@ func (c *blockCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.resourceType} return unsupportedResourceTypeError{resourceType: c.resourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }
@ -49,7 +51,7 @@ func (c *blockCommand) Execute() error {
} }
func (c *blockCommand) blockAccount(gtsClient *client.Client) error { func (c *blockCommand) blockAccount(gtsClient *client.Client) error {
accountID, err := getAccountID(gtsClient, false, c.accountName) accountID, err := getAccountID(gtsClient, false, c.accountName, c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }

View file

@ -11,14 +11,17 @@ import (
type createCommand struct { type createCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
listTitle string listTitle string
listRepliesPolicy string listRepliesPolicy string
} }
func newCreateCommand(name, summary string) *createCommand { func newCreateCommand(tlf topLevelFlags, name, summary string) *createCommand {
command := createCommand{ command := createCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to create") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to create")
@ -35,7 +38,7 @@ func (c *createCommand) Execute() error {
return flagNotSetError{flagText: resourceTypeFlag} return flagNotSetError{flagText: resourceTypeFlag}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }

View file

@ -10,13 +10,16 @@ import (
type deleteCommand struct { type deleteCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
listID string listID string
} }
func newDeleteCommand(name, summary string) *deleteCommand { func newDeleteCommand(tlf topLevelFlags, name, summary string) *deleteCommand {
command := deleteCommand{ command := deleteCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to delete") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to delete")
@ -41,7 +44,7 @@ func (c *deleteCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.resourceType} return unsupportedResourceTypeError{resourceType: c.resourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }

View file

@ -10,6 +10,7 @@ import (
type followCommand struct { type followCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
accountName string accountName string
showReposts bool showReposts bool
@ -17,11 +18,11 @@ type followCommand struct {
unfollow bool unfollow bool
} }
func newFollowCommand(name, summary string, unfollow bool) *followCommand { func newFollowCommand(tlf topLevelFlags, name, summary string, unfollow bool) *followCommand {
command := followCommand{ command := followCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
unfollow: unfollow, unfollow: unfollow,
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to follow") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to follow")
@ -44,7 +45,7 @@ func (c *followCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.resourceType} return unsupportedResourceTypeError{resourceType: c.resourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }
@ -53,7 +54,7 @@ func (c *followCommand) Execute() error {
} }
func (c *followCommand) followAccount(gtsClient *client.Client) error { func (c *followCommand) followAccount(gtsClient *client.Client) error {
accountID, err := getAccountID(gtsClient, false, c.accountName) accountID, err := getAccountID(gtsClient, false, c.accountName, c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }

View file

@ -12,12 +12,15 @@ import (
type loginCommand struct { type loginCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
instance string instance string
} }
func newLoginCommand(name, summary string) *loginCommand { func newLoginCommand(tlf topLevelFlags, name, summary string) *loginCommand {
command := loginCommand{ command := loginCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
instance: "", instance: "",
} }
@ -88,7 +91,7 @@ Once you have the code please copy and paste it below.
return fmt.Errorf("unable to verify the credentials; %w", err) return fmt.Errorf("unable to verify the credentials; %w", err)
} }
loginName, err := config.SaveCredentials(account.Username, gtsClient.Authentication) loginName, err := config.SaveCredentials(c.topLevelFlags.configDir, account.Username, gtsClient.Authentication)
if err != nil { if err != nil {
return fmt.Errorf("unable to save the authentication details; %w", err) return fmt.Errorf("unable to save the authentication details; %w", err)
} }

View file

@ -88,9 +88,9 @@ func run() error {
unblock: "unblock a resource (e.g. an account)", unblock: "unblock a resource (e.g. an account)",
} }
globals := topLevelFlags{} tlf := topLevelFlags{}
flag.StringVar(&globals.configDir, "config-dir", "", "specify your config directory") flag.StringVar(&tlf.configDir, "config-dir", "", "specify your config directory")
flag.Usage = enbasUsageFunc(summaries) flag.Usage = enbasUsageFunc(summaries)
@ -109,33 +109,33 @@ func run() error {
switch subcommand { switch subcommand {
case login: case login:
executor = newLoginCommand(login, summaries[login]) executor = newLoginCommand(tlf, login, summaries[login])
case version: case version:
executor = newVersionCommand(version, summaries[version]) executor = newVersionCommand(version, summaries[version])
case showResource: case showResource:
executor = newShowCommand(showResource, summaries[showResource]) executor = newShowCommand(tlf, showResource, summaries[showResource])
case switchAccount: case switchAccount:
executor = newSwitchCommand(switchAccount, summaries[switchAccount]) executor = newSwitchCommand(tlf, switchAccount, summaries[switchAccount])
case createResource: case createResource:
executor = newCreateCommand(createResource, summaries[createResource]) executor = newCreateCommand(tlf, createResource, summaries[createResource])
case deleteResource: case deleteResource:
executor = newDeleteCommand(deleteResource, summaries[deleteResource]) executor = newDeleteCommand(tlf, deleteResource, summaries[deleteResource])
case updateResource: case updateResource:
executor = newUpdateCommand(updateResource, summaries[updateResource]) executor = newUpdateCommand(tlf, updateResource, summaries[updateResource])
case whoami: case whoami:
executor = newWhoAmICommand(globals, whoami, summaries[whoami]) executor = newWhoAmICommand(tlf, whoami, summaries[whoami])
case add: case add:
executor = newAddCommand(add, summaries[add]) executor = newAddCommand(tlf, add, summaries[add])
case remove: case remove:
executor = newRemoveCommand(remove, summaries[remove]) executor = newRemoveCommand(tlf, remove, summaries[remove])
case follow: case follow:
executor = newFollowCommand(follow, summaries[follow], false) executor = newFollowCommand(tlf, follow, summaries[follow], false)
case unfollow: case unfollow:
executor = newFollowCommand(unfollow, summaries[unfollow], true) executor = newFollowCommand(tlf, unfollow, summaries[unfollow], true)
case block: case block:
executor = newBlockCommand(block, summaries[block], false) executor = newBlockCommand(tlf, block, summaries[block], false)
case unblock: case unblock:
executor = newBlockCommand(unblock, summaries[unblock], true) executor = newBlockCommand(tlf, unblock, summaries[unblock], true)
default: default:
flag.Usage() flag.Usage()

View file

@ -10,18 +10,20 @@ import (
type removeCommand struct { type removeCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
fromResourceType string fromResourceType string
listID string listID string
accountNames accountNames accountNames accountNames
} }
func newRemoveCommand(name, summary string) *removeCommand { func newRemoveCommand(tlf topLevelFlags, name, summary string) *removeCommand {
emptyArr := make([]string, 0, 3) emptyArr := make([]string, 0, 3)
command := removeCommand{ command := removeCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
accountNames: accountNames(emptyArr), accountNames: accountNames(emptyArr),
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to remove (e.g. account, note)") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to remove (e.g. account, note)")
@ -49,7 +51,7 @@ func (c *removeCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.fromResourceType} return unsupportedResourceTypeError{resourceType: c.fromResourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }
@ -123,7 +125,7 @@ func (c *removeCommand) removeNoteFromAccount(gtsClient *client.Client) error {
return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames)) return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames))
} }
accountID, err := getAccountID(gtsClient, false, c.accountNames[0]) accountID, err := getAccountID(gtsClient, false, c.accountNames[0], c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }

View file

@ -11,6 +11,7 @@ import (
type showCommand struct { type showCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
myAccount bool myAccount bool
showAccountRelationship bool showAccountRelationship bool
showUserPreferences bool showUserPreferences bool
@ -23,9 +24,10 @@ type showCommand struct {
limit int limit int
} }
func newShowCommand(name, summary string) *showCommand { func newShowCommand(tlf topLevelFlags, name, summary string) *showCommand {
command := showCommand{ command := showCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
} }
command.BoolVar(&command.myAccount, myAccountFlag, false, "set to true to lookup your account") command.BoolVar(&command.myAccount, myAccountFlag, false, "set to true to lookup your account")
@ -65,7 +67,7 @@ func (c *showCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.resourceType} return unsupportedResourceTypeError{resourceType: c.resourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }
@ -91,7 +93,7 @@ func (c *showCommand) showAccount(gtsClient *client.Client) error {
) )
if c.myAccount { if c.myAccount {
account, err = getMyAccount(gtsClient) account, err = getMyAccount(gtsClient, c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account details; %w", err) return fmt.Errorf("received an error while getting the account details; %w", err)
} }
@ -234,7 +236,7 @@ func (c *showCommand) showLists(gtsClient *client.Client) error {
} }
func (c *showCommand) showFollowers(gtsClient *client.Client) error { func (c *showCommand) showFollowers(gtsClient *client.Client) error {
accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName) accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName, c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }
@ -254,7 +256,7 @@ func (c *showCommand) showFollowers(gtsClient *client.Client) error {
} }
func (c *showCommand) showFollowing(gtsClient *client.Client) error { func (c *showCommand) showFollowing(gtsClient *client.Client) error {
accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName) accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName, c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("received an error while getting the account ID; %w", err) return fmt.Errorf("received an error while getting the account ID; %w", err)
} }

View file

@ -9,13 +9,15 @@ import (
type switchCommand struct { type switchCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
toAccount string toAccount string
} }
func newSwitchCommand(name, summary string) *switchCommand { func newSwitchCommand(tlf topLevelFlags, name, summary string) *switchCommand {
command := switchCommand{ command := switchCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
toAccount: "", topLevelFlags: tlf,
} }
command.StringVar(&command.toAccount, toAccountFlag, "", "the account to switch to") command.StringVar(&command.toAccount, toAccountFlag, "", "the account to switch to")
@ -30,7 +32,7 @@ func (c *switchCommand) Execute() error {
return flagNotSetError{flagText: toAccountFlag} return flagNotSetError{flagText: toAccountFlag}
} }
if err := config.UpdateCurrentAccount(c.toAccount); err != nil { if err := config.UpdateCurrentAccount(c.toAccount, c.topLevelFlags.configDir); err != nil {
return fmt.Errorf("unable to switch accounts; %w", err) return fmt.Errorf("unable to switch accounts; %w", err)
} }

View file

@ -11,15 +11,17 @@ import (
type updateCommand struct { type updateCommand struct {
*flag.FlagSet *flag.FlagSet
topLevelFlags topLevelFlags
resourceType string resourceType string
listID string listID string
listTitle string listTitle string
listRepliesPolicy string listRepliesPolicy string
} }
func newUpdateCommand(name, summary string) *updateCommand { func newUpdateCommand(tlf topLevelFlags, name, summary string) *updateCommand {
command := updateCommand{ command := updateCommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
topLevelFlags: tlf,
} }
command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to update") command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to update")
@ -46,7 +48,7 @@ func (c *updateCommand) Execute() error {
return unsupportedResourceTypeError{resourceType: c.resourceType} return unsupportedResourceTypeError{resourceType: c.resourceType}
} }
gtsClient, err := client.NewClientFromConfig() gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to create the GoToSocial client; %w", err) return fmt.Errorf("unable to create the GoToSocial client; %w", err)
} }

View file

@ -10,13 +10,13 @@ import (
type whoAmICommand struct { type whoAmICommand struct {
*flag.FlagSet *flag.FlagSet
globals topLevelFlags topLevelFlags topLevelFlags
} }
func newWhoAmICommand(globals topLevelFlags, name, summary string) *whoAmICommand { func newWhoAmICommand(tlf topLevelFlags, name, summary string) *whoAmICommand {
command := whoAmICommand{ command := whoAmICommand{
FlagSet: flag.NewFlagSet(name, flag.ExitOnError), FlagSet: flag.NewFlagSet(name, flag.ExitOnError),
globals: globals, topLevelFlags: tlf,
} }
command.Usage = commandUsageFunc(name, summary, command.FlagSet) command.Usage = commandUsageFunc(name, summary, command.FlagSet)
@ -25,7 +25,7 @@ func newWhoAmICommand(globals topLevelFlags, name, summary string) *whoAmIComman
} }
func (c *whoAmICommand) Execute() error { func (c *whoAmICommand) Execute() error {
config, err := config.NewAuthenticationConfigFromFile() config, err := config.NewCredentialsConfigFromFile(c.topLevelFlags.configDir)
if err != nil { if err != nil {
return fmt.Errorf("unable to load the credential config; %w", err) return fmt.Errorf("unable to load the credential config; %w", err)
} }

View file

@ -20,8 +20,8 @@ type Client struct {
Timeout time.Duration Timeout time.Duration
} }
func NewClientFromConfig() (*Client, error) { func NewClientFromConfig(configDir string) (*Client, error) {
config, err := config.NewAuthenticationConfigFromFile() config, err := config.NewCredentialsConfigFromFile(configDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to get the authentication configuration; %w", err) return nil, fmt.Errorf("unable to get the authentication configuration; %w", err)
} }
@ -70,7 +70,7 @@ func (g *Client) sendRequest(method string, url string, requestBody io.Reader, o
request.Header.Set("User-Agent", g.UserAgent) request.Header.Set("User-Agent", g.UserAgent)
if len(g.Authentication.AccessToken) > 0 { if len(g.Authentication.AccessToken) > 0 {
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", g.Authentication.AccessToken)) request.Header.Set("Authorization", "Bearer "+g.Authentication.AccessToken)
} }
response, err := g.HTTPClient.Do(request) response, err := g.HTTPClient.Do(request)

View file

@ -7,8 +7,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"codeflow.dananglin.me.uk/apollo/enbas/internal"
) )
const ( const (
@ -27,14 +25,25 @@ type Credentials struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
} }
func SaveCredentials(username string, credentials Credentials) (string, error) { type CredentialsNotFoundError struct {
if err := ensureConfigDir(); err != nil { AccountName string
}
func (e CredentialsNotFoundError) Error() string {
return "unable to find the credentials for the account '" + e.AccountName + "'"
}
// SaveCredentials saves the credentials into the credentials file within the specified configuration
// directory. If the directory is not specified then the default directory is used. If the directory
// is not present, it will be created.
func SaveCredentials(configDir, username string, credentials Credentials) (string, error) {
if err := ensureConfigDir(calculateConfigDir(configDir)); err != nil {
return "", fmt.Errorf("unable to ensure the configuration directory; %w", err) return "", fmt.Errorf("unable to ensure the configuration directory; %w", err)
} }
var authConfig CredentialsConfig var authConfig CredentialsConfig
filepath := credentialsConfigFile() filepath := credentialsConfigFile(configDir)
if _, err := os.Stat(filepath); err != nil { if _, err := os.Stat(filepath); err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
@ -43,7 +52,7 @@ func SaveCredentials(username string, credentials Credentials) (string, error) {
authConfig.Credentials = make(map[string]Credentials) authConfig.Credentials = make(map[string]Credentials)
} else { } else {
authConfig, err = NewAuthenticationConfigFromFile() authConfig, err = NewCredentialsConfigFromFile(configDir)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err) return "", fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err)
} }
@ -63,15 +72,34 @@ func SaveCredentials(username string, credentials Credentials) (string, error) {
authConfig.Credentials[authenticationName] = credentials authConfig.Credentials[authenticationName] = credentials
if err := saveCredentialsConfigFile(authConfig); err != nil { if err := saveCredentialsConfigFile(authConfig, configDir); err != nil {
return "", fmt.Errorf("unable to save the authentication configuration to file; %w", err) return "", fmt.Errorf("unable to save the authentication configuration to file; %w", err)
} }
return authenticationName, nil return authenticationName, nil
} }
func NewAuthenticationConfigFromFile() (CredentialsConfig, error) { func UpdateCurrentAccount(account string, configDir string) error {
path := credentialsConfigFile() credentialsConfig, err := NewCredentialsConfigFromFile(configDir)
if err != nil {
return fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err)
}
if _, ok := credentialsConfig.Credentials[account]; !ok {
return CredentialsNotFoundError{account}
}
credentialsConfig.CurrentAccount = account
if err := saveCredentialsConfigFile(credentialsConfig, configDir); err != nil {
return fmt.Errorf("unable to save the authentication configuration to file; %w", err)
}
return nil
}
func NewCredentialsConfigFromFile(configDir string) (CredentialsConfig, error) {
path := credentialsConfigFile(configDir)
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
@ -88,58 +116,12 @@ func NewAuthenticationConfigFromFile() (CredentialsConfig, error) {
return authConfig, nil return authConfig, nil
} }
func UpdateCurrentAccount(account string) error { func saveCredentialsConfigFile(authConfig CredentialsConfig, configDir string) error {
authConfig, err := NewAuthenticationConfigFromFile() path := credentialsConfigFile(configDir)
file, err := os.Create(path)
if err != nil { if err != nil {
return fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err) return fmt.Errorf("unable to open %s; %w", path, err)
}
if _, ok := authConfig.Credentials[account]; !ok {
return fmt.Errorf("account %s is not found", account)
}
authConfig.CurrentAccount = account
if err := saveCredentialsConfigFile(authConfig); err != nil {
return fmt.Errorf("unable to save the authentication configuration to file; %w", err)
}
return nil
}
func credentialsConfigFile() string {
return filepath.Join(configDir(), credentialsFileName)
}
func configDir() string {
rootDir, err := os.UserConfigDir()
if err != nil {
rootDir = "."
}
return filepath.Join(rootDir, internal.ApplicationName)
}
func ensureConfigDir() error {
dir := configDir()
if _, err := os.Stat(dir); err != nil {
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(dir, 0o750); err != nil {
return fmt.Errorf("unable to create %s; %w", dir, err)
}
} else {
return fmt.Errorf("unknown error received when running stat on %s; %w", dir, err)
}
}
return nil
}
func saveCredentialsConfigFile(authConfig CredentialsConfig) error {
file, err := os.Create(credentialsConfigFile())
if err != nil {
return fmt.Errorf("unable to open the config file; %w", err)
} }
defer file.Close() defer file.Close()
@ -153,3 +135,7 @@ func saveCredentialsConfigFile(authConfig CredentialsConfig) error {
return nil return nil
} }
func credentialsConfigFile(configDir string) string {
return filepath.Join(calculateConfigDir(configDir), credentialsFileName)
}

View file

@ -0,0 +1,37 @@
package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"codeflow.dananglin.me.uk/apollo/enbas/internal"
)
func calculateConfigDir(configDir string) string {
if configDir != "" {
return configDir
}
rootDir, err := os.UserConfigDir()
if err != nil {
rootDir = "."
}
return filepath.Join(rootDir, internal.ApplicationName)
}
func ensureConfigDir(configDir string) error {
if _, err := os.Stat(configDir); err != nil {
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(configDir, 0o750); err != nil {
return fmt.Errorf("unable to create %s; %w", configDir, err)
}
} else {
return fmt.Errorf("unknown error received when running stat on %s; %w", configDir, err)
}
}
return nil
}