Skip to content

Commit

Permalink
Add ability to use a flag.FlagSet (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
BoRuDar authored Nov 4, 2021
1 parent ebbd61a commit 25dac63
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

.idea
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ Flags:
-first_name "Description (default: default_value)"
```
And program execution will be terminated.
#### Options for _NewFlagProvider_
* WithFlagSet - set a custom FlagSet
* WithErrorHandler - to catch and handle errors from the init phase (before actually getting data from flags)

### File provider
Doesn't require any specific tags. JSON and YAML formats of files are supported.
Expand Down
24 changes: 24 additions & 0 deletions configurator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,27 @@ func TestFallBackToDefault(t *testing.T) {

assert.Equal(t, "default_val", cfg.NameFlag)
}

func TestSetOnFailFn(t *testing.T) {
var (
cfg = struct {
Name string `default:"test_name"`
}{}
onFailFn = func(err error) {
if err.Error() != "configurator: field [Name] with tags [default:\"test_name\"] cannot be set" {
t.Fatalf("unexpected error: %v", err)
}
}
)

c, err := New(
&cfg,
NewEnvProvider(),
)
if err != nil {
t.Fatal("unexpected err: ", err)
}

c.SetOnFailFn(onFailFn)
c.InitValues()
}
8 changes: 6 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package configuration

import "errors"
import (
"errors"
"fmt"
)

var (
ErrEmptyValue = errors.New("empty value")
ErrEmptyValue = errors.New("empty value")
ErrNotAPointer = fmt.Errorf("not a pointer to a struct")
)
54 changes: 44 additions & 10 deletions flagProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,61 @@ package configuration
import (
"flag"
"fmt"
"os"
"reflect"
"strings"
)

const flagSeparator = "|"

type FlagProviderOption func(*flagProvider)

// NewFlagProvider creates a new provider to fetch data from flags like: --flag_name some_value
func NewFlagProvider(ptrToCfg interface{}) flagProvider {
func NewFlagProvider(ptrToCfg interface{}, opts ...FlagProviderOption) flagProvider {
fp := flagProvider{
flagsValues: map[string]func() *string{},
flags: map[string]*flagData{},
flagsValues: map[string]func() *string{},
flags: map[string]*flagData{},
flagSet: flag.CommandLine,
errorHandler: func(err error) {},
}

for _, f := range opts {
f(&fp)
}
fp.initFlagProvider(ptrToCfg)

flag.Parse()
fp.errorHandler(fp.initFlagProvider(ptrToCfg))

fp.errorHandler(fp.flagSet.Parse(os.Args[1:]))

return fp
}

// FlagSet is the part of flag.FlagSet that NewFlagProvider uses
type FlagSet interface {
Parse([]string) error
String(string, string, string) *string
}

// WithFlagSet allows the flag.FlagSet to be provided to NewFlagProvider.
// This allows compatability with other flag parsing utilities.
func WithFlagSet(s FlagSet) FlagProviderOption {
return func(fp *flagProvider) {
fp.flagSet = s
}
}

// WithErrorHandler captures errors from fp.initFlagProvider and fp.flagSet.Parse
func WithErrorHandler(fn func(err error)) FlagProviderOption {
return func(fp *flagProvider) {
fp.errorHandler = fn
}
}

type flagProvider struct {
flagsValues map[string]func() *string
flags map[string]*flagData
flagsValues map[string]func() *string
flags map[string]*flagData
flagSet FlagSet
errorHandler func(err error)
}

type flagData struct {
Expand All @@ -41,7 +75,7 @@ func (fp flagProvider) initFlagProvider(i interface{}) error {
t = t.Elem()
v = v.Elem()
default:
return fmt.Errorf("not a pointer to a struct")
return ErrNotAPointer
}

for i := 0; i < t.NumField(); i++ {
Expand All @@ -62,7 +96,7 @@ func (fp flagProvider) initFlagProvider(i interface{}) error {
continue
}

fp.setFlagCallbacks(tField)
fp.errorHandler(fp.setFlagCallbacks(tField))
}
return nil
}
Expand All @@ -78,7 +112,7 @@ func (fp flagProvider) setFlagCallbacks(field reflect.StructField) error {
}
fp.flags[fd.key] = fd

valStr := flag.String(fd.key, fd.defaultVal, fd.usage)
valStr := fp.flagSet.String(fd.key, fd.defaultVal, fd.usage)
fp.flagsValues[fd.key] = func() *string {
return valStr
}
Expand Down
84 changes: 84 additions & 0 deletions flagProvider_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package configuration

import (
"flag"
"os"
"reflect"
"testing"
Expand Down Expand Up @@ -130,3 +131,86 @@ func TestGetFlagData(t *testing.T) {
})
}
}

func TestFlagProvider_CustomFlagSet(t *testing.T) {
type testStruct struct {
Name string `flag:"flag_name3||Description"`
}
testObj := testStruct{}
os.Args = []string{"smth", "-flag_name3=flag_value"}

fieldType := reflect.TypeOf(&testObj).Elem().Field(0)
fieldVal := reflect.ValueOf(&testObj).Elem().Field(0)

fs := flag.NewFlagSet("test", flag.ContinueOnError)
provider := NewFlagProvider(&testObj, WithFlagSet(fs))
testValue := "flag_value"

if err := provider.Provide(fieldType, fieldVal); err != nil {
t.Fatalf("cannot set value: %v", err)
}

assert.Equal(t, testValue, testObj.Name)
}

func TestFlagProvider_WithErrorHandler(t *testing.T) {
type testStruct struct {
Name string `flag:"flag_name4||Description"`
}
testObj := testStruct{}
os.Args = []string{"smth", "-flag_name4=flag_value"}

fieldType := reflect.TypeOf(&testObj).Elem().Field(0)
fieldVal := reflect.ValueOf(&testObj).Elem().Field(0)

eh := func(err error) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
provider := NewFlagProvider(&testObj, WithErrorHandler(eh))
testValue := "flag_value"

if err := provider.Provide(fieldType, fieldVal); err != nil {
t.Fatalf("cannot set value: %v", err)
}

assert.Equal(t, testValue, testObj.Name)
}

func TestFlagProvider_WithErrorHandlerAndErr(t *testing.T) {
type testStruct struct {
Name string `flag:"flag_name5||||"`
}
testObj := testStruct{}
os.Args = []string{""}
counter := 0

eh := func(err error) {
counter++

if err != nil && err.Error() != "flagProvider: wrong flag definition [flag_name5||||]" {
t.Fatalf("unexpected error")
}
}
_ = NewFlagProvider(&testObj, WithErrorHandler(eh))

if counter != 3 {
t.Fatal("error must be called 3 times")
}
}

func TestFlagProvider_Error(t *testing.T) {
type testStruct struct {
Name string `flag:"flag_name5||||"`
}
testObj := testStruct{}
os.Args = []string{""}

eh := func(err error) {
if err != nil && err.Error() != ErrNotAPointer.Error() {
t.Fatalf("unexpected error: %v", err)
}
}
_ = NewFlagProvider(testObj, WithErrorHandler(eh))
}

0 comments on commit 25dac63

Please sign in to comment.