diff --git a/default_test.go b/default_test.go index 58f1392..26971a6 100644 --- a/default_test.go +++ b/default_test.go @@ -1,6 +1,7 @@ package yaconf import ( + "reflect" "testing" "time" @@ -28,3 +29,48 @@ func TestFillDefaultValues(t *testing.T) { require.Equal(t, 2*time.Second, config.DB.Timeout) require.Equal(t, true, config.DB.Debug) } + +func TestSetDefaultValue(t *testing.T) { + type db struct { + Host string `yaconf:"default=localhost"` + } + config := &struct { + LogLevel string `yaconf:"default=info"` + LogFile string + private string `yaconf:"default=test"` + DB *db + }{ + DB: &db{}, + } + + err := setDefaultValue(reflect.ValueOf(config)) + require.Nil(t, err) + require.Empty(t, config.private) + require.Equal(t, "localhost", config.DB.Host) +} + +func TestIsDuration(t *testing.T) { + require.True(t, isDuration(reflect.ValueOf(time.Duration(1000)))) + require.False(t, isDuration(reflect.ValueOf(int64(22)))) +} + +func TestIsInt(t *testing.T) { + require.True(t, isInt(reflect.ValueOf(int8(1)))) + require.True(t, isInt(reflect.ValueOf(int16(2)))) + require.True(t, isInt(reflect.ValueOf(int32(3)))) + require.True(t, isInt(reflect.ValueOf(int64(4)))) + require.True(t, isInt(reflect.ValueOf(int8(5)))) + + require.False(t, isInt(reflect.ValueOf(byte('6')))) + require.False(t, isInt(reflect.ValueOf("22"))) +} + +func TestIsUint(t *testing.T) { + require.True(t, isUint(reflect.ValueOf(uint8(1)))) + require.True(t, isUint(reflect.ValueOf(uint16(2)))) + require.True(t, isUint(reflect.ValueOf(uint32(3)))) + require.True(t, isUint(reflect.ValueOf(uint64(4)))) + require.True(t, isUint(reflect.ValueOf(uint8(5)))) + + require.False(t, isUint(reflect.ValueOf("22"))) +} diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..035cf13 --- /dev/null +++ b/validate.go @@ -0,0 +1,64 @@ +package yaconf + +import ( + "fmt" + "reflect" + "strings" +) + +type validator interface { + Validate() error +} + +func addPrefix(prefix, name string) string { + if prefix == "" { + return name + } + + return fmt.Sprintf("%s.%s", prefix, name) +} + +func validate(config interface{}, prefix string) []string { + t := reflect.TypeOf(config) + v := reflect.ValueOf(config) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + + errors := []string{} + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + + name := f.Tag.Get("yaml") + if name == "" { + name = f.Name + } + + if strings.Contains(f.Tag.Get("yaconf"), "required") && v.Field(i).IsZero() { + errors = append(errors, fmt.Sprintf("%s is required", addPrefix(prefix, name))) + continue + } + + if f.Type.Kind() == reflect.Struct { + errors = append(errors, validate(v.Field(i).Interface(), addPrefix(prefix, name))...) + continue + } + + if f.Type.Kind() == reflect.Ptr { + if f.Type.Elem().Kind() != reflect.Struct { + continue + } + if v.Field(i).IsNil() { + errors = append(errors, validate(reflect.New(v.Field(i).Type().Elem()).Interface(), addPrefix(prefix, name))...) + } else { + errors = append(errors, validate(v.Field(i).Elem().Interface(), addPrefix(prefix, name))...) + } + } + + } + return errors +} diff --git a/validate_test.go b/validate_test.go new file mode 100644 index 0000000..dbeb259 --- /dev/null +++ b/validate_test.go @@ -0,0 +1,43 @@ +package yaconf + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAddPrefix(t *testing.T) { + s := addPrefix("", "bar") + require.Equal(t, "bar", s) + + s = addPrefix("foo", "bar") + require.Equal(t, "foo.bar", s) +} + +func TestValidate(t *testing.T) { + type db struct { + Host string `yaconf:"required"` + } + pid := "/var/run/test.pid" + config := struct { + LogLevel string `yaml:"log_level" yaconf:"required,default=info"` + LogFile string `yaml:"log_file" yaconf:"required"` + PID *string `yaconf:"required"` + DB1 db + DB2 *db + DB3 *db + private string + }{ + PID: &pid, + DB3: &db{ + Host: "127.0.0.1", + }, + } + + errors := validate(&config, "") + require.Equal(t, 4, len(errors), errors) + require.Equal(t, "log_level is required", errors[0]) + require.Equal(t, "log_file is required", errors[1]) + require.Equal(t, "DB1.Host is required", errors[2]) + require.Equal(t, "DB2.Host is required", errors[3]) +} diff --git a/yaconf.go b/yaconf.go index a21d1e5..c4ba006 100644 --- a/yaconf.go +++ b/yaconf.go @@ -3,69 +3,11 @@ package yaconf import ( "fmt" "os" - "reflect" "strings" "gopkg.in/yaml.v3" ) -type validator interface { - Validate() error -} - -func addPrefix(prefix, name string) string { - if prefix == "" { - return name - } - - return fmt.Sprintf("%s.%s", prefix, name) -} - -func validate(config interface{}, prefix string) []string { - t := reflect.TypeOf(config) - v := reflect.ValueOf(config) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - - errors := []string{} - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - if !f.IsExported() { - continue - } - - name := f.Tag.Get("yaml") - if name == "" { - name = f.Name - } - - if f.Tag.Get("yaconf") == "required" && v.Field(i).IsZero() { - errors = append(errors, fmt.Sprintf("%s is required", addPrefix(prefix, name))) - continue - } - - if f.Type.Kind() == reflect.Struct { - errors = append(errors, validate(v.Field(i).Interface(), addPrefix(prefix, name))...) - continue - } - - if f.Type.Kind() == reflect.Ptr { - if f.Type.Elem().Kind() != reflect.Struct { - continue - } - if v.Field(i).IsNil() { - errors = append(errors, validate(reflect.New(f.Type.Elem()), addPrefix(prefix, name))...) - } else { - errors = append(errors, validate(v.Field(i).Elem().Interface(), addPrefix(prefix, name))...) - } - } - - } - return errors -} - // Read config from file func Read(filename string, config interface{}) error { data, err := os.ReadFile(filename) diff --git a/yaconf_test.go b/yaconf_test.go new file mode 100644 index 0000000..c6bf6bc --- /dev/null +++ b/yaconf_test.go @@ -0,0 +1,51 @@ +package yaconf + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +type testConfig3 struct { +} + +func (c *testConfig3) Validate() error { + return fmt.Errorf("invalid") +} + +func TestRead(t *testing.T) { + file, err := os.CreateTemp("", "") + require.Nil(t, err) + + defer func() { + os.Remove(file.Name()) + }() + + _, err = file.Write([]byte(`log_level: info`)) + require.Nil(t, err) + + err = file.Close() + require.Nil(t, err) + + // Test success + config := struct { + LogLevel string `yaml:"log_level"` + }{} + + err = Read(file.Name(), &config) + require.Nil(t, err) + + // Test with errors + config2 := struct { + LogFile string `yaml:"log_file" yaconf:"required"` + }{} + + err = Read(file.Name(), &config2) + require.NotNil(t, err) + + // Test with custom validator + err = Read(file.Name(), &testConfig3{}) + require.NotNil(t, err) +}