From a977d4d9f52c8a3bbeafc7afef68fdf4b83a6d72 Mon Sep 17 00:00:00 2001 From: Mikhail Kornilovich Date: Fri, 10 Apr 2026 18:45:40 +0300 Subject: [PATCH] add memory limit --- dsl/ast.go | 23 +++--- dsl/lexer.go | 38 ++++++++++ dsl/memory_test.go | 95 +++++++++++++++++++++++ dsl/parser.go | 120 ++++++++++++++++++++++++++---- go.mod | 2 + go.sum | 2 + reporter/reporter.go | 49 +++++++++--- runner/limiter.go | 15 ++++ runner/limiter_linux.go | 153 ++++++++++++++++++++++++++++++++++++++ runner/limiter_other.go | 27 +++++++ runner/limiter_windows.go | 134 +++++++++++++++++++++++++++++++++ runner/result.go | 6 ++ runner/runner.go | 35 ++++++++- 13 files changed, 662 insertions(+), 37 deletions(-) create mode 100644 dsl/memory_test.go create mode 100644 go.sum create mode 100644 runner/limiter.go create mode 100644 runner/limiter_linux.go create mode 100644 runner/limiter_other.go create mode 100644 runner/limiter_windows.go diff --git a/dsl/ast.go b/dsl/ast.go index 1d1ccbd..cb4a0f9 100644 --- a/dsl/ast.go +++ b/dsl/ast.go @@ -8,6 +8,7 @@ type File struct { BuildWindows string BuildDarwin string Timeout time.Duration + MemoryLimit int64 // bytes; 0 means no limit Binary string // executable name produced by build (default: solution) Sources string // glob pattern for source files, expanded as $SOURCES in build @@ -18,12 +19,13 @@ type File struct { } type Group struct { - Name string - Weight float64 - Timeout time.Duration - Env map[string]string - Scoring ScoringMode - Wrapper string // exec wrapper command (e.g., "valgrind --error-exitcode=1") + Name string + Weight float64 + Timeout time.Duration + MemoryLimit int64 + Env map[string]string + Scoring ScoringMode + Wrapper string // exec wrapper command (e.g., "valgrind --error-exitcode=1") Tests []*Test Pattern *Pattern @@ -50,10 +52,11 @@ func (p *Pattern) IsDirMode() bool { } type Test struct { - Name string - Timeout time.Duration - Env map[string]string - Wrapper string + Name string + Timeout time.Duration + MemoryLimit int64 + Env map[string]string + Wrapper string Stdin *string Args []string diff --git a/dsl/lexer.go b/dsl/lexer.go index da1b12d..8d48673 100644 --- a/dsl/lexer.go +++ b/dsl/lexer.go @@ -14,6 +14,7 @@ const ( TOKEN_FLOAT TOKEN_INT TOKEN_DURATION + TOKEN_SIZE TOKEN_LBRACE TOKEN_RBRACE @@ -37,6 +38,8 @@ func (t TokenType) String() string { return "INT" case TOKEN_DURATION: return "DURATION" + case TOKEN_SIZE: + return "SIZE" case TOKEN_LBRACE: return "{" case TOKEN_RBRACE: @@ -353,6 +356,10 @@ func (l *Lexer) readNumberOrDuration(line, col int) (Token, error) { } } + if sizeSuffix := l.tryReadSizeSuffix(); sizeSuffix != "" { + return Token{TOKEN_SIZE, buf.String() + sizeSuffix, line, col}, nil + } + suffix := l.tryReadDurationSuffix() if suffix != "" { return Token{TOKEN_DURATION, buf.String() + suffix, line, col}, nil @@ -364,6 +371,37 @@ func (l *Lexer) readNumberOrDuration(line, col int) (Token, error) { return Token{TOKEN_INT, buf.String(), line, col}, nil } +// tryReadSizeSuffix reads memory size suffixes: B, K, KB, KiB, M, MB, MiB, G, GB, GiB. +// Units are case-sensitive uppercase to avoid collision with duration "m" (minutes). +func (l *Lexer) tryReadSizeSuffix() string { + ch, ok := l.peek() + if !ok { + return "" + } + var unit rune + switch ch { + case 'B': + l.advance() + return "B" + case 'K', 'M', 'G': + unit = ch + default: + return "" + } + l.advance() + var buf strings.Builder + buf.WriteRune(unit) + if next, ok := l.peek(); ok && next == 'i' { + l.advance() + buf.WriteRune('i') + } + if next, ok := l.peek(); ok && next == 'B' { + l.advance() + buf.WriteRune('B') + } + return buf.String() +} + func (l *Lexer) tryReadDurationSuffix() string { ch, ok := l.peek() if !ok { diff --git a/dsl/memory_test.go b/dsl/memory_test.go new file mode 100644 index 0000000..3020750 --- /dev/null +++ b/dsl/memory_test.go @@ -0,0 +1,95 @@ +package dsl + +import "testing" + +func TestParseSizeLiteral(t *testing.T) { + cases := []struct { + in string + want int64 + }{ + {"256", 256}, + {"256B", 256}, + {"1K", 1024}, + {"2KB", 2 * 1024}, + {"4KiB", 4 * 1024}, + {"256M", 256 * 1024 * 1024}, + {"256MB", 256 * 1024 * 1024}, + {"512MiB", 512 * 1024 * 1024}, + {"1G", 1024 * 1024 * 1024}, + {"2GB", 2 * 1024 * 1024 * 1024}, + {"3GiB", 3 * 1024 * 1024 * 1024}, + } + for _, c := range cases { + got, err := parseSizeLiteral(c.in, 0, 0) + if err != nil { + t.Errorf("parseSizeLiteral(%q) error: %v", c.in, err) + continue + } + if got != c.want { + t.Errorf("parseSizeLiteral(%q) = %d, want %d", c.in, got, c.want) + } + } +} + +func TestParseSizeLiteralInvalid(t *testing.T) { + bad := []string{"abc", "100TB", "10XB", ""} + for _, s := range bad { + if _, err := parseSizeLiteral(s, 0, 0); err == nil { + t.Errorf("parseSizeLiteral(%q) expected error", s) + } + } +} + +func TestParseMemoryLimit(t *testing.T) { + src := ` +build "go build -o solution ." +timeout 10s +memory_limit = 256MB + +group("g1") { + weight = 0.5 + memory_limit = 128MiB + + test("t1") { + stdout = "ok\n" + } + + test("t2") { + memory_limit = 64M + stdout = "ok\n" + } +} + +group("g2") { + weight = 0.5 + + test("inherits") { + stdout = "ok\n" + } +} +` + f, _, err := Parse(src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if f.MemoryLimit != 256*1024*1024 { + t.Errorf("file memory: got %d", f.MemoryLimit) + } + g1 := f.Groups[0] + if g1.MemoryLimit != 128*1024*1024 { + t.Errorf("g1 memory: got %d", g1.MemoryLimit) + } + if g1.Tests[0].MemoryLimit != 128*1024*1024 { + t.Errorf("g1.t1 memory (inherited from group): got %d", g1.Tests[0].MemoryLimit) + } + if g1.Tests[1].MemoryLimit != 64*1024*1024 { + t.Errorf("g1.t2 memory (override): got %d", g1.Tests[1].MemoryLimit) + } + g2 := f.Groups[1] + if g2.MemoryLimit != 256*1024*1024 { + t.Errorf("g2 memory (inherited from file): got %d", g2.MemoryLimit) + } + if g2.Tests[0].MemoryLimit != 256*1024*1024 { + t.Errorf("g2.inherits memory: got %d", g2.Tests[0].MemoryLimit) + } +} diff --git a/dsl/parser.go b/dsl/parser.go index d379296..ba85de4 100644 --- a/dsl/parser.go +++ b/dsl/parser.go @@ -169,8 +169,19 @@ func (p *Parser) parseFile() (*File, error) { } f.Timeout = d + case "memory_limit": + p.advance() + if _, err := p.expect(TOKEN_ASSIGN); err != nil { + return nil, err + } + n, err := p.parseSize() + if err != nil { + return nil, err + } + f.MemoryLimit = n + case "group": - g, err := p.parseGroup(f.Timeout) + g, err := p.parseGroup(f.Timeout, f.MemoryLimit) if err != nil { return nil, err } @@ -202,7 +213,7 @@ func (p *Parser) validateWeights(f *File) error { return nil } -func (p *Parser) parseGroup(defaultTimeout time.Duration) (*Group, error) { +func (p *Parser) parseGroup(defaultTimeout time.Duration, defaultMemory int64) (*Group, error) { if err := p.expectIdent("group"); err != nil { return nil, err } @@ -221,10 +232,11 @@ func (p *Parser) parseGroup(defaultTimeout time.Duration) (*Group, error) { } g := &Group{ - Name: name.Value, - Timeout: defaultTimeout, - Env: map[string]string{}, - Scoring: ScoringPartial, + Name: name.Value, + Timeout: defaultTimeout, + MemoryLimit: defaultMemory, + Env: map[string]string{}, + Scoring: ScoringPartial, } for !p.isRBrace() { @@ -256,6 +268,17 @@ func (p *Parser) parseGroup(defaultTimeout time.Duration) (*Group, error) { } g.Timeout = d + case "memory_limit": + p.advance() + if _, err := p.expect(TOKEN_ASSIGN); err != nil { + return nil, err + } + n, err := p.parseSize() + if err != nil { + return nil, err + } + g.MemoryLimit = n + case "scoring": p.advance() if _, err := p.expect(TOKEN_ASSIGN); err != nil { @@ -307,7 +330,7 @@ func (p *Parser) parseGroup(defaultTimeout time.Duration) (*Group, error) { g.Wrapper = s.Value case "test": - test, err := p.parseTest(g.Timeout) + test, err := p.parseTest(g.Timeout, g.MemoryLimit) if err != nil { return nil, err } @@ -331,7 +354,7 @@ func (p *Parser) parseGroup(defaultTimeout time.Duration) (*Group, error) { return g, nil } -func (p *Parser) parseTest(defaultTimeout time.Duration) (*Test, error) { +func (p *Parser) parseTest(defaultTimeout time.Duration, defaultMemory int64) (*Test, error) { if err := p.expectIdent("test"); err != nil { return nil, err } @@ -351,14 +374,15 @@ func (p *Parser) parseTest(defaultTimeout time.Duration) (*Test, error) { zero := 0 test := &Test{ - Name: name.Value, - Timeout: defaultTimeout, - Env: map[string]string{}, - InFiles: map[string]string{}, - OutFiles: map[string]string{}, - ExitCode: &zero, - Stdout: NoMatcher{}, - Stderr: NoMatcher{}, + Name: name.Value, + Timeout: defaultTimeout, + MemoryLimit: defaultMemory, + Env: map[string]string{}, + InFiles: map[string]string{}, + OutFiles: map[string]string{}, + ExitCode: &zero, + Stdout: NoMatcher{}, + Stderr: NoMatcher{}, } for !p.isRBrace() { @@ -428,6 +452,17 @@ func (p *Parser) parseTest(defaultTimeout time.Duration) (*Test, error) { } test.Timeout = d + case "memory_limit": + p.advance() + if _, err := p.expect(TOKEN_ASSIGN); err != nil { + return nil, err + } + n, err := p.parseSize() + if err != nil { + return nil, err + } + test.MemoryLimit = n + case "wrapper": p.advance() if _, err := p.expect(TOKEN_ASSIGN); err != nil { @@ -680,6 +715,59 @@ func (p *Parser) parseInt() (int, error) { return n, nil } +// parseSize accepts either a TOKEN_SIZE (e.g. "256MB", "1GiB", "512K") or a bare +// TOKEN_INT interpreted as bytes. MiB/MB are both 1024² — we use IEC semantics. +func (p *Parser) parseSize() (int64, error) { + t := p.peek() + switch t.Type { + case TOKEN_SIZE: + p.advance() + return parseSizeLiteral(t.Value, t.Line, t.Col) + case TOKEN_INT: + p.advance() + n, err := strconv.ParseInt(t.Value, 10, 64) + if err != nil { + return 0, fmt.Errorf("%d:%d: invalid size %q", t.Line, t.Col, t.Value) + } + return n, nil + default: + return 0, fmt.Errorf("%d:%d: expected size (e.g. 256MB, 1GiB), got %s %q", t.Line, t.Col, t.Type, t.Value) + } +} + +func parseSizeLiteral(s string, line, col int) (int64, error) { + i := 0 + for i < len(s) && (s[i] >= '0' && s[i] <= '9') { + i++ + } + if i == 0 { + return 0, fmt.Errorf("%d:%d: invalid size %q", line, col, s) + } + numPart := s[:i] + unit := s[i:] + n, err := strconv.ParseInt(numPart, 10, 64) + if err != nil { + return 0, fmt.Errorf("%d:%d: invalid size %q", line, col, s) + } + var mult int64 + switch unit { + case "", "B": + mult = 1 + case "K", "KB", "KiB": + mult = 1024 + case "M", "MB", "MiB": + mult = 1024 * 1024 + case "G", "GB", "GiB": + mult = 1024 * 1024 * 1024 + default: + return 0, fmt.Errorf("%d:%d: unknown size unit %q (use B/K/M/G or KiB/MiB/GiB)", line, col, unit) + } + if n < 0 { + return 0, fmt.Errorf("%d:%d: size must be non-negative", line, col) + } + return n * mult, nil +} + func (p *Parser) parseDuration() (time.Duration, error) { t := p.peek() if t.Type != TOKEN_DURATION { diff --git a/go.mod b/go.mod index 6008e60..b4857f3 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/Mond1c/judge go 1.26.1 + +require golang.org/x/sys v0.27.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..bacf432 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/reporter/reporter.go b/reporter/reporter.go index 271a775..021c096 100644 --- a/reporter/reporter.go +++ b/reporter/reporter.go @@ -34,8 +34,15 @@ func Text(w io.Writer, result *runner.SuiteResult) { if tr.Status != runner.StatusPass { icon = "✗" } - fmt.Fprintf(w, "│ %s [%s] %s (%dms)\n", - icon, tr.Status, tr.Name, tr.Elapsed.Milliseconds()) + mem := "" + if tr.PeakMemory > 0 { + mem = fmt.Sprintf(", %s", humanBytes(tr.PeakMemory)) + if tr.MemoryLimit > 0 { + mem = fmt.Sprintf(", %s/%s", humanBytes(tr.PeakMemory), humanBytes(tr.MemoryLimit)) + } + } + fmt.Fprintf(w, "│ %s [%s] %s (%dms%s)\n", + icon, tr.Status, tr.Name, tr.Elapsed.Milliseconds(), mem) for _, f := range tr.Failures { for _, line := range strings.Split(f, "\n") { @@ -71,10 +78,12 @@ type jsonGroupResult struct { } type jsonTestResult struct { - Name string `json:"name"` - Status string `json:"status"` - ElapsedMs int64 `json:"elapsed_ms"` - Failures []string `json:"failures,omitempty"` + Name string `json:"name"` + Status string `json:"status"` + ElapsedMs int64 `json:"elapsed_ms"` + PeakMemoryKB int64 `json:"peak_memory_kb,omitempty"` + MemoryLimitKB int64 `json:"memory_limit_kb,omitempty"` + Failures []string `json:"failures,omitempty"` } func Aggregate(w io.Writer, dir string) error { @@ -131,6 +140,24 @@ func Aggregate(w io.Writer, dir string) error { return nil } +func humanBytes(n int64) string { + const ( + KiB = 1024 + MiB = 1024 * KiB + GiB = 1024 * MiB + ) + switch { + case n >= GiB: + return fmt.Sprintf("%.2fGiB", float64(n)/float64(GiB)) + case n >= MiB: + return fmt.Sprintf("%.1fMiB", float64(n)/float64(MiB)) + case n >= KiB: + return fmt.Sprintf("%.0fKiB", float64(n)/float64(KiB)) + default: + return fmt.Sprintf("%dB", n) + } +} + func jsonResult(r *runner.SuiteResult) jsonSuiteResult { res := jsonSuiteResult{ TotalScore: r.TotalScore, @@ -146,10 +173,12 @@ func jsonResult(r *runner.SuiteResult) jsonSuiteResult { } for _, tr := range gr.Tests { jgr.Tests = append(jgr.Tests, jsonTestResult{ - Name: tr.Name, - Status: tr.Status.String(), - ElapsedMs: tr.Elapsed.Milliseconds(), - Failures: tr.Failures, + Name: tr.Name, + Status: tr.Status.String(), + ElapsedMs: tr.Elapsed.Milliseconds(), + PeakMemoryKB: tr.PeakMemory / 1024, + MemoryLimitKB: tr.MemoryLimit / 1024, + Failures: tr.Failures, }) } res.Groups = append(res.Groups, jgr) diff --git a/runner/limiter.go b/runner/limiter.go new file mode 100644 index 0000000..88a1042 --- /dev/null +++ b/runner/limiter.go @@ -0,0 +1,15 @@ +package runner + +import "os/exec" + +type limiter interface { + prepare(cmd *exec.Cmd) error + afterStart(cmd *exec.Cmd) error + collect() limitStats + cleanup() +} + +type limitStats struct { + PeakMemory int64 + MemoryExceeded bool +} diff --git a/runner/limiter_linux.go b/runner/limiter_linux.go new file mode 100644 index 0000000..1ec4101 --- /dev/null +++ b/runner/limiter_linux.go @@ -0,0 +1,153 @@ +//go:build linux + +package runner + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +var ( + cgroupRootOnce sync.Once + cgroupRoot string + cgroupInitErr error + cgroupCounter int64 +) + +const cgroupFSRoot = "/sys/fs/cgroup" + +func ensureCgroupRoot() (string, error) { + cgroupRootOnce.Do(func() { + data, err := os.ReadFile("/proc/self/cgroup") + if err != nil { + cgroupInitErr = fmt.Errorf("read /proc/self/cgroup: %w", err) + return + } + var rel string + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if strings.HasPrefix(line, "0::") { + rel = strings.TrimPrefix(line, "0::") + break + } + } + if rel == "" { + cgroupInitErr = fmt.Errorf("cgroup v2 not found in /proc/self/cgroup (unified hierarchy required)") + return + } + ownCg := filepath.Join(cgroupFSRoot, rel) + + controllers, err := os.ReadFile(filepath.Join(ownCg, "cgroup.controllers")) + if err != nil { + cgroupInitErr = fmt.Errorf("cgroup %s not accessible: %w", ownCg, err) + return + } + if !strings.Contains(" "+string(controllers)+" ", " memory ") { + cgroupInitErr = fmt.Errorf("memory controller not delegated to %s (controllers: %s)", ownCg, strings.TrimSpace(string(controllers))) + return + } + + initCg := filepath.Join(ownCg, "judge.init") + if err := os.MkdirAll(initCg, 0755); err != nil { + cgroupInitErr = fmt.Errorf("mkdir %s: %w", initCg, err) + return + } + if err := os.WriteFile(filepath.Join(initCg, "cgroup.procs"), []byte(strconv.Itoa(os.Getpid())), 0644); err != nil { + cgroupInitErr = fmt.Errorf("move judge into %s: %w", initCg, err) + return + } + + if err := os.WriteFile(filepath.Join(ownCg, "cgroup.subtree_control"), []byte("+memory"), 0644); err != nil { + current, _ := os.ReadFile(filepath.Join(ownCg, "cgroup.subtree_control")) + if !strings.Contains(" "+string(current)+" ", " memory ") { + cgroupInitErr = fmt.Errorf("enable +memory in %s/cgroup.subtree_control: %w", ownCg, err) + return + } + } + + cgroupRoot = ownCg + }) + return cgroupRoot, cgroupInitErr +} + +type linuxLimiter struct { + memLimit int64 + cgPath string +} + +func newLimiter(memLimit int64) limiter { + return &linuxLimiter{memLimit: memLimit} +} + +func (l *linuxLimiter) prepare(cmd *exec.Cmd) error { + if l.memLimit <= 0 { + return nil + } + root, err := ensureCgroupRoot() + if err != nil { + return err + } + name := fmt.Sprintf("judge.test.%d.%d", os.Getpid(), atomic.AddInt64(&cgroupCounter, 1)) + l.cgPath = filepath.Join(root, name) + if err := os.Mkdir(l.cgPath, 0755); err != nil { + l.cgPath = "" + return fmt.Errorf("mkdir %s: %w", name, err) + } + if err := os.WriteFile(filepath.Join(l.cgPath, "memory.max"), []byte(strconv.FormatInt(l.memLimit, 10)), 0644); err != nil { + _ = os.Remove(l.cgPath) + l.cgPath = "" + return fmt.Errorf("write memory.max: %w", err) + } + return nil +} + +func (l *linuxLimiter) afterStart(cmd *exec.Cmd) error { + if l.cgPath == "" || cmd.Process == nil { + return nil + } + return os.WriteFile(filepath.Join(l.cgPath, "cgroup.procs"), []byte(strconv.Itoa(cmd.Process.Pid)), 0644) +} + +func (l *linuxLimiter) collect() limitStats { + if l.cgPath == "" { + return limitStats{} + } + var s limitStats + if data, err := os.ReadFile(filepath.Join(l.cgPath, "memory.peak")); err == nil { + if n, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64); err == nil { + s.PeakMemory = n + } + } + if data, err := os.ReadFile(filepath.Join(l.cgPath, "memory.events")); err == nil { + for _, line := range strings.Split(string(data), "\n") { + fields := strings.Fields(line) + if len(fields) != 2 { + continue + } + if (fields[0] == "oom_kill" || fields[0] == "oom_group_kill") && fields[1] != "0" { + s.MemoryExceeded = true + } + } + } + return s +} + +func (l *linuxLimiter) cleanup() { + if l.cgPath == "" { + return + } + for i := 0; i < 10; i++ { + err := os.Remove(l.cgPath) + if err == nil || os.IsNotExist(err) { + l.cgPath = "" + return + } + time.Sleep(20 * time.Millisecond) + } +} diff --git a/runner/limiter_other.go b/runner/limiter_other.go new file mode 100644 index 0000000..e317096 --- /dev/null +++ b/runner/limiter_other.go @@ -0,0 +1,27 @@ +//go:build !linux && !windows + +package runner + +import ( + "fmt" + "os/exec" +) + +type noopLimiter struct { + memLimit int64 +} + +func newLimiter(memLimit int64) limiter { + return &noopLimiter{memLimit: memLimit} +} + +func (l *noopLimiter) prepare(cmd *exec.Cmd) error { + if l.memLimit > 0 { + return fmt.Errorf("memory_limit is not supported on this platform (only linux/windows)") + } + return nil +} + +func (l *noopLimiter) afterStart(cmd *exec.Cmd) error { return nil } +func (l *noopLimiter) collect() limitStats { return limitStats{} } +func (l *noopLimiter) cleanup() {} diff --git a/runner/limiter_windows.go b/runner/limiter_windows.go new file mode 100644 index 0000000..e826caf --- /dev/null +++ b/runner/limiter_windows.go @@ -0,0 +1,134 @@ +//go:build windows + +package runner + +import ( + "fmt" + "os/exec" + "unsafe" + + "golang.org/x/sys/windows" +) + +type windowsLimiter struct { + memLimit int64 + job windows.Handle + peak int64 + exceeded bool +} + +func newLimiter(memLimit int64) limiter { + return &windowsLimiter{memLimit: memLimit} +} + +const ( + jobObjectExtendedLimitInformationClass = 9 + jobObjectLimitProcessMemory = 0x00000100 + jobObjectLimitKillOnJobClose = 0x00002000 +) + +type ioCounters struct { + ReadOperationCount uint64 + WriteOperationCount uint64 + OtherOperationCount uint64 + ReadTransferCount uint64 + WriteTransferCount uint64 + OtherTransferCount uint64 +} + +type jobObjectBasicLimitInformation struct { + PerProcessUserTimeLimit int64 + PerJobUserTimeLimit int64 + LimitFlags uint32 + MinimumWorkingSetSize uintptr + MaximumWorkingSetSize uintptr + ActiveProcessLimit uint32 + Affinity uintptr + PriorityClass uint32 + SchedulingClass uint32 +} + +type jobObjectExtendedLimitInformation struct { + BasicLimitInformation jobObjectBasicLimitInformation + IoInfo ioCounters + ProcessMemoryLimit uintptr + JobMemoryLimit uintptr + PeakProcessMemoryUsed uintptr + PeakJobMemoryUsed uintptr +} + +func (l *windowsLimiter) prepare(cmd *exec.Cmd) error { + if l.memLimit <= 0 { + return nil + } + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return fmt.Errorf("CreateJobObject: %w", err) + } + info := jobObjectExtendedLimitInformation{ + BasicLimitInformation: jobObjectBasicLimitInformation{ + LimitFlags: jobObjectLimitProcessMemory | jobObjectLimitKillOnJobClose, + }, + ProcessMemoryLimit: uintptr(l.memLimit), + } + if _, err := windows.SetInformationJobObject( + job, + jobObjectExtendedLimitInformationClass, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + ); err != nil { + windows.CloseHandle(job) + return fmt.Errorf("SetInformationJobObject: %w", err) + } + l.job = job + return nil +} + +func (l *windowsLimiter) afterStart(cmd *exec.Cmd) error { + if l.job == 0 || cmd.Process == nil { + return nil + } + procHandle, err := windows.OpenProcess( + windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE|windows.PROCESS_QUERY_INFORMATION, + false, + uint32(cmd.Process.Pid), + ) + if err != nil { + return fmt.Errorf("OpenProcess: %w", err) + } + defer windows.CloseHandle(procHandle) + if err := windows.AssignProcessToJobObject(l.job, procHandle); err != nil { + return fmt.Errorf("AssignProcessToJobObject: %w", err) + } + return nil +} + +func (l *windowsLimiter) collect() limitStats { + if l.job == 0 { + return limitStats{} + } + var info jobObjectExtendedLimitInformation + var ret uint32 + err := windows.QueryInformationJobObject( + l.job, + jobObjectExtendedLimitInformationClass, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + &ret, + ) + if err != nil { + return limitStats{} + } + l.peak = int64(info.PeakProcessMemoryUsed) + if l.memLimit > 0 && l.peak >= l.memLimit { + l.exceeded = true + } + return limitStats{PeakMemory: l.peak, MemoryExceeded: l.exceeded} +} + +func (l *windowsLimiter) cleanup() { + if l.job != 0 { + windows.CloseHandle(l.job) + l.job = 0 + } +} diff --git a/runner/result.go b/runner/result.go index 8bd8f49..8fa40d9 100644 --- a/runner/result.go +++ b/runner/result.go @@ -11,6 +11,7 @@ const ( StatusPass Status = iota StatusFail StatusTLE + StatusMLE StatusBuildError StatusRuntimeError ) @@ -23,6 +24,8 @@ func (s Status) String() string { return "FAIL" case StatusTLE: return "TLE" + case StatusMLE: + return "MLE" case StatusBuildError: return "BUILD_ERROR" case StatusRuntimeError: @@ -37,6 +40,9 @@ type TestResult struct { Status Status Elapsed time.Duration + PeakMemory int64 // bytes; 0 if not measured + MemoryLimit int64 // bytes; 0 if unlimited + Failures []string ActualStdout string diff --git a/runner/runner.go b/runner/runner.go index 6f49aea..125c0d7 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -225,6 +225,9 @@ func (r *Runner) runGroup(g *dsl.Group) *GroupResult { if t.Timeout == 0 { t.Timeout = g.Timeout } + if t.MemoryLimit == 0 { + t.MemoryLimit = g.MemoryLimit + } if t.Wrapper == "" { t.Wrapper = g.Wrapper } @@ -313,10 +316,34 @@ func (r *Runner) runTest(t *dsl.Test) *TestResult { cmd.Stdout = stdout cmd.Stderr = stderr + tr.MemoryLimit = t.MemoryLimit + lim := newLimiter(t.MemoryLimit) + if err := lim.prepare(cmd); err != nil { + tr.Status = StatusRuntimeError + tr.addFailure("memory limiter setup: %v", err) + return tr + } + defer lim.cleanup() + start := time.Now() - runErr := cmd.Run() + if err := cmd.Start(); err != nil { + tr.Status = StatusRuntimeError + tr.addFailure("start: %v", err) + return tr + } + if err := lim.afterStart(cmd); err != nil { + killProcessGroup(cmd) + _ = cmd.Wait() + tr.Status = StatusRuntimeError + tr.addFailure("memory limiter attach: %v", err) + return tr + } + runErr := cmd.Wait() tr.Elapsed = time.Since(start) + stats := lim.collect() + tr.PeakMemory = stats.PeakMemory + if ctx.Err() == context.DeadlineExceeded { killProcessGroup(cmd) } @@ -334,6 +361,12 @@ func (r *Runner) runTest(t *dsl.Test) *TestResult { return tr } + if stats.MemoryExceeded { + tr.Status = StatusMLE + tr.addFailure("memory limit exceeded (limit %d bytes, peak %d bytes)", t.MemoryLimit, stats.PeakMemory) + return tr + } + actualCode := 0 if runErr != nil { if exitErr, ok := runErr.(*exec.ExitError); ok {