diff --git a/.gitignore b/.gitignore index b53949b..39d691b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /environment/ -/enbas +/*.enbas +/__build/* +!__build/.gitkeep diff --git a/__build/.gitkeep b/__build/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/cmd/enbas/account.go b/cmd/enbas/account.go index a7d6053..91d29f6 100644 --- a/cmd/enbas/account.go +++ b/cmd/enbas/account.go @@ -8,7 +8,7 @@ import ( "codeflow.dananglin.me.uk/apollo/enbas/internal/model" ) -func getAccountID(gtsClient *client.Client, myAccount bool, accountName string) (string, error) { +func getAccountID(gtsClient *client.Client, myAccount bool, accountName, configDir string) (string, error) { var ( accountID string err error @@ -16,7 +16,7 @@ func getAccountID(gtsClient *client.Client, myAccount bool, accountName string) switch { case myAccount: - accountID, err = getMyAccountID(gtsClient) + accountID, err = getMyAccountID(gtsClient, configDir) if err != nil { return "", fmt.Errorf("unable to get your account ID; %w", err) } @@ -41,8 +41,8 @@ func getTheirAccountID(gtsClient *client.Client, accountURI string) (string, err return account.ID, nil } -func getMyAccountID(gtsClient *client.Client) (string, error) { - account, err := getMyAccount(gtsClient) +func getMyAccountID(gtsClient *client.Client, configDir string) (string, error) { + account, err := getMyAccount(gtsClient, configDir) if err != nil { return "", fmt.Errorf("received an error while getting your account details; %w", err) } @@ -50,8 +50,8 @@ func getMyAccountID(gtsClient *client.Client) (string, error) { return account.ID, nil } -func getMyAccount(gtsClient *client.Client) (model.Account, error) { - authConfig, err := config.NewAuthenticationConfigFromFile() +func getMyAccount(gtsClient *client.Client, configDir string) (model.Account, error) { + authConfig, err := config.NewCredentialsConfigFromFile(configDir) if err != nil { return model.Account{}, fmt.Errorf("unable to retrieve the authentication configuration; %w", err) } diff --git a/cmd/enbas/add.go b/cmd/enbas/add.go index b860b75..40bbdeb 100644 --- a/cmd/enbas/add.go +++ b/cmd/enbas/add.go @@ -11,6 +11,7 @@ import ( type addCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string toResourceType string listID string @@ -18,12 +19,13 @@ type addCommand struct { content string } -func newAddCommand(name, summary string) *addCommand { +func newAddCommand(tlf topLevelFlags, name, summary string) *addCommand { emptyArr := make([]string, 0, 3) command := addCommand{ - FlagSet: flag.NewFlagSet(name, flag.ExitOnError), - accountNames: accountNames(emptyArr), + FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + accountNames: accountNames(emptyArr), + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to add (e.g. account, note)") @@ -52,7 +54,7 @@ func (c *addCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.toResourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } @@ -126,7 +128,7 @@ func (c *addCommand) addNoteToAccount(gtsClient *client.Client) error { return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames)) } - accountID, err := getAccountID(gtsClient, false, c.accountNames[0]) + accountID, err := getAccountID(gtsClient, false, c.accountNames[0], c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } diff --git a/cmd/enbas/block.go b/cmd/enbas/block.go index 03a7069..c7650d3 100644 --- a/cmd/enbas/block.go +++ b/cmd/enbas/block.go @@ -10,15 +10,17 @@ import ( type blockCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string accountName string unblock bool } -func newBlockCommand(name, summary string, unblock bool) *blockCommand { +func newBlockCommand(tlf topLevelFlags, name, summary string, unblock bool) *blockCommand { command := blockCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + topLevelFlags: tlf, unblock: unblock, } @@ -40,7 +42,7 @@ func (c *blockCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.resourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } @@ -49,7 +51,7 @@ func (c *blockCommand) Execute() error { } func (c *blockCommand) blockAccount(gtsClient *client.Client) error { - accountID, err := getAccountID(gtsClient, false, c.accountName) + accountID, err := getAccountID(gtsClient, false, c.accountName, c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } diff --git a/cmd/enbas/create.go b/cmd/enbas/create.go index d6a831c..5468599 100644 --- a/cmd/enbas/create.go +++ b/cmd/enbas/create.go @@ -11,14 +11,17 @@ import ( type createCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string listTitle string listRepliesPolicy string } -func newCreateCommand(name, summary string) *createCommand { +func newCreateCommand(tlf topLevelFlags, name, summary string) *createCommand { command := createCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to create") @@ -35,7 +38,7 @@ func (c *createCommand) Execute() error { return flagNotSetError{flagText: resourceTypeFlag} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } diff --git a/cmd/enbas/delete.go b/cmd/enbas/delete.go index e0fc0f7..8708947 100644 --- a/cmd/enbas/delete.go +++ b/cmd/enbas/delete.go @@ -10,13 +10,16 @@ import ( type deleteCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string listID string } -func newDeleteCommand(name, summary string) *deleteCommand { +func newDeleteCommand(tlf topLevelFlags, name, summary string) *deleteCommand { command := deleteCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to delete") @@ -41,7 +44,7 @@ func (c *deleteCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.resourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } diff --git a/cmd/enbas/flags.go b/cmd/enbas/flags.go index df46ae5..912c9df 100644 --- a/cmd/enbas/flags.go +++ b/cmd/enbas/flags.go @@ -15,3 +15,7 @@ func (a *accountNames) Set(value string) error { return nil } + +type topLevelFlags struct { + configDir string +} diff --git a/cmd/enbas/follow.go b/cmd/enbas/follow.go index ff0148e..47b4f98 100644 --- a/cmd/enbas/follow.go +++ b/cmd/enbas/follow.go @@ -10,18 +10,19 @@ import ( type followCommand struct { *flag.FlagSet - resourceType string - accountName string - showReposts bool - notify bool - unfollow bool + topLevelFlags topLevelFlags + resourceType string + accountName string + showReposts bool + notify bool + unfollow bool } -func newFollowCommand(name, summary string, unfollow bool) *followCommand { +func newFollowCommand(tlf topLevelFlags, name, summary string, unfollow bool) *followCommand { command := followCommand{ - FlagSet: flag.NewFlagSet(name, flag.ExitOnError), - - unfollow: unfollow, + FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + unfollow: unfollow, + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to follow") @@ -44,7 +45,7 @@ func (c *followCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.resourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } @@ -53,7 +54,7 @@ func (c *followCommand) Execute() error { } func (c *followCommand) followAccount(gtsClient *client.Client) error { - accountID, err := getAccountID(gtsClient, false, c.accountName) + accountID, err := getAccountID(gtsClient, false, c.accountName, c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } diff --git a/cmd/enbas/login.go b/cmd/enbas/login.go index 46cde0e..00aab30 100644 --- a/cmd/enbas/login.go +++ b/cmd/enbas/login.go @@ -12,12 +12,15 @@ import ( type loginCommand struct { *flag.FlagSet + + topLevelFlags topLevelFlags instance string } -func newLoginCommand(name, summary string) *loginCommand { +func newLoginCommand(tlf topLevelFlags, name, summary string) *loginCommand { command := loginCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + topLevelFlags: tlf, instance: "", } @@ -88,7 +91,7 @@ Once you have the code please copy and paste it below. return fmt.Errorf("unable to verify the credentials; %w", err) } - loginName, err := config.SaveCredentials(account.Username, gtsClient.Authentication) + loginName, err := config.SaveCredentials(c.topLevelFlags.configDir, account.Username, gtsClient.Authentication) if err != nil { return fmt.Errorf("unable to save the authentication details; %w", err) } diff --git a/cmd/enbas/main.go b/cmd/enbas/main.go index 898da31..43690ca 100644 --- a/cmd/enbas/main.go +++ b/cmd/enbas/main.go @@ -88,6 +88,10 @@ func run() error { unblock: "unblock a resource (e.g. an account)", } + tlf := topLevelFlags{} + + flag.StringVar(&tlf.configDir, "config-dir", "", "specify your config directory") + flag.Usage = enbasUsageFunc(summaries) flag.Parse() @@ -105,33 +109,33 @@ func run() error { switch subcommand { case login: - executor = newLoginCommand(login, summaries[login]) + executor = newLoginCommand(tlf, login, summaries[login]) case version: executor = newVersionCommand(version, summaries[version]) case showResource: - executor = newShowCommand(showResource, summaries[showResource]) + executor = newShowCommand(tlf, showResource, summaries[showResource]) case switchAccount: - executor = newSwitchCommand(switchAccount, summaries[switchAccount]) + executor = newSwitchCommand(tlf, switchAccount, summaries[switchAccount]) case createResource: - executor = newCreateCommand(createResource, summaries[createResource]) + executor = newCreateCommand(tlf, createResource, summaries[createResource]) case deleteResource: - executor = newDeleteCommand(deleteResource, summaries[deleteResource]) + executor = newDeleteCommand(tlf, deleteResource, summaries[deleteResource]) case updateResource: - executor = newUpdateCommand(updateResource, summaries[updateResource]) + executor = newUpdateCommand(tlf, updateResource, summaries[updateResource]) case whoami: - executor = newWhoAmICommand(whoami, summaries[whoami]) + executor = newWhoAmICommand(tlf, whoami, summaries[whoami]) case add: - executor = newAddCommand(add, summaries[add]) + executor = newAddCommand(tlf, add, summaries[add]) case remove: - executor = newRemoveCommand(remove, summaries[remove]) + executor = newRemoveCommand(tlf, remove, summaries[remove]) case follow: - executor = newFollowCommand(follow, summaries[follow], false) + executor = newFollowCommand(tlf, follow, summaries[follow], false) case unfollow: - executor = newFollowCommand(unfollow, summaries[unfollow], true) + executor = newFollowCommand(tlf, unfollow, summaries[unfollow], true) case block: - executor = newBlockCommand(block, summaries[block], false) + executor = newBlockCommand(tlf, block, summaries[block], false) case unblock: - executor = newBlockCommand(unblock, summaries[unblock], true) + executor = newBlockCommand(tlf, unblock, summaries[unblock], true) default: flag.Usage() diff --git a/cmd/enbas/remove.go b/cmd/enbas/remove.go index f6e4c4a..f985c77 100644 --- a/cmd/enbas/remove.go +++ b/cmd/enbas/remove.go @@ -10,18 +10,20 @@ import ( type removeCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string fromResourceType string listID string accountNames accountNames } -func newRemoveCommand(name, summary string) *removeCommand { +func newRemoveCommand(tlf topLevelFlags, name, summary string) *removeCommand { emptyArr := make([]string, 0, 3) command := removeCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), accountNames: accountNames(emptyArr), + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the resource type to remove (e.g. account, note)") @@ -49,7 +51,7 @@ func (c *removeCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.fromResourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } @@ -123,7 +125,7 @@ func (c *removeCommand) removeNoteFromAccount(gtsClient *client.Client) error { return fmt.Errorf("unexpected number of accounts specified; want 1, got %d", len(c.accountNames)) } - accountID, err := getAccountID(gtsClient, false, c.accountNames[0]) + accountID, err := getAccountID(gtsClient, false, c.accountNames[0], c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } diff --git a/cmd/enbas/show.go b/cmd/enbas/show.go index 79014e4..e2b703f 100644 --- a/cmd/enbas/show.go +++ b/cmd/enbas/show.go @@ -11,6 +11,7 @@ import ( type showCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags myAccount bool showAccountRelationship bool showUserPreferences bool @@ -23,9 +24,10 @@ type showCommand struct { limit int } -func newShowCommand(name, summary string) *showCommand { +func newShowCommand(tlf topLevelFlags, name, summary string) *showCommand { command := showCommand{ - FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + topLevelFlags: tlf, } command.BoolVar(&command.myAccount, myAccountFlag, false, "set to true to lookup your account") @@ -65,7 +67,7 @@ func (c *showCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.resourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } @@ -91,7 +93,7 @@ func (c *showCommand) showAccount(gtsClient *client.Client) error { ) if c.myAccount { - account, err = getMyAccount(gtsClient) + account, err = getMyAccount(gtsClient, c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account details; %w", err) } @@ -234,7 +236,7 @@ func (c *showCommand) showLists(gtsClient *client.Client) error { } func (c *showCommand) showFollowers(gtsClient *client.Client) error { - accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName) + accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName, c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } @@ -254,7 +256,7 @@ func (c *showCommand) showFollowers(gtsClient *client.Client) error { } func (c *showCommand) showFollowing(gtsClient *client.Client) error { - accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName) + accountID, err := getAccountID(gtsClient, c.myAccount, c.accountName, c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("received an error while getting the account ID; %w", err) } diff --git a/cmd/enbas/switch.go b/cmd/enbas/switch.go index 152f3f4..fcef938 100644 --- a/cmd/enbas/switch.go +++ b/cmd/enbas/switch.go @@ -9,13 +9,15 @@ import ( type switchCommand struct { *flag.FlagSet + + topLevelFlags topLevelFlags toAccount string } -func newSwitchCommand(name, summary string) *switchCommand { +func newSwitchCommand(tlf topLevelFlags, name, summary string) *switchCommand { command := switchCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), - toAccount: "", + topLevelFlags: tlf, } command.StringVar(&command.toAccount, toAccountFlag, "", "the account to switch to") @@ -30,7 +32,7 @@ func (c *switchCommand) Execute() error { return flagNotSetError{flagText: toAccountFlag} } - if err := config.UpdateCurrentAccount(c.toAccount); err != nil { + if err := config.UpdateCurrentAccount(c.toAccount, c.topLevelFlags.configDir); err != nil { return fmt.Errorf("unable to switch accounts; %w", err) } diff --git a/cmd/enbas/update.go b/cmd/enbas/update.go index 0509411..0eae167 100644 --- a/cmd/enbas/update.go +++ b/cmd/enbas/update.go @@ -11,15 +11,17 @@ import ( type updateCommand struct { *flag.FlagSet + topLevelFlags topLevelFlags resourceType string listID string listTitle string listRepliesPolicy string } -func newUpdateCommand(name, summary string) *updateCommand { +func newUpdateCommand(tlf topLevelFlags, name, summary string) *updateCommand { command := updateCommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + topLevelFlags: tlf, } command.StringVar(&command.resourceType, resourceTypeFlag, "", "specify the type of resource to update") @@ -46,7 +48,7 @@ func (c *updateCommand) Execute() error { return unsupportedResourceTypeError{resourceType: c.resourceType} } - gtsClient, err := client.NewClientFromConfig() + gtsClient, err := client.NewClientFromConfig(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to create the GoToSocial client; %w", err) } diff --git a/cmd/enbas/usage.go b/cmd/enbas/usage.go index 4536ed3..e055937 100644 --- a/cmd/enbas/usage.go +++ b/cmd/enbas/usage.go @@ -53,7 +53,7 @@ func enbasUsageFunc(summaries map[string]string) func() { builder.WriteString("SUMMARY:\n enbas - A GoToSocial client for the terminal.\n\n") if binaryVersion != "" { - builder.WriteString("VERSION:\n " + binaryVersion + "\n\n") + builder.WriteString("VERSION:\n " + binaryVersion + "\n\n") } builder.WriteString("USAGE:\n enbas [flags]\n enbas [command]\n\nCOMMANDS:") diff --git a/cmd/enbas/whoami.go b/cmd/enbas/whoami.go index d87174a..60fb06c 100644 --- a/cmd/enbas/whoami.go +++ b/cmd/enbas/whoami.go @@ -9,11 +9,14 @@ import ( type whoAmICommand struct { *flag.FlagSet + + topLevelFlags topLevelFlags } -func newWhoAmICommand(name, summary string) *whoAmICommand { +func newWhoAmICommand(tlf topLevelFlags, name, summary string) *whoAmICommand { command := whoAmICommand{ FlagSet: flag.NewFlagSet(name, flag.ExitOnError), + topLevelFlags: tlf, } command.Usage = commandUsageFunc(name, summary, command.FlagSet) @@ -22,7 +25,7 @@ func newWhoAmICommand(name, summary string) *whoAmICommand { } func (c *whoAmICommand) Execute() error { - config, err := config.NewAuthenticationConfigFromFile() + config, err := config.NewCredentialsConfigFromFile(c.topLevelFlags.configDir) if err != nil { return fmt.Errorf("unable to load the credential config; %w", err) } diff --git a/internal/build/magefiles/mage.go b/internal/build/magefiles/mage.go index 3271229..933cf27 100644 --- a/internal/build/magefiles/mage.go +++ b/internal/build/magefiles/mage.go @@ -14,7 +14,7 @@ import ( ) const ( - binary = "enbas" + app = "enbas" defaultInstallPrefix = "/usr/local" envInstallPrefix = "ENBAS_INSTALL_PREFIX" envTestVerbose = "ENBAS_TEST_VERBOSE" @@ -65,7 +65,8 @@ func Build() error { return fmt.Errorf("unable to change to the project's root directory; %w", err) } - main := "./cmd/" + binary + main := "./cmd/" + app + binary := "./__build/" + app flags := ldflags() build := sh.RunCmd("go", "build") args := []string{"-ldflags=" + flags, "-o", binary} @@ -93,13 +94,13 @@ func Install() error { installPrefix = defaultInstallPrefix } - dest := filepath.Join(installPrefix, "bin", binary) + dest := filepath.Join(installPrefix, "bin", app) - if err := sh.Copy(dest, binary); err != nil { + if err := sh.Copy(dest, app); err != nil { return fmt.Errorf("unable to install %s; %w", dest, err) } - fmt.Printf("%s successfully installed to %s\n", binary, dest) + fmt.Printf("%s successfully installed to %s\n", app, dest) return nil } @@ -110,7 +111,7 @@ func Clean() error { return fmt.Errorf("unable to change to the project's root directory; %w", err) } - if err := sh.Rm(binary); err != nil { + if err := sh.Rm(app); err != nil { return err } diff --git a/internal/client/client.go b/internal/client/client.go index 6512178..11fd963 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -20,8 +20,8 @@ type Client struct { Timeout time.Duration } -func NewClientFromConfig() (*Client, error) { - config, err := config.NewAuthenticationConfigFromFile() +func NewClientFromConfig(configDir string) (*Client, error) { + config, err := config.NewCredentialsConfigFromFile(configDir) if err != nil { return nil, fmt.Errorf("unable to get the authentication configuration; %w", err) } @@ -70,7 +70,7 @@ func (g *Client) sendRequest(method string, url string, requestBody io.Reader, o request.Header.Set("User-Agent", g.UserAgent) if len(g.Authentication.AccessToken) > 0 { - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", g.Authentication.AccessToken)) + request.Header.Set("Authorization", "Bearer "+g.Authentication.AccessToken) } response, err := g.HTTPClient.Do(request) diff --git a/internal/config/credentials.go b/internal/config/credentials.go index b228fa0..f575239 100644 --- a/internal/config/credentials.go +++ b/internal/config/credentials.go @@ -7,8 +7,6 @@ import ( "os" "path/filepath" "strings" - - "codeflow.dananglin.me.uk/apollo/enbas/internal" ) const ( @@ -27,14 +25,25 @@ type Credentials struct { AccessToken string `json:"accessToken"` } -func SaveCredentials(username string, credentials Credentials) (string, error) { - if err := ensureConfigDir(); err != nil { +type CredentialsNotFoundError struct { + AccountName string +} + +func (e CredentialsNotFoundError) Error() string { + return "unable to find the credentials for the account '" + e.AccountName + "'" +} + +// SaveCredentials saves the credentials into the credentials file within the specified configuration +// directory. If the directory is not specified then the default directory is used. If the directory +// is not present, it will be created. +func SaveCredentials(configDir, username string, credentials Credentials) (string, error) { + if err := ensureConfigDir(calculateConfigDir(configDir)); err != nil { return "", fmt.Errorf("unable to ensure the configuration directory; %w", err) } var authConfig CredentialsConfig - filepath := credentialsConfigFile() + filepath := credentialsConfigFile(configDir) if _, err := os.Stat(filepath); err != nil { if !errors.Is(err, os.ErrNotExist) { @@ -43,7 +52,7 @@ func SaveCredentials(username string, credentials Credentials) (string, error) { authConfig.Credentials = make(map[string]Credentials) } else { - authConfig, err = NewAuthenticationConfigFromFile() + authConfig, err = NewCredentialsConfigFromFile(configDir) if err != nil { return "", fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err) } @@ -63,15 +72,34 @@ func SaveCredentials(username string, credentials Credentials) (string, error) { authConfig.Credentials[authenticationName] = credentials - if err := saveCredentialsConfigFile(authConfig); err != nil { + if err := saveCredentialsConfigFile(authConfig, configDir); err != nil { return "", fmt.Errorf("unable to save the authentication configuration to file; %w", err) } return authenticationName, nil } -func NewAuthenticationConfigFromFile() (CredentialsConfig, error) { - path := credentialsConfigFile() +func UpdateCurrentAccount(account string, configDir string) error { + credentialsConfig, err := NewCredentialsConfigFromFile(configDir) + if err != nil { + return fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err) + } + + if _, ok := credentialsConfig.Credentials[account]; !ok { + return CredentialsNotFoundError{account} + } + + credentialsConfig.CurrentAccount = account + + if err := saveCredentialsConfigFile(credentialsConfig, configDir); err != nil { + return fmt.Errorf("unable to save the authentication configuration to file; %w", err) + } + + return nil +} + +func NewCredentialsConfigFromFile(configDir string) (CredentialsConfig, error) { + path := credentialsConfigFile(configDir) file, err := os.Open(path) if err != nil { @@ -88,58 +116,12 @@ func NewAuthenticationConfigFromFile() (CredentialsConfig, error) { return authConfig, nil } -func UpdateCurrentAccount(account string) error { - authConfig, err := NewAuthenticationConfigFromFile() +func saveCredentialsConfigFile(authConfig CredentialsConfig, configDir string) error { + path := credentialsConfigFile(configDir) + + file, err := os.Create(path) if err != nil { - return fmt.Errorf("unable to retrieve the existing authentication configuration; %w", err) - } - - if _, ok := authConfig.Credentials[account]; !ok { - return fmt.Errorf("account %s is not found", account) - } - - authConfig.CurrentAccount = account - - if err := saveCredentialsConfigFile(authConfig); err != nil { - return fmt.Errorf("unable to save the authentication configuration to file; %w", err) - } - - return nil -} - -func credentialsConfigFile() string { - return filepath.Join(configDir(), credentialsFileName) -} - -func configDir() string { - rootDir, err := os.UserConfigDir() - if err != nil { - rootDir = "." - } - - return filepath.Join(rootDir, internal.ApplicationName) -} - -func ensureConfigDir() error { - dir := configDir() - - if _, err := os.Stat(dir); err != nil { - if errors.Is(err, os.ErrNotExist) { - if err := os.MkdirAll(dir, 0o750); err != nil { - return fmt.Errorf("unable to create %s; %w", dir, err) - } - } else { - return fmt.Errorf("unknown error received when running stat on %s; %w", dir, err) - } - } - - return nil -} - -func saveCredentialsConfigFile(authConfig CredentialsConfig) error { - file, err := os.Create(credentialsConfigFile()) - if err != nil { - return fmt.Errorf("unable to open the config file; %w", err) + return fmt.Errorf("unable to open %s; %w", path, err) } defer file.Close() @@ -153,3 +135,7 @@ func saveCredentialsConfigFile(authConfig CredentialsConfig) error { return nil } + +func credentialsConfigFile(configDir string) string { + return filepath.Join(calculateConfigDir(configDir), credentialsFileName) +} diff --git a/internal/config/directory.go b/internal/config/directory.go new file mode 100644 index 0000000..bcd1bad --- /dev/null +++ b/internal/config/directory.go @@ -0,0 +1,37 @@ +package config + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "codeflow.dananglin.me.uk/apollo/enbas/internal" +) + +func calculateConfigDir(configDir string) string { + if configDir != "" { + return configDir + } + + rootDir, err := os.UserConfigDir() + if err != nil { + rootDir = "." + } + + return filepath.Join(rootDir, internal.ApplicationName) +} + +func ensureConfigDir(configDir string) error { + if _, err := os.Stat(configDir); err != nil { + if errors.Is(err, os.ErrNotExist) { + if err := os.MkdirAll(configDir, 0o750); err != nil { + return fmt.Errorf("unable to create %s; %w", configDir, err) + } + } else { + return fmt.Errorf("unknown error received when running stat on %s; %w", configDir, err) + } + } + + return nil +}