Skip to content

Commit 01a9435

Browse files
Mark Saroufimmsaroufim
authored andcommitted
Add problem sync module for updating problem sets
- Create problem_sync.py with shared logic for downloading repos, parsing competition YAMLs, and creating/updating leaderboards - Provides sync_problems() function usable by both API and Discord bot - Includes ProblemData, CompetitionData, SyncResult data classes
1 parent e2ee3d7 commit 01a9435

1 file changed

Lines changed: 300 additions & 0 deletions

File tree

src/libkernelbot/problem_sync.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""Shared logic for syncing problems from a repository.
2+
3+
This module provides the core functionality for downloading problem sets from GitHub
4+
and creating/updating leaderboards. Used by both the API and Discord bot.
5+
"""
6+
7+
import subprocess
8+
import tempfile
9+
from dataclasses import dataclass, field
10+
from datetime import datetime, timedelta, timezone
11+
from pathlib import Path
12+
from typing import Optional, TypedDict
13+
14+
import yaml
15+
16+
from .task import LeaderboardDefinition, make_task_definition
17+
from .utils import parse_deadline, setup_logging
18+
19+
logger = setup_logging(__name__)
20+
21+
22+
class ProblemData(TypedDict):
23+
name: str
24+
directory: str
25+
deadline: str
26+
gpus: list[str]
27+
28+
29+
class CompetitionData(TypedDict):
30+
name: str
31+
description: str
32+
deadline: str
33+
problems: list[ProblemData]
34+
35+
36+
@dataclass
37+
class SyncResult:
38+
"""Result of a problem sync operation."""
39+
40+
created: list[str] = field(default_factory=list)
41+
updated: list[str] = field(default_factory=list)
42+
skipped: list[dict] = field(default_factory=list)
43+
errors: list[dict] = field(default_factory=list)
44+
45+
46+
@dataclass
47+
class ProblemPlan:
48+
"""Plan for creating or updating a problem."""
49+
50+
name: str
51+
directory: str
52+
definition: LeaderboardDefinition
53+
deadline: datetime
54+
gpus: list[str]
55+
action: str # "create" or "update"
56+
57+
58+
def download_problem_repo(repository: str, branch: str, temp_dir: str) -> Path:
59+
"""Download and extract a problem repository from GitHub.
60+
61+
Args:
62+
repository: Repository in "owner/repo" format
63+
branch: Branch name to download
64+
temp_dir: Temporary directory to extract to
65+
66+
Returns:
67+
Path to the problems directory
68+
69+
Raises:
70+
RuntimeError: If download or extraction fails
71+
"""
72+
url = f"https://github.com/{repository}/archive/{branch}.zip"
73+
folder_name = repository.split("/")[-1] + "-" + branch
74+
75+
# Download
76+
try:
77+
subprocess.check_call(
78+
["wget", "-q", "-O", f"{temp_dir}/problems.zip", url],
79+
encoding="utf-8",
80+
timeout=60,
81+
)
82+
except subprocess.CalledProcessError as e:
83+
raise RuntimeError(f"Could not download repository from {url}: {e}") from e
84+
except subprocess.TimeoutExpired as e:
85+
raise RuntimeError("Timeout downloading repository") from e
86+
87+
# Extract
88+
try:
89+
subprocess.check_call(
90+
["unzip", "-q", f"{temp_dir}/problems.zip", "-d", temp_dir],
91+
encoding="utf-8",
92+
timeout=30,
93+
)
94+
except subprocess.CalledProcessError as e:
95+
raise RuntimeError(f"Could not unzip repository: {e}") from e
96+
97+
problem_dir = Path(temp_dir) / folder_name / "problems"
98+
if not problem_dir.exists():
99+
raise RuntimeError("No 'problems' directory found in repository")
100+
101+
return problem_dir
102+
103+
104+
def create_update_plan( # noqa: C901
105+
competition: CompetitionData,
106+
problem_dir: Path,
107+
existing_leaderboards: dict,
108+
force: bool = False,
109+
) -> tuple[list[ProblemPlan], list[dict]]:
110+
"""Determine which problems to create or update.
111+
112+
Args:
113+
competition: Parsed competition YAML data
114+
problem_dir: Path to the problems directory
115+
existing_leaderboards: Dict mapping leaderboard names to their data
116+
force: If True, allow significant task changes
117+
118+
Returns:
119+
Tuple of (list of ProblemPlan objects, list of skip/error dicts)
120+
"""
121+
plans = []
122+
skipped = []
123+
124+
for problem in competition.get("problems", []):
125+
name = problem.get("name")
126+
directory = problem.get("directory")
127+
deadline_str = problem.get("deadline")
128+
gpus = problem.get("gpus", [])
129+
130+
if not name or not directory:
131+
skipped.append({"name": name or "unknown", "reason": "Missing name or directory"})
132+
continue
133+
134+
source_path = problem_dir / directory
135+
if not source_path.exists():
136+
skipped.append({"name": name, "reason": f"Directory {directory} not found"})
137+
continue
138+
139+
try:
140+
definition = make_task_definition(source_path)
141+
except Exception as e:
142+
skipped.append({"name": name, "reason": f"Failed to parse task.yml: {e}"})
143+
continue
144+
145+
deadline = parse_deadline(deadline_str) if deadline_str else None
146+
if deadline is None:
147+
deadline = datetime.now(timezone.utc) + timedelta(days=365)
148+
elif deadline.tzinfo is None:
149+
deadline = deadline.replace(tzinfo=timezone.utc)
150+
151+
# Use GPUs from YAML or task definition
152+
if not gpus:
153+
gpus = definition.gpus if definition.gpus else []
154+
155+
if name in existing_leaderboards:
156+
old_lb = existing_leaderboards[name]
157+
old_deadline = old_lb["deadline"]
158+
if hasattr(old_deadline, "tzinfo") and old_deadline.tzinfo is None:
159+
old_deadline = old_deadline.replace(tzinfo=timezone.utc)
160+
161+
deadline_changed = old_deadline != deadline
162+
task_changed = old_lb["task"] != definition.task
163+
164+
if not deadline_changed and not task_changed:
165+
skipped.append({"name": name, "reason": "no changes"})
166+
continue
167+
168+
if task_changed and not force:
169+
old_task = old_lb["task"]
170+
new_task = definition.task
171+
if (
172+
old_task.files != new_task.files
173+
or old_task.config != new_task.config
174+
or old_task.lang != new_task.lang
175+
or old_task.benchmarks != new_task.benchmarks
176+
):
177+
skipped.append({"name": name, "reason": "significant task changes require --force"})
178+
continue
179+
180+
plans.append(
181+
ProblemPlan(
182+
name=name,
183+
directory=directory,
184+
definition=definition,
185+
deadline=deadline,
186+
gpus=gpus,
187+
action="update",
188+
)
189+
)
190+
else:
191+
if not gpus:
192+
skipped.append({"name": name, "reason": "No GPUs specified in task.yml or YAML"})
193+
continue
194+
195+
plans.append(
196+
ProblemPlan(
197+
name=name,
198+
directory=directory,
199+
definition=definition,
200+
deadline=deadline,
201+
gpus=gpus,
202+
action="create",
203+
)
204+
)
205+
206+
return plans, skipped
207+
208+
209+
def sync_problems( # noqa: C901
210+
db_context,
211+
repository: str = "gpu-mode/reference-kernels",
212+
problem_set: Optional[str] = None,
213+
branch: str = "main",
214+
force: bool = False,
215+
creator_id: int = 0,
216+
forum_id: int = -1,
217+
) -> SyncResult:
218+
"""Sync problems from a GitHub repository.
219+
220+
Downloads the repository, parses competition YAML files, and creates/updates leaderboards.
221+
222+
Args:
223+
db_context: Database context manager
224+
repository: Repository in "owner/repo" format
225+
problem_set: Specific problem set to sync, or None for all
226+
branch: Branch to download
227+
force: If True, allow significant task changes
228+
creator_id: ID of the creator (0 for API)
229+
forum_id: Discord forum ID (-1 for API)
230+
231+
Returns:
232+
SyncResult with created, updated, skipped, and errors lists
233+
"""
234+
if "/" in branch:
235+
raise ValueError("Branch names with slashes are not supported")
236+
237+
result = SyncResult()
238+
239+
with tempfile.TemporaryDirectory() as temp_dir:
240+
try:
241+
problem_dir = download_problem_repo(repository, branch, temp_dir)
242+
except RuntimeError as e:
243+
result.errors.append({"name": "download", "error": str(e)})
244+
return result
245+
246+
# Find YAML files
247+
if problem_set is None:
248+
yaml_files = list(problem_dir.glob("*.yaml"))
249+
else:
250+
yaml_file = problem_dir / f"{problem_set}.yaml"
251+
if not yaml_file.exists():
252+
available = [f.stem for f in problem_dir.glob("*.yaml")]
253+
result.errors.append({
254+
"name": problem_set,
255+
"error": f"Problem set not found. Available: {available}"
256+
})
257+
return result
258+
yaml_files = [yaml_file]
259+
260+
# Get existing leaderboards
261+
with db_context as db:
262+
existing_leaderboards = {lb["name"]: lb for lb in db.get_leaderboards()}
263+
264+
# Process each YAML file
265+
for yaml_file in yaml_files:
266+
try:
267+
with open(yaml_file) as f:
268+
competition = yaml.safe_load(f)
269+
270+
plans, skipped = create_update_plan(
271+
competition, problem_dir, existing_leaderboards, force
272+
)
273+
result.skipped.extend(skipped)
274+
275+
for plan in plans:
276+
try:
277+
if plan.action == "create":
278+
with db_context as db:
279+
db.create_leaderboard(
280+
name=plan.name,
281+
deadline=plan.deadline,
282+
definition=plan.definition,
283+
creator_id=creator_id,
284+
forum_id=forum_id,
285+
gpu_types=plan.gpus,
286+
)
287+
result.created.append(plan.name)
288+
else: # update
289+
with db_context as db:
290+
db.update_leaderboard(plan.name, plan.deadline, plan.definition)
291+
result.updated.append(plan.name)
292+
except Exception as e:
293+
result.errors.append({"name": plan.name, "error": f"{plan.action} failed: {e}"})
294+
295+
except yaml.YAMLError as e:
296+
result.errors.append({"name": yaml_file.stem, "error": f"Invalid YAML: {e}"})
297+
except Exception as e:
298+
result.errors.append({"name": yaml_file.stem, "error": str(e)})
299+
300+
return result

0 commit comments

Comments
 (0)