diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 6c38d72..b6904d5 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,9 +1,9 @@ package main import ( - "flag" "fmt" "os" + "strings" "github.com/Mond1c/judge/dsl" "github.com/Mond1c/judge/reporter" @@ -29,30 +29,30 @@ Example: ` func main() { - if len(os.Args) >= 2 && os.Args[1] == "aggregate" { - runAggregate() + args := os.Args[1:] + + if len(args) == 0 || hasFlag(args, "--help") || hasFlag(args, "-h") { + fmt.Print(usage) + os.Exit(0) + } + + if len(args) >= 1 && args[0] == "aggregate" { + runAggregate(args[1:]) return } - fs := flag.NewFlagSet("judge", flag.ContinueOnError) - fs.SetOutput(os.Stderr) - fs.Usage = func() { fmt.Fprint(os.Stderr, usage) } + jsonOutput := hasFlag(args, "--json") + wrapper := flagValue(args, "--wrapper") + binary := flagValue(args, "--binary") + positional := positionalArgs(args) - jsonOutput := fs.Bool("json", false, "output as JSON") - wrapper := fs.String("wrapper", "", "exec wrapper command") - binary := fs.String("binary", "", "binary name override") - - if err := fs.Parse(os.Args[1:]); err != nil { - os.Exit(2) - } - - if fs.NArg() < 2 { + if len(positional) < 2 { fmt.Fprintf(os.Stderr, "error: need and \n\n%s", usage) os.Exit(1) } - testFile := fs.Arg(0) - solutionDir := fs.Arg(1) + testFile := positional[0] + solutionDir := positional[1] src, err := os.ReadFile(testFile) if err != nil { @@ -73,12 +73,12 @@ func main() { r := runner.New(f, runner.Config{ WorkDir: solutionDir, - BinaryName: *binary, - Wrapper: *wrapper, + BinaryName: binary, + Wrapper: wrapper, }) result := r.Run() - if *jsonOutput { + if jsonOutput { if err := reporter.JSON(os.Stdout, result); err != nil { fatalf("json output error: %v", err) } @@ -91,12 +91,11 @@ func main() { } } -func runAggregate() { - if len(os.Args) < 3 { +func runAggregate(args []string) { + if len(args) < 1 { fatalf("usage: judge aggregate ") } - dir := os.Args[2] - if err := reporter.Aggregate(os.Stdout, dir); err != nil { + if err := reporter.Aggregate(os.Stdout, args[0]); err != nil { fatalf("%v", err) } } @@ -105,3 +104,51 @@ func fatalf(msg string, args ...any) { fmt.Fprintf(os.Stderr, "error: "+msg+"\n", args...) os.Exit(1) } + +func hasFlag(args []string, name string) bool { + for _, a := range args { + if a == name { + return true + } + } + return false +} + +func flagValue(args []string, name string) string { + prefix := name + "=" + for i, a := range args { + if a == name && i+1 < len(args) { + return args[i+1] + } + if strings.HasPrefix(a, prefix) { + return a[len(prefix):] + } + } + return "" +} + +func positionalArgs(args []string) []string { + known := map[string]bool{"--json": true, "--help": true, "-h": true} + withValue := map[string]bool{"--wrapper": true, "--binary": true} + + var out []string + skip := false + for _, a := range args { + if skip { + skip = false + continue + } + if known[a] { + continue + } + if withValue[a] { + skip = true + continue + } + if strings.HasPrefix(a, "--wrapper=") || strings.HasPrefix(a, "--binary=") { + continue + } + out = append(out, a) + } + return out +}