Skip to content

Commit c664c2b

Browse files
Mark Saroufimmsaroufim
authored andcommitted
Add API-only mode with admin endpoints and CLI support
- Add --api-only flag to run FastAPI server without Discord bot - Add admin endpoints: start, stop, stats, submissions, leaderboards - Add POST /admin/update-problems endpoint using shared problem_sync - Rename admin_create_leaderboard to create_dev_leaderboard - Add ADMIN_TOKEN environment variable for API authentication - Extract parse_deadline and resolve_problem_directory to shared utils - Add comprehensive tests for admin API endpoints
1 parent 01a9435 commit c664c2b

7 files changed

Lines changed: 563 additions & 53 deletions

File tree

src/kernelbot/api/main.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,24 @@
1010
from fastapi import Depends, FastAPI, Header, HTTPException, Request, UploadFile
1111
from fastapi.responses import JSONResponse, StreamingResponse
1212

13+
from kernelbot.env import env
1314
from libkernelbot.backend import KernelBackend
1415
from libkernelbot.background_submission_manager import BackgroundSubmissionManager
1516
from libkernelbot.consts import SubmissionMode
1617
from libkernelbot.db_types import IdentityType
1718
from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry
19+
from libkernelbot.problem_sync import sync_problems
1820
from libkernelbot.submission import (
1921
ProcessedSubmissionRequest,
2022
SubmissionRequest,
2123
prepare_submission,
2224
)
23-
from libkernelbot.utils import KernelBotError, setup_logging
25+
from libkernelbot.task import make_task_definition
26+
from libkernelbot.utils import (
27+
KernelBotError,
28+
resolve_problem_directory,
29+
setup_logging,
30+
)
2431

2532
from .api_utils import (
2633
_handle_discord_oauth,
@@ -165,6 +172,16 @@ async def validate_user_header(
165172
return user_info
166173

167174

175+
def require_admin(
176+
authorization: Optional[str] = Header(None, alias="Authorization"),
177+
) -> None:
178+
if not authorization:
179+
raise HTTPException(status_code=401, detail="Missing Authorization header")
180+
expected = f"Bearer {env.ADMIN_TOKEN}"
181+
if authorization != expected:
182+
raise HTTPException(status_code=401, detail="Invalid admin token")
183+
184+
168185
@app.get("/auth/init")
169186
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
170187
if provider not in ["discord", "github"]:
@@ -470,6 +487,162 @@ async def run_submission_async(
470487
logger.error(f"Unexpected error in api submissoin: {e}")
471488
raise HTTPException(status_code=500, detail="Internal server error") from e
472489

490+
491+
@app.post("/admin/start")
492+
async def admin_start(
493+
_: Annotated[None, Depends(require_admin)],
494+
) -> dict:
495+
backend_instance.accepts_jobs = True
496+
return {"status": "ok", "accepts_jobs": True}
497+
498+
499+
@app.post("/admin/stop")
500+
async def admin_stop(
501+
_: Annotated[None, Depends(require_admin)],
502+
) -> dict:
503+
backend_instance.accepts_jobs = False
504+
return {"status": "ok", "accepts_jobs": False}
505+
506+
507+
@app.post("/admin/leaderboards")
508+
async def create_dev_leaderboard(
509+
payload: dict,
510+
_: Annotated[None, Depends(require_admin)],
511+
db_context=Depends(get_db),
512+
) -> dict:
513+
"""Create a dev leaderboard from a problem directory.
514+
515+
Mirrors the Discord /admin leaderboard create-local command.
516+
- Only requires 'directory' (e.g., "identity_py")
517+
- Name is auto-derived as "{directory}-dev"
518+
- Deadline defaults to 1 year from now
519+
- GPU(s) must be specified in task.yml
520+
"""
521+
directory = payload.get("directory")
522+
523+
if not directory:
524+
raise HTTPException(status_code=400, detail="Missing required field: directory")
525+
526+
directory_path = resolve_problem_directory(directory, env.PROBLEM_DEV_DIR)
527+
if not directory_path:
528+
raise HTTPException(status_code=400, detail="Invalid problem directory")
529+
530+
definition = make_task_definition(directory_path)
531+
532+
# Auto-derive name and deadline like admin_cog.leaderboard_create_local
533+
leaderboard_name = f"{directory}-dev"
534+
deadline_value = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365)
535+
536+
# GPUs must be specified in task.yml
537+
if not definition.gpus:
538+
raise HTTPException(
539+
status_code=400,
540+
detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
541+
)
542+
543+
with db_context as db:
544+
# Delete existing leaderboard if it exists (like create-local does)
545+
try:
546+
db.delete_leaderboard(leaderboard_name, force=True)
547+
except Exception:
548+
pass # Leaderboard doesn't exist, that's fine
549+
550+
db.create_leaderboard(
551+
name=leaderboard_name,
552+
deadline=deadline_value,
553+
definition=definition,
554+
creator_id=0,
555+
forum_id=-1,
556+
gpu_types=definition.gpus,
557+
)
558+
return {"status": "ok", "leaderboard": leaderboard_name}
559+
560+
561+
@app.delete("/admin/leaderboards/{leaderboard_name}")
562+
async def admin_delete_leaderboard(
563+
leaderboard_name: str,
564+
_: Annotated[None, Depends(require_admin)],
565+
db_context=Depends(get_db),
566+
force: bool = False,
567+
) -> dict:
568+
with db_context as db:
569+
db.delete_leaderboard(leaderboard_name, force=force)
570+
return {"status": "ok", "leaderboard": leaderboard_name, "force": force}
571+
572+
573+
@app.delete("/admin/submissions/{submission_id}")
574+
async def admin_delete_submission(
575+
submission_id: int,
576+
_: Annotated[None, Depends(require_admin)],
577+
db_context=Depends(get_db),
578+
) -> dict:
579+
with db_context as db:
580+
db.delete_submission(submission_id)
581+
return {"status": "ok", "submission_id": submission_id}
582+
583+
584+
@app.get("/admin/stats")
585+
async def admin_stats(
586+
_: Annotated[None, Depends(require_admin)],
587+
db_context=Depends(get_db),
588+
last_day_only: bool = False,
589+
) -> dict:
590+
with db_context as db:
591+
stats = db.generate_stats(last_day_only)
592+
return {"status": "ok", "stats": stats}
593+
594+
595+
@app.get("/admin/submissions/{submission_id}")
596+
async def admin_get_submission(
597+
submission_id: int,
598+
_: Annotated[None, Depends(require_admin)],
599+
db_context=Depends(get_db),
600+
) -> dict:
601+
with db_context as db:
602+
submission = db.get_submission_by_id(submission_id)
603+
if submission is None:
604+
raise HTTPException(status_code=404, detail="Submission not found")
605+
return {"status": "ok", "submission": submission}
606+
607+
608+
@app.post("/admin/update-problems")
609+
async def admin_update_problems(
610+
payload: dict,
611+
_: Annotated[None, Depends(require_admin)],
612+
db_context=Depends(get_db),
613+
) -> dict:
614+
"""Update problems from a GitHub repository.
615+
616+
Mirrors the Discord /admin update-problems command.
617+
Downloads the repository, parses competition YAML files, and creates/updates leaderboards.
618+
"""
619+
repository = payload.get("repository", "gpu-mode/reference-kernels")
620+
problem_set = payload.get("problem_set")
621+
branch = payload.get("branch", "main")
622+
force = payload.get("force", False)
623+
624+
try:
625+
result = sync_problems(
626+
db_context=db_context,
627+
repository=repository,
628+
problem_set=problem_set,
629+
branch=branch,
630+
force=force,
631+
creator_id=0, # API-created
632+
forum_id=-1, # No Discord forum
633+
)
634+
except ValueError as e:
635+
raise HTTPException(status_code=400, detail=str(e)) from e
636+
637+
return {
638+
"status": "ok",
639+
"created": result.created,
640+
"updated": result.updated,
641+
"skipped": result.skipped,
642+
"errors": result.errors,
643+
}
644+
645+
473646
@app.get("/leaderboards")
474647
async def get_leaderboards(db_context=Depends(get_db)):
475648
"""An endpoint that returns all leaderboards.

src/kernelbot/cogs/admin_cog.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from libkernelbot.task import LeaderboardDefinition, make_task_definition
2525
from libkernelbot.utils import (
2626
KernelBotError,
27+
parse_deadline,
2728
setup_logging,
2829
)
2930

@@ -217,17 +218,6 @@ async def leaderboard_create_local(
217218
f"Leaderboard '{leaderboard_name}' created.",
218219
)
219220

220-
def _parse_deadline(self, deadline: str):
221-
# Try parsing with time first
222-
try:
223-
return datetime.strptime(deadline, "%Y-%m-%d %H:%M")
224-
except ValueError:
225-
try:
226-
return datetime.strptime(deadline, "%Y-%m-%d")
227-
except ValueError as ve:
228-
logger.error(f"Value Error: {str(ve)}", exc_info=True)
229-
return None
230-
231221
def _leaderboard_opening_message(
232222
self, leaderboard_name: str, deadline: datetime, description: str
233223
):
@@ -254,7 +244,7 @@ async def leaderboard_create_impl( # noqa: C901
254244
)
255245
return
256246

257-
date_value = self._parse_deadline(deadline)
247+
date_value = parse_deadline(deadline)
258248
if date_value is None:
259249
await send_discord_message(
260250
interaction,
@@ -632,7 +622,7 @@ async def _create_update_plan( # noqa: C901
632622

633623
# from the database, we get datetime with timezone,
634624
# so we need to convert here to enable comparison
635-
new_dl = self._parse_deadline(problem["deadline"])
625+
new_dl = parse_deadline(problem["deadline"])
636626
new_dl = new_dl.astimezone(timezone.utc)
637627
if old["deadline"] != new_dl:
638628
pass
@@ -749,7 +739,7 @@ async def update_competition(
749739
with self.bot.leaderboard_db as db:
750740
task = make_task_definition(root / entry["directory"])
751741
db.update_leaderboard(
752-
entry["name"], self._parse_deadline(entry["deadline"]), task
742+
entry["name"], parse_deadline(entry["deadline"]), task
753743
)
754744
new_lb: LeaderboardItem = db.get_leaderboard(entry["name"])
755745

src/kernelbot/env.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,8 @@
55

66
from libkernelbot.utils import get_github_branch_name
77

8-
9-
def init_environment():
10-
load_dotenv()
11-
12-
# Validate environment
13-
required_env_vars = ["DISCORD_TOKEN", "GITHUB_TOKEN", "GITHUB_REPO"]
14-
for var in required_env_vars:
15-
if not os.getenv(var):
16-
raise ValueError(f"{var} not found")
17-
18-
19-
init_environment()
8+
# Load .env at module level
9+
load_dotenv()
2010

2111
env = types.SimpleNamespace()
2212

@@ -26,6 +16,8 @@ def init_environment():
2616
env.DISCORD_CLUSTER_STAGING_ID = os.getenv("DISCORD_CLUSTER_STAGING_ID")
2717
env.DISCORD_DEBUG_CLUSTER_STAGING_ID = os.getenv("DISCORD_DEBUG_CLUSTER_STAGING_ID")
2818

19+
env.ADMIN_TOKEN = os.getenv("ADMIN_TOKEN")
20+
2921
# Only required to run the CLI against this instance
3022
# setting these is required only to run the CLI against local instance
3123
env.CLI_DISCORD_CLIENT_ID = os.getenv("CLI_DISCORD_CLIENT_ID", "")
@@ -47,3 +39,13 @@ def init_environment():
4739
# PostgreSQL-specific constants
4840
env.DATABASE_URL = os.getenv("DATABASE_URL")
4941
env.DISABLE_SSL = os.getenv("DISABLE_SSL")
42+
43+
44+
def init_environment(skip_discord: bool = False):
45+
"""Validate required environment variables."""
46+
required_env_vars = ["GITHUB_TOKEN", "GITHUB_REPO"]
47+
if not skip_discord:
48+
required_env_vars.append("DISCORD_TOKEN")
49+
for var in required_env_vars:
50+
if not os.getenv(var):
51+
raise ValueError(f"{var} not found")

0 commit comments

Comments
 (0)