add memory limit

This commit is contained in:
2026-04-10 18:45:40 +03:00
parent 86b8d83643
commit a977d4d9f5
13 changed files with 662 additions and 37 deletions

15
runner/limiter.go Normal file
View File

@@ -0,0 +1,15 @@
package runner
import "os/exec"
type limiter interface {
prepare(cmd *exec.Cmd) error
afterStart(cmd *exec.Cmd) error
collect() limitStats
cleanup()
}
type limitStats struct {
PeakMemory int64
MemoryExceeded bool
}

153
runner/limiter_linux.go Normal file
View File

@@ -0,0 +1,153 @@
//go:build linux
package runner
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
var (
cgroupRootOnce sync.Once
cgroupRoot string
cgroupInitErr error
cgroupCounter int64
)
const cgroupFSRoot = "/sys/fs/cgroup"
func ensureCgroupRoot() (string, error) {
cgroupRootOnce.Do(func() {
data, err := os.ReadFile("/proc/self/cgroup")
if err != nil {
cgroupInitErr = fmt.Errorf("read /proc/self/cgroup: %w", err)
return
}
var rel string
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
if strings.HasPrefix(line, "0::") {
rel = strings.TrimPrefix(line, "0::")
break
}
}
if rel == "" {
cgroupInitErr = fmt.Errorf("cgroup v2 not found in /proc/self/cgroup (unified hierarchy required)")
return
}
ownCg := filepath.Join(cgroupFSRoot, rel)
controllers, err := os.ReadFile(filepath.Join(ownCg, "cgroup.controllers"))
if err != nil {
cgroupInitErr = fmt.Errorf("cgroup %s not accessible: %w", ownCg, err)
return
}
if !strings.Contains(" "+string(controllers)+" ", " memory ") {
cgroupInitErr = fmt.Errorf("memory controller not delegated to %s (controllers: %s)", ownCg, strings.TrimSpace(string(controllers)))
return
}
initCg := filepath.Join(ownCg, "judge.init")
if err := os.MkdirAll(initCg, 0755); err != nil {
cgroupInitErr = fmt.Errorf("mkdir %s: %w", initCg, err)
return
}
if err := os.WriteFile(filepath.Join(initCg, "cgroup.procs"), []byte(strconv.Itoa(os.Getpid())), 0644); err != nil {
cgroupInitErr = fmt.Errorf("move judge into %s: %w", initCg, err)
return
}
if err := os.WriteFile(filepath.Join(ownCg, "cgroup.subtree_control"), []byte("+memory"), 0644); err != nil {
current, _ := os.ReadFile(filepath.Join(ownCg, "cgroup.subtree_control"))
if !strings.Contains(" "+string(current)+" ", " memory ") {
cgroupInitErr = fmt.Errorf("enable +memory in %s/cgroup.subtree_control: %w", ownCg, err)
return
}
}
cgroupRoot = ownCg
})
return cgroupRoot, cgroupInitErr
}
type linuxLimiter struct {
memLimit int64
cgPath string
}
func newLimiter(memLimit int64) limiter {
return &linuxLimiter{memLimit: memLimit}
}
func (l *linuxLimiter) prepare(cmd *exec.Cmd) error {
if l.memLimit <= 0 {
return nil
}
root, err := ensureCgroupRoot()
if err != nil {
return err
}
name := fmt.Sprintf("judge.test.%d.%d", os.Getpid(), atomic.AddInt64(&cgroupCounter, 1))
l.cgPath = filepath.Join(root, name)
if err := os.Mkdir(l.cgPath, 0755); err != nil {
l.cgPath = ""
return fmt.Errorf("mkdir %s: %w", name, err)
}
if err := os.WriteFile(filepath.Join(l.cgPath, "memory.max"), []byte(strconv.FormatInt(l.memLimit, 10)), 0644); err != nil {
_ = os.Remove(l.cgPath)
l.cgPath = ""
return fmt.Errorf("write memory.max: %w", err)
}
return nil
}
func (l *linuxLimiter) afterStart(cmd *exec.Cmd) error {
if l.cgPath == "" || cmd.Process == nil {
return nil
}
return os.WriteFile(filepath.Join(l.cgPath, "cgroup.procs"), []byte(strconv.Itoa(cmd.Process.Pid)), 0644)
}
func (l *linuxLimiter) collect() limitStats {
if l.cgPath == "" {
return limitStats{}
}
var s limitStats
if data, err := os.ReadFile(filepath.Join(l.cgPath, "memory.peak")); err == nil {
if n, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64); err == nil {
s.PeakMemory = n
}
}
if data, err := os.ReadFile(filepath.Join(l.cgPath, "memory.events")); err == nil {
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) != 2 {
continue
}
if (fields[0] == "oom_kill" || fields[0] == "oom_group_kill") && fields[1] != "0" {
s.MemoryExceeded = true
}
}
}
return s
}
func (l *linuxLimiter) cleanup() {
if l.cgPath == "" {
return
}
for i := 0; i < 10; i++ {
err := os.Remove(l.cgPath)
if err == nil || os.IsNotExist(err) {
l.cgPath = ""
return
}
time.Sleep(20 * time.Millisecond)
}
}

27
runner/limiter_other.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build !linux && !windows
package runner
import (
"fmt"
"os/exec"
)
type noopLimiter struct {
memLimit int64
}
func newLimiter(memLimit int64) limiter {
return &noopLimiter{memLimit: memLimit}
}
func (l *noopLimiter) prepare(cmd *exec.Cmd) error {
if l.memLimit > 0 {
return fmt.Errorf("memory_limit is not supported on this platform (only linux/windows)")
}
return nil
}
func (l *noopLimiter) afterStart(cmd *exec.Cmd) error { return nil }
func (l *noopLimiter) collect() limitStats { return limitStats{} }
func (l *noopLimiter) cleanup() {}

134
runner/limiter_windows.go Normal file
View File

@@ -0,0 +1,134 @@
//go:build windows
package runner
import (
"fmt"
"os/exec"
"unsafe"
"golang.org/x/sys/windows"
)
type windowsLimiter struct {
memLimit int64
job windows.Handle
peak int64
exceeded bool
}
func newLimiter(memLimit int64) limiter {
return &windowsLimiter{memLimit: memLimit}
}
const (
jobObjectExtendedLimitInformationClass = 9
jobObjectLimitProcessMemory = 0x00000100
jobObjectLimitKillOnJobClose = 0x00002000
)
type ioCounters struct {
ReadOperationCount uint64
WriteOperationCount uint64
OtherOperationCount uint64
ReadTransferCount uint64
WriteTransferCount uint64
OtherTransferCount uint64
}
type jobObjectBasicLimitInformation struct {
PerProcessUserTimeLimit int64
PerJobUserTimeLimit int64
LimitFlags uint32
MinimumWorkingSetSize uintptr
MaximumWorkingSetSize uintptr
ActiveProcessLimit uint32
Affinity uintptr
PriorityClass uint32
SchedulingClass uint32
}
type jobObjectExtendedLimitInformation struct {
BasicLimitInformation jobObjectBasicLimitInformation
IoInfo ioCounters
ProcessMemoryLimit uintptr
JobMemoryLimit uintptr
PeakProcessMemoryUsed uintptr
PeakJobMemoryUsed uintptr
}
func (l *windowsLimiter) prepare(cmd *exec.Cmd) error {
if l.memLimit <= 0 {
return nil
}
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
return fmt.Errorf("CreateJobObject: %w", err)
}
info := jobObjectExtendedLimitInformation{
BasicLimitInformation: jobObjectBasicLimitInformation{
LimitFlags: jobObjectLimitProcessMemory | jobObjectLimitKillOnJobClose,
},
ProcessMemoryLimit: uintptr(l.memLimit),
}
if _, err := windows.SetInformationJobObject(
job,
jobObjectExtendedLimitInformationClass,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)),
); err != nil {
windows.CloseHandle(job)
return fmt.Errorf("SetInformationJobObject: %w", err)
}
l.job = job
return nil
}
func (l *windowsLimiter) afterStart(cmd *exec.Cmd) error {
if l.job == 0 || cmd.Process == nil {
return nil
}
procHandle, err := windows.OpenProcess(
windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE|windows.PROCESS_QUERY_INFORMATION,
false,
uint32(cmd.Process.Pid),
)
if err != nil {
return fmt.Errorf("OpenProcess: %w", err)
}
defer windows.CloseHandle(procHandle)
if err := windows.AssignProcessToJobObject(l.job, procHandle); err != nil {
return fmt.Errorf("AssignProcessToJobObject: %w", err)
}
return nil
}
func (l *windowsLimiter) collect() limitStats {
if l.job == 0 {
return limitStats{}
}
var info jobObjectExtendedLimitInformation
var ret uint32
err := windows.QueryInformationJobObject(
l.job,
jobObjectExtendedLimitInformationClass,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)),
&ret,
)
if err != nil {
return limitStats{}
}
l.peak = int64(info.PeakProcessMemoryUsed)
if l.memLimit > 0 && l.peak >= l.memLimit {
l.exceeded = true
}
return limitStats{PeakMemory: l.peak, MemoryExceeded: l.exceeded}
}
func (l *windowsLimiter) cleanup() {
if l.job != 0 {
windows.CloseHandle(l.job)
l.job = 0
}
}

View File

@@ -11,6 +11,7 @@ const (
StatusPass Status = iota
StatusFail
StatusTLE
StatusMLE
StatusBuildError
StatusRuntimeError
)
@@ -23,6 +24,8 @@ func (s Status) String() string {
return "FAIL"
case StatusTLE:
return "TLE"
case StatusMLE:
return "MLE"
case StatusBuildError:
return "BUILD_ERROR"
case StatusRuntimeError:
@@ -37,6 +40,9 @@ type TestResult struct {
Status Status
Elapsed time.Duration
PeakMemory int64 // bytes; 0 if not measured
MemoryLimit int64 // bytes; 0 if unlimited
Failures []string
ActualStdout string

View File

@@ -225,6 +225,9 @@ func (r *Runner) runGroup(g *dsl.Group) *GroupResult {
if t.Timeout == 0 {
t.Timeout = g.Timeout
}
if t.MemoryLimit == 0 {
t.MemoryLimit = g.MemoryLimit
}
if t.Wrapper == "" {
t.Wrapper = g.Wrapper
}
@@ -313,10 +316,34 @@ func (r *Runner) runTest(t *dsl.Test) *TestResult {
cmd.Stdout = stdout
cmd.Stderr = stderr
tr.MemoryLimit = t.MemoryLimit
lim := newLimiter(t.MemoryLimit)
if err := lim.prepare(cmd); err != nil {
tr.Status = StatusRuntimeError
tr.addFailure("memory limiter setup: %v", err)
return tr
}
defer lim.cleanup()
start := time.Now()
runErr := cmd.Run()
if err := cmd.Start(); err != nil {
tr.Status = StatusRuntimeError
tr.addFailure("start: %v", err)
return tr
}
if err := lim.afterStart(cmd); err != nil {
killProcessGroup(cmd)
_ = cmd.Wait()
tr.Status = StatusRuntimeError
tr.addFailure("memory limiter attach: %v", err)
return tr
}
runErr := cmd.Wait()
tr.Elapsed = time.Since(start)
stats := lim.collect()
tr.PeakMemory = stats.PeakMemory
if ctx.Err() == context.DeadlineExceeded {
killProcessGroup(cmd)
}
@@ -334,6 +361,12 @@ func (r *Runner) runTest(t *dsl.Test) *TestResult {
return tr
}
if stats.MemoryExceeded {
tr.Status = StatusMLE
tr.addFailure("memory limit exceeded (limit %d bytes, peak %d bytes)", t.MemoryLimit, stats.PeakMemory)
return tr
}
actualCode := 0
if runErr != nil {
if exitErr, ok := runErr.(*exec.ExitError); ok {