diff --git a/cmd/main.go b/cmd/main.go index 10d6dc3..303e1d3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -60,9 +60,9 @@ func main() { fmt.Printf("Unable to update version: %v\n", err) } + fileHandle, _ := os.OpenFile(configDirString+"/"+configFileLocation, os.O_RDWR, 0666) + defer fileHandle.Close() if needsUpdate { - fileHandle, _ := os.OpenFile(configDirString+"/"+configFileLocation, os.O_RDWR, 0666) - defer fileHandle.Close() versionedTerraform.UpdateConfig(*fileHandle) } @@ -79,7 +79,7 @@ func main() { } // Check if stable version of terraform is required - needsStable, err = versionedTerraform.ConfigRequiresStable(configDir, configFileLocation) + needsStable, err = versionedTerraform.ConfigRequiresStable(*fileHandle) if err != nil { fmt.Fprintf(os.Stderr, "Unable to open config file, defaulting to stable versions of terraform only") } diff --git a/configManagement.go b/configManagement.go index a93d643..9beb7a8 100644 --- a/configManagement.go +++ b/configManagement.go @@ -18,8 +18,8 @@ type configStruct struct { } //ConfigRequiresStable returns bool, error only false if StableOnly: false is set in configuration file -func ConfigRequiresStable(fileSystem fs.FS, configFile string) (bool, error) { - fileHandle, err := fileSystem.Open(configFile) +func ConfigRequiresStable(File os.File) (bool, error) { + fileHandle, err := os.Open(File.Name()) if err != nil { return true, err } @@ -125,19 +125,28 @@ func LoadInstalledVersions(fileSystem fs.FS) ([]SemVersion, error) { // adding: // a new date to the last updated field // the available versions listed on terraforms website -// teh status of if the user wants only stable releases -func UpdateConfig(File os.File) error { +// the status of if the user wants only stable releases +func UpdateConfig(File os.File, timeNow ...time.Time) error { configValues := new(configStruct) configValues.AvailableVersions, _ = GetVersionList() + configValues.StableOnly, _ = ConfigRequiresStable(File) - timeNow := time.Now() - configValues.LastUpdate = timeNow.Unix() + var t time.Time + if len(timeNow) > 0 { + t = timeNow[0] + } else { + t = time.Now() + } + configValues.LastUpdate = t.Unix() File.Truncate(0) File.Seek(0, 0) - lineToByte := []byte(fmt.Sprintf("LastUpdate: %d\n", configValues.LastUpdate)) + lineToByte := []byte(fmt.Sprintf("StableOnly: %+v\n", configValues.StableOnly)) + File.Write(lineToByte) + + lineToByte = []byte(fmt.Sprintf("LastUpdate: %d\n", configValues.LastUpdate)) File.Write(lineToByte) lineToByte = []byte(fmt.Sprintf("AvailableVersions: %+v\n", configValues.AvailableVersions)) File.Write(lineToByte) diff --git a/configManagement_test.go b/configManagement_test.go index e97dca2..2fa713f 100644 --- a/configManagement_test.go +++ b/configManagement_test.go @@ -2,8 +2,11 @@ package versionedTerraform import ( "fmt" + "io" + "os" "reflect" "sort" + "strings" "testing" "testing/fstest" "time" @@ -114,3 +117,73 @@ func TestInstalledVersions(t *testing.T) { }) } + +func TestConfigRequiresStable(t *testing.T) { + availableVersions, _ := GetVersionList() + versions := strings.Join(availableVersions, " ") + cases := []struct { + name, content, want string + timeNow time.Time + }{ + {"StableOnly True", "StableOnly: true\n" + + "LastUpdate: 1674481203\n" + + "AvailableVersions: [1.3.7]", + "StableOnly: true\n" + + "LastUpdate: 1286705410\n" + + "AvailableVersions: [" + + versions + "]\n", + time.Date(2010, 10, 10, 10, 10, 10, 10, time.UTC)}, + {"StableOnly False", "StableOnly: false\n" + + "LastUpdate: 1674481203\n" + + "AvailableVersions: [1.3.7]", + "StableOnly: false\n" + + "LastUpdate: 1286705410\n" + + "AvailableVersions: [" + + versions + "]\n", + time.Date(2010, 10, 10, 10, 10, 10, 10, time.UTC)}, + {"StableOnly not found", "LastUpdate: 1674481203\n" + + "AvailableVersions: [1.3.7]", + "StableOnly: true\n" + + "LastUpdate: 1286705410\n" + + "AvailableVersions: [" + + versions + "]\n", + time.Date(2010, 10, 10, 10, 10, 10, 10, time.UTC)}, + } + + for _, c := range cases { + t.Run("Test: "+c.name, func(t *testing.T) { + t.Parallel() + + tempDir := os.TempDir() + tempFile, err := os.Create(tempDir + "/config") + defer tempFile.Close() + + if err != nil { + t.Errorf("Unable to execute test : %v", err) + } + + UpdateConfig(*tempFile, c.timeNow) + + tempFile.Seek(0, 0) + data := make([]byte, 1024) + var got string + for { + n, err := tempFile.Read(data) + if err == io.EOF { + break + } + if err != nil { + t.Errorf("File reading error : %v", err) + return + } + got += string(data[:n]) + } + + if !reflect.DeepEqual(got, c.want) { + t.Errorf("%v test failed to meet conditions", c.name) + fmt.Fprintf(os.Stdout, "%v\n", c.want) + fmt.Fprintf(os.Stdout, "%v\n", got) + } + }) + } +}