diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index 954394e..0012e0d 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -4,12 +4,12 @@ import ( "fmt" "os" "path/filepath" - "time" - "github.com/spf13/viper" + viperlib "github.com/spf13/viper" ) var initialised bool = false +var viper *viperlib.Viper = viperlib.New() // EnsureInitialised reads the config. Will quit if config is invalid func EnsureInitialised() { @@ -25,7 +25,7 @@ func EnsureInitialised() { // TODO - allow env var for config if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + if _, ok := err.(viperlib.ConfigFileNotFoundError); ok { // Config file not found; ignore error if desired } else { fmt.Printf("Error loading config file: %s\n", err) @@ -36,7 +36,10 @@ func EnsureInitialised() { } } func getConfigPath() string { - var path string + path := os.Getenv("DEVCONTAINERX_CONFIG_PATH") + if path != "" { + return path + } if os.Getenv("HOME") != "" { path = filepath.Join("$HOME", ".devcontainer-cli/") } else { @@ -58,14 +61,6 @@ func GetExperimentalFeaturesEnabled() bool { EnsureInitialised() return viper.GetBool("experimental") } -func GetLastUpdateCheck() time.Time { - EnsureInitialised() - return viper.GetTime("lastUpdateCheck") -} -func SetLastUpdateCheck(t time.Time) { - EnsureInitialised() - viper.Set("lastUpdateCheck", t) -} func GetAll() map[string]interface{} { EnsureInitialised() return viper.AllSettings() diff --git a/internal/pkg/status/status.go b/internal/pkg/status/status.go new file mode 100644 index 0000000..e8e43b6 --- /dev/null +++ b/internal/pkg/status/status.go @@ -0,0 +1,72 @@ +package status + +import ( + "fmt" + "os" + "path/filepath" + "time" + + viperlib "github.com/spf13/viper" +) + +var initialised bool = false +var viper *viperlib.Viper = viperlib.New() + +// EnsureInitialised reads the config. Will quit if config is invalid +func EnsureInitialised() { + if !initialised { + viper.SetConfigName("devcontainer-cli-status") + viper.SetConfigType("json") + + viper.AddConfigPath(getConfigPath()) + + // TODO - allow env var for config + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viperlib.ConfigFileNotFoundError); ok { + // Config file not found; ignore error if desired + } else { + fmt.Printf("Error loading status file: %s\n", err) + os.Exit(1) + } + } + initialised = true + } +} +func getConfigPath() string { + path := os.Getenv("DEVCONTAINERX_STATUS_PATH") + if path != "" { + return path + } + if os.Getenv("HOME") != "" { + path = filepath.Join("$HOME", ".devcontainer-cli/") + } else { + // if HOME not set, assume Windows and use USERPROFILE env var + path = filepath.Join("$USERPROFILE", ".devcontainer-cli/") + } + return os.ExpandEnv(path) +} + +func GetLastUpdateCheck() time.Time { + EnsureInitialised() + return viper.GetTime("lastUpdateCheck") +} +func SetLastUpdateCheck(t time.Time) { + EnsureInitialised() + viper.Set("lastUpdateCheck", t) +} +func GetAll() map[string]interface{} { + EnsureInitialised() + return viper.AllSettings() +} + +func SaveStatus() error { + EnsureInitialised() + configPath := getConfigPath() + configPath = os.ExpandEnv(configPath) + if err := os.MkdirAll(configPath, 0755); err != nil { + return err + } + configFilePath := filepath.Join(configPath, "devcontainer-cli-status.json") + fmt.Printf("HERE: %q\n", configFilePath) + return viper.WriteConfigAs(configFilePath) +} diff --git a/internal/pkg/update/update.go b/internal/pkg/update/update.go index 7154ca5..3444739 100644 --- a/internal/pkg/update/update.go +++ b/internal/pkg/update/update.go @@ -2,11 +2,12 @@ package update import ( "fmt" + "os" "time" "github.com/blang/semver" "github.com/rhysd/go-github-selfupdate/selfupdate" - "github.com/stuartleeks/devcontainer-cli/internal/pkg/config" + "github.com/stuartleeks/devcontainer-cli/internal/pkg/status" ) func CheckForUpdate(currentVersion string) (*selfupdate.Release, error) { @@ -29,7 +30,12 @@ func CheckForUpdate(currentVersion string) (*selfupdate.Release, error) { func PeriodicCheckForUpdate(currentVersion string) { const checkInterval time.Duration = 24 * time.Hour - lastCheck := config.GetLastUpdateCheck() + if os.Getenv("DEVCONTAINERX_SKIP_UPDATE") != "" { + // Skip update check + return + } + + lastCheck := status.GetLastUpdateCheck() if time.Now().Before(lastCheck.Add(checkInterval)) { return @@ -40,8 +46,8 @@ func PeriodicCheckForUpdate(currentVersion string) { fmt.Printf("Error checking for updates: %s", err) } - config.SetLastUpdateCheck(time.Now()) - if err = config.SaveConfig(); err != nil { + status.SetLastUpdateCheck(time.Now()) + if err = status.SaveStatus(); err != nil { fmt.Printf("Error saving last update check time: :%s\n", err) }