Skip to content

Commit

Permalink
test: always use our testing certificates
Browse files Browse the repository at this point in the history
  • Loading branch information
justinclift committed Mar 23, 2019
1 parent f0a09af commit e19bb8f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
6 changes: 3 additions & 3 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ func commit(args []string) error {
if _, err = os.Stat(filepath.Join(".dio", db, "db")); os.IsNotExist(err) {
// At the moment, since there's no better way to check for the existence of a remote database, we just
// grab the list of the users databases and check against that
dbList, err := getDatabases(cloud, certUser)
if err != nil {
return err
dbList, errInner := getDatabases(cloud, certUser)
if errInner != nil {
return errInner
}
for _, j := range dbList {
if db == j.Name {
Expand Down
40 changes: 31 additions & 9 deletions cmd/dio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"flag"
Expand Down Expand Up @@ -90,11 +91,11 @@ func (s *DioSuite) SetUpSuite(c *chk.C) {
s.config = filepath.Join(tempDir, "config.toml")
f, err := os.Create(s.config)
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}
d, err := os.Getwd()
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}
origDir = d
mockAddr := "https://localhost:5551"
Expand All @@ -106,45 +107,66 @@ func (s *DioSuite) SetUpSuite(c *chk.C) {
filepath.Join(d, "..", "test_data", "default.cert.pem"),
mockAddr)
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}

// Drop any old config loaded automatically by viper, and use our temporary test config instead
viper.Reset()
viper.SetConfigFile(s.config)
if err = viper.ReadInConfig(); err != nil {
log.Fatalf("Error loading test config file: %s", err.Error())
log.Fatalf("Error loading test config file: %s", err)
return
}
cloud = viper.GetString("general.cloud")

// Use our testing certificates
ourCAPool := x509.NewCertPool()
chainFile, err := ioutil.ReadFile(filepath.Join(d, "..", "test_data", "ca-chain-docker.cert.pem"))
if err != nil {
log.Fatalln(err)
}
ok := ourCAPool.AppendCertsFromPEM(chainFile)
if !ok {
log.Fatalln("Error when loading certificate chain file")
}
testCert := filepath.Join(d, "..", "test_data", "default.cert.pem")
cert, err := tls.LoadX509KeyPair(testCert, testCert)
if err != nil {
log.Fatalln(err)
}
TLSConfig.Certificates = []tls.Certificate{cert}
certUser, _, err = getUserAndServer()
if err != nil {
log.Fatalln(err)
}

// Add test database
s.dbName = "19kB.sqlite"
db, err := ioutil.ReadFile(filepath.Join(d, "..", "test_data", s.dbName))
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}
s.dbFile = filepath.Join(tempDir, s.dbName)
err = ioutil.WriteFile(s.dbFile, db, 0644)
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}

// Set the last modified date of the database file to a known value
err = os.Chtimes(s.dbFile, time.Now(), time.Date(2019, time.March, 15, 18, 1, 0, 0, time.UTC))
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}

// Add a test licence
lic, err := ioutil.ReadFile(filepath.Join(d, "..", "LICENSE"))
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}
licFile = filepath.Join(tempDir, "test.licence")
err = ioutil.WriteFile(licFile, lic, 0644)
if err != nil {
log.Fatalln(err.Error())
log.Fatalln(err)
}

// If not told otherwise, redirect command output to /dev/null
Expand Down
4 changes: 2 additions & 2 deletions cmd/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func dbChanged(db string, meta metaData) (changed bool, err error) {

// Retrieves the list of databases available to the user
var getDatabases = func(url string, user string) (dbList []dbListEntry, err error) {
resp, body, errs := rq.New().TLSClientConfig(&TLSConfig).Get(fmt.Sprintf("%s/%s", url, user)).End()
resp, body, errs := rq.New().TLSClientConfig(&TLSConfig).Get(fmt.Sprintf("%s/%s", url, user)).EndBytes()
if errs != nil {
e := fmt.Sprintln("Errors when retrieving the database list:")
for _, err := range errs {
Expand All @@ -158,7 +158,7 @@ var getDatabases = func(url string, user string) (dbList []dbListEntry, err erro
return
}
defer resp.Body.Close()
err = json.Unmarshal([]byte(body), &dbList)
err = json.Unmarshal(body, &dbList)
if err != nil {
_, errInner := fmt.Fprintf(fOut, "Error retrieving database list: '%v'\n", err.Error())
if errInner != nil {
Expand Down

0 comments on commit e19bb8f

Please sign in to comment.