flag.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package flag
  2. import (
  3. "errors"
  4. "flag"
  5. "io"
  6. "math/big"
  7. "strings"
  8. "src.elv.sh/pkg/eval"
  9. "src.elv.sh/pkg/eval/errs"
  10. "src.elv.sh/pkg/eval/vals"
  11. "src.elv.sh/pkg/getopt"
  12. )
  13. // Ns is the namespace for the flag: module.
  14. var Ns = eval.BuildNsNamed("flag").
  15. AddGoFns(map[string]any{
  16. "call": call,
  17. "parse": parse,
  18. "parse-getopt": parseGetopt,
  19. }).Ns()
  20. func call(fm *eval.Frame, fn *eval.Closure, argsVal vals.List) error {
  21. var args []string
  22. err := vals.ScanListToGo(argsVal, &args)
  23. if err != nil {
  24. return err
  25. }
  26. fs := newFlagSet("")
  27. for i, name := range fn.OptNames {
  28. value := fn.OptDefaults[i]
  29. addFlag(fs, name, value, "")
  30. }
  31. err = fs.Parse(args)
  32. if err != nil {
  33. return err
  34. }
  35. m := make(map[string]any)
  36. fs.VisitAll(func(f *flag.Flag) {
  37. m[f.Name] = f.Value.(flag.Getter).Get()
  38. })
  39. return fn.Call(fm.Fork("parse:call"), callArgs(fs.Args()), m)
  40. }
  41. func callArgs(ss []string) []any {
  42. vs := make([]any, len(ss))
  43. for i, s := range ss {
  44. vs[i] = s
  45. }
  46. return vs
  47. }
  48. func parse(argsVal vals.List, specsVal vals.List) (vals.Map, vals.List, error) {
  49. var args []string
  50. err := vals.ScanListToGo(argsVal, &args)
  51. if err != nil {
  52. return nil, nil, err
  53. }
  54. var specs []vals.List
  55. err = vals.ScanListToGo(specsVal, &specs)
  56. if err != nil {
  57. return nil, nil, err
  58. }
  59. fs := newFlagSet("")
  60. for _, spec := range specs {
  61. var (
  62. name string
  63. value any
  64. description string
  65. )
  66. vals.ScanListElementsToGo(spec, &name, &value, &description)
  67. err := addFlag(fs, name, value, description)
  68. if err != nil {
  69. return nil, nil, err
  70. }
  71. }
  72. err = fs.Parse(args)
  73. if err != nil {
  74. return nil, nil, err
  75. }
  76. m := vals.EmptyMap
  77. fs.VisitAll(func(f *flag.Flag) {
  78. m = m.Assoc(f.Name, f.Value.(flag.Getter).Get())
  79. })
  80. return m, vals.MakeListSlice(fs.Args()), nil
  81. }
  82. func newFlagSet(name string) *flag.FlagSet {
  83. fs := flag.NewFlagSet(name, flag.ContinueOnError)
  84. fs.SetOutput(io.Discard)
  85. return fs
  86. }
  87. func addFlag(fs *flag.FlagSet, name string, value any, description string) error {
  88. switch value := value.(type) {
  89. case bool:
  90. fs.Bool(name, value, description)
  91. case string:
  92. fs.String(name, value, description)
  93. case int, *big.Int, *big.Rat, float64:
  94. fs.Var(&numFlag{value}, name, description)
  95. case vals.List:
  96. fs.Var(&listFlag{value}, name, description)
  97. default:
  98. return errs.BadValue{What: "flag default value",
  99. Valid: "boolean, number, string or list",
  100. Actual: vals.ReprPlain(value)}
  101. }
  102. return nil
  103. }
  104. type numFlag struct{ value vals.Num }
  105. func (nf *numFlag) String() string { return vals.ToString(nf.value) }
  106. func (nf *numFlag) Get() any { return nf.value }
  107. func (nf *numFlag) Set(s string) error { return vals.ScanToGo(s, &nf.value) }
  108. type listFlag struct{ value vals.List }
  109. func (lf *listFlag) String() string { return vals.ToString(lf.value) }
  110. func (lf *listFlag) Get() any { return lf.value }
  111. func (lf *listFlag) Set(s string) error {
  112. lf.value = vals.MakeListSlice(strings.Split(s, ","))
  113. return nil
  114. }
  115. type specStruct struct {
  116. Short rune
  117. Long string
  118. ArgRequired bool
  119. ArgOptional bool
  120. }
  121. var (
  122. errShortLong = errors.New("at least one of &short and &long must be non-empty")
  123. errArgRequiredArgOptional = errors.New("at most one of &arg-required and &arg-optional may be true")
  124. )
  125. func (s *specStruct) OptionSpec() (*getopt.OptionSpec, error) {
  126. if s.Short == 0 && s.Long == "" {
  127. return nil, errShortLong
  128. }
  129. arity := getopt.NoArgument
  130. switch {
  131. case s.ArgRequired && s.ArgOptional:
  132. return nil, errArgRequiredArgOptional
  133. case s.ArgRequired:
  134. arity = getopt.RequiredArgument
  135. case s.ArgOptional:
  136. arity = getopt.OptionalArgument
  137. }
  138. return &getopt.OptionSpec{Short: s.Short, Long: s.Long, Arity: arity}, nil
  139. }
  140. type parseGetoptOptions struct {
  141. StopAfterDoubleDash bool
  142. StopBeforeNonFlag bool
  143. LongOnly bool
  144. }
  145. func (o *parseGetoptOptions) SetDefaultOptions() { o.StopAfterDoubleDash = true }
  146. func (o *parseGetoptOptions) Config() getopt.Config {
  147. c := getopt.Config(0)
  148. if o.StopAfterDoubleDash {
  149. c |= getopt.StopAfterDoubleDash
  150. }
  151. if o.StopBeforeNonFlag {
  152. c |= getopt.StopBeforeFirstNonOption
  153. }
  154. if o.LongOnly {
  155. c |= getopt.LongOnly
  156. }
  157. return c
  158. }
  159. func parseGetopt(opts parseGetoptOptions, argsVal vals.List, specsVal vals.List) (vals.List, vals.List, error) {
  160. var args []string
  161. err := vals.ScanListToGo(argsVal, &args)
  162. if err != nil {
  163. return nil, nil, err
  164. }
  165. var specMaps []vals.Map
  166. err = vals.ScanListToGo(specsVal, &specMaps)
  167. if err != nil {
  168. return nil, nil, err
  169. }
  170. specs := make([]*getopt.OptionSpec, len(specMaps))
  171. originalSpecMap := make(map[*getopt.OptionSpec]vals.Map)
  172. for i, specMap := range specMaps {
  173. var s specStruct
  174. vals.ScanMapToGo(specMap, &s)
  175. spec, err := s.OptionSpec()
  176. if err != nil {
  177. return nil, nil, err
  178. }
  179. specs[i] = spec
  180. originalSpecMap[spec] = specMap
  181. }
  182. flags, nonFlagArgs, err := getopt.Parse(args, specs, opts.Config())
  183. if err != nil {
  184. return nil, nil, err
  185. }
  186. flagsList := vals.EmptyList
  187. for _, flag := range flags {
  188. flagsList = flagsList.Conj(
  189. vals.MakeMap(
  190. "spec", originalSpecMap[flag.Spec],
  191. "arg", flag.Argument,
  192. "long", flag.Long))
  193. }
  194. return flagsList, vals.MakeListSlice(nonFlagArgs), nil
  195. }