Skip to content

Commit 3cdb63e

Browse files
authored
Merge pull request #80 from heetch/rog-005-fix-flags
backend/flags: avoiding setting values that aren't specified
2 parents c6edf60 + 0782bac commit 3cdb63e

File tree

2 files changed

+154
-148
lines changed

2 files changed

+154
-148
lines changed

backend/flags/flags.go

+38-31
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"flag"
66
"fmt"
7+
"os"
78
"reflect"
89
"time"
910

@@ -12,11 +13,15 @@ import (
1213
)
1314

1415
// Backend that loads configuration from the command line flags.
15-
type Backend struct{}
16+
type Backend struct {
17+
flags *flag.FlagSet
18+
}
1619

1720
// NewBackend creates a flags backend.
1821
func NewBackend() *Backend {
19-
return new(Backend)
22+
return &Backend{
23+
flags: flag.CommandLine,
24+
}
2025
}
2126

2227
// LoadStruct takes a struct config, define flags based on it and parse the command line args.
@@ -34,80 +39,90 @@ func (b *Backend) LoadStruct(ctx context.Context, cfg *confita.StructConfig) err
3439
switch {
3540
case f.Value.Type().String() == "time.Duration":
3641
var val time.Duration
37-
flag.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description)
42+
b.flags.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description)
3843
if f.Short != "" {
39-
flag.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description))
44+
b.flags.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description))
4045
}
4146
// this function must be executed after the flag.Parse call.
4247
defer func() {
4348
// if the user has set the flag, save the value in the field.
44-
if isFlagSet(f) {
49+
if b.isFlagSet(f) {
4550
f.Value.SetInt(int64(val))
4651
}
4752
}()
4853
case k == reflect.Bool:
4954
var val bool
50-
flag.BoolVar(&val, f.Key, f.Default.Bool(), f.Description)
55+
b.flags.BoolVar(&val, f.Key, f.Default.Bool(), f.Description)
5156
if f.Short != "" {
52-
flag.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description))
57+
b.flags.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description))
5358
}
5459
defer func() {
55-
if isFlagSet(f) {
60+
if b.isFlagSet(f) {
5661
f.Value.SetBool(val)
5762
}
5863
}()
5964
case k >= reflect.Int && k <= reflect.Int64:
6065
var val int
61-
flag.IntVar(&val, f.Key, int(f.Default.Int()), f.Description)
66+
b.flags.IntVar(&val, f.Key, int(f.Default.Int()), f.Description)
6267
if f.Short != "" {
63-
flag.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description))
68+
b.flags.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description))
6469
}
6570
defer func() {
66-
if isFlagSet(f) {
71+
if b.isFlagSet(f) {
6772
f.Value.SetInt(int64(val))
6873
}
6974
}()
7075
case k >= reflect.Uint && k <= reflect.Uint64:
7176
var val uint64
72-
flag.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description)
77+
b.flags.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description)
7378
if f.Short != "" {
74-
flag.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description))
79+
b.flags.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description))
7580
}
7681
defer func() {
77-
if isFlagSet(f) {
82+
if b.isFlagSet(f) {
7883
f.Value.SetUint(val)
7984
}
8085
}()
8186
case k >= reflect.Float32 && k <= reflect.Float64:
8287
var val float64
83-
flag.Float64Var(&val, f.Key, f.Default.Float(), f.Description)
88+
b.flags.Float64Var(&val, f.Key, f.Default.Float(), f.Description)
8489
if f.Short != "" {
85-
flag.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description))
90+
b.flags.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description))
8691
}
8792
defer func() {
88-
if isFlagSet(f) {
93+
if b.isFlagSet(f) {
8994
f.Value.SetFloat(val)
9095
}
9196
}()
9297
case k == reflect.String:
9398
var val string
94-
flag.StringVar(&val, f.Key, f.Default.String(), f.Description)
99+
b.flags.StringVar(&val, f.Key, f.Default.String(), f.Description)
95100
if f.Short != "" {
96-
flag.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description))
101+
b.flags.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description))
97102
}
98103
defer func() {
99-
if isFlagSet(f) {
104+
if b.isFlagSet(f) {
100105
f.Value.SetString(val)
101106
}
102107
}()
103108
default:
104-
flag.Var(&flagValue{f}, f.Key, f.Description)
109+
b.flags.Var(&flagValue{f}, f.Key, f.Description)
105110
}
106111
}
107112

108-
flag.Parse()
113+
// Note: in the usual case, when b.flags is flag.CommandLine, this will exit
114+
// rather than returning an error.
115+
return b.flags.Parse(os.Args[1:])
116+
}
109117

110-
return nil
118+
func (b *Backend) isFlagSet(config *confita.FieldConfig) bool {
119+
ok := false
120+
b.flags.Visit(func(f *flag.Flag) {
121+
if f.Name == config.Key || f.Name == config.Short {
122+
ok = true
123+
}
124+
})
125+
return ok
111126
}
112127

113128
type flagValue struct {
@@ -139,11 +154,3 @@ func (b *Backend) Name() string {
139154
func shortDesc(description string) string {
140155
return fmt.Sprintf("%s (short)", description)
141156
}
142-
143-
func isFlagSet(config *confita.FieldConfig) bool {
144-
flagset := make(map[*confita.FieldConfig]bool)
145-
flag.Visit(func(f *flag.Flag) { flagset[config] = true })
146-
147-
_, ok := flagset[config]
148-
return ok
149-
}

0 commit comments

Comments
 (0)