From 320ae1c60a4b034642c4f50ddf437eee45e8b8df Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Sat, 29 Oct 2022 02:28:58 +0300 Subject: [PATCH] lib/envflag: small refactoring after 518c340ae3ed726227ed73d5e7e57027a654deac and 02096e06d020b1828c285ca68b16074d2e6d03c3 --- lib/envflag/envflag.go | 73 ++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/lib/envflag/envflag.go b/lib/envflag/envflag.go index 1430ca8b4..598cf4dc0 100644 --- a/lib/envflag/envflag.go +++ b/lib/envflag/envflag.go @@ -22,9 +22,46 @@ var ( // // This function must be called instead of flag.Parse() before using any flags in the program. func Parse() { - // Substitute %{ENV_VAR} inside args with the corresponding environment variable values - args := os.Args[1:] - dstArgs := args[:0] + ParseFlagSet(flag.CommandLine, os.Args[1:]) +} + +// ParseFlagSet parses the given args into the given fs. +func ParseFlagSet(fs *flag.FlagSet, args []string) { + args = expandArgs(args) + if err := fs.Parse(args); err != nil { + // Do not use lib/logger here, since it is uninitialized yet. + log.Fatalf("cannot parse flags %q: %s", args, err) + } + if !*enable { + return + } + // Remember explicitly set command-line flags. + flagsSet := make(map[string]bool) + fs.Visit(func(f *flag.Flag) { + flagsSet[f.Name] = true + }) + + // Obtain the remaining flag values from environment vars. + fs.VisitAll(func(f *flag.Flag) { + if flagsSet[f.Name] { + // The flag is explicitly set via command-line. + return + } + // Get flag value from environment var. + fname := getEnvFlagName(f.Name) + if v, ok := envtemplate.LookupEnv(fname); ok { + if err := fs.Set(f.Name, v); err != nil { + // Do not use lib/logger here, since it is uninitialized yet. + log.Fatalf("cannot set flag %s to %q, which is read from env var %q: %s", f.Name, v, fname, err) + } + } + }) +} + +// expandArgs substitutes %{ENV_VAR} placeholders inside args +// with the corresponding environment variable values. +func expandArgs(args []string) []string { + dstArgs := make([]string, 0, len(args)) for _, arg := range args { s, err := envtemplate.ReplaceString(arg) if err != nil { @@ -35,35 +72,7 @@ func Parse() { dstArgs = append(dstArgs, s) } } - os.Args = os.Args[:1+len(dstArgs)] - - // Parse flags - flag.Parse() - if !*enable { - return - } - - // Remember explicitly set command-line flags. - flagsSet := make(map[string]bool) - flag.Visit(func(f *flag.Flag) { - flagsSet[f.Name] = true - }) - - // Obtain the remaining flag values from environment vars. - flag.VisitAll(func(f *flag.Flag) { - if flagsSet[f.Name] { - // The flag is explicitly set via command-line. - return - } - // Get flag value from environment var. - fname := getEnvFlagName(f.Name) - if v, ok := envtemplate.LookupEnv(fname); ok { - if err := flag.Set(f.Name, v); err != nil { - // Do not use lib/logger here, since it is uninitialized yet. - log.Fatalf("cannot set flag %s to %q, which is read from environment variable %q: %s", f.Name, v, fname, err) - } - } - }) + return dstArgs } func getEnvFlagName(s string) string {