|
10 | 10 | from fastapi import Depends, FastAPI, Header, HTTPException, Request, UploadFile |
11 | 11 | from fastapi.responses import JSONResponse, StreamingResponse |
12 | 12 |
|
| 13 | +from kernelbot.env import env |
13 | 14 | from libkernelbot.backend import KernelBackend |
14 | 15 | from libkernelbot.background_submission_manager import BackgroundSubmissionManager |
15 | 16 | from libkernelbot.consts import SubmissionMode |
16 | 17 | from libkernelbot.db_types import IdentityType |
17 | 18 | from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry |
| 19 | +from libkernelbot.problem_sync import sync_problems |
18 | 20 | from libkernelbot.submission import ( |
19 | 21 | ProcessedSubmissionRequest, |
20 | 22 | SubmissionRequest, |
21 | 23 | prepare_submission, |
22 | 24 | ) |
23 | | -from libkernelbot.utils import KernelBotError, setup_logging |
| 25 | +from libkernelbot.task import make_task_definition |
| 26 | +from libkernelbot.utils import ( |
| 27 | + KernelBotError, |
| 28 | + resolve_problem_directory, |
| 29 | + setup_logging, |
| 30 | +) |
24 | 31 |
|
25 | 32 | from .api_utils import ( |
26 | 33 | _handle_discord_oauth, |
@@ -165,6 +172,16 @@ async def validate_user_header( |
165 | 172 | return user_info |
166 | 173 |
|
167 | 174 |
|
| 175 | +def require_admin( |
| 176 | + authorization: Optional[str] = Header(None, alias="Authorization"), |
| 177 | +) -> None: |
| 178 | + if not authorization: |
| 179 | + raise HTTPException(status_code=401, detail="Missing Authorization header") |
| 180 | + expected = f"Bearer {env.ADMIN_TOKEN}" |
| 181 | + if authorization != expected: |
| 182 | + raise HTTPException(status_code=401, detail="Invalid admin token") |
| 183 | + |
| 184 | + |
168 | 185 | @app.get("/auth/init") |
169 | 186 | async def auth_init(provider: str, db_context=Depends(get_db)) -> dict: |
170 | 187 | if provider not in ["discord", "github"]: |
@@ -470,6 +487,162 @@ async def run_submission_async( |
470 | 487 | logger.error(f"Unexpected error in api submissoin: {e}") |
471 | 488 | raise HTTPException(status_code=500, detail="Internal server error") from e |
472 | 489 |
|
| 490 | + |
| 491 | +@app.post("/admin/start") |
| 492 | +async def admin_start( |
| 493 | + _: Annotated[None, Depends(require_admin)], |
| 494 | +) -> dict: |
| 495 | + backend_instance.accepts_jobs = True |
| 496 | + return {"status": "ok", "accepts_jobs": True} |
| 497 | + |
| 498 | + |
| 499 | +@app.post("/admin/stop") |
| 500 | +async def admin_stop( |
| 501 | + _: Annotated[None, Depends(require_admin)], |
| 502 | +) -> dict: |
| 503 | + backend_instance.accepts_jobs = False |
| 504 | + return {"status": "ok", "accepts_jobs": False} |
| 505 | + |
| 506 | + |
| 507 | +@app.post("/admin/leaderboards") |
| 508 | +async def create_dev_leaderboard( |
| 509 | + payload: dict, |
| 510 | + _: Annotated[None, Depends(require_admin)], |
| 511 | + db_context=Depends(get_db), |
| 512 | +) -> dict: |
| 513 | + """Create a dev leaderboard from a problem directory. |
| 514 | +
|
| 515 | + Mirrors the Discord /admin leaderboard create-local command. |
| 516 | + - Only requires 'directory' (e.g., "identity_py") |
| 517 | + - Name is auto-derived as "{directory}-dev" |
| 518 | + - Deadline defaults to 1 year from now |
| 519 | + - GPU(s) must be specified in task.yml |
| 520 | + """ |
| 521 | + directory = payload.get("directory") |
| 522 | + |
| 523 | + if not directory: |
| 524 | + raise HTTPException(status_code=400, detail="Missing required field: directory") |
| 525 | + |
| 526 | + directory_path = resolve_problem_directory(directory, env.PROBLEM_DEV_DIR) |
| 527 | + if not directory_path: |
| 528 | + raise HTTPException(status_code=400, detail="Invalid problem directory") |
| 529 | + |
| 530 | + definition = make_task_definition(directory_path) |
| 531 | + |
| 532 | + # Auto-derive name and deadline like admin_cog.leaderboard_create_local |
| 533 | + leaderboard_name = f"{directory}-dev" |
| 534 | + deadline_value = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365) |
| 535 | + |
| 536 | + # GPUs must be specified in task.yml |
| 537 | + if not definition.gpus: |
| 538 | + raise HTTPException( |
| 539 | + status_code=400, |
| 540 | + detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types." |
| 541 | + ) |
| 542 | + |
| 543 | + with db_context as db: |
| 544 | + # Delete existing leaderboard if it exists (like create-local does) |
| 545 | + try: |
| 546 | + db.delete_leaderboard(leaderboard_name, force=True) |
| 547 | + except Exception: |
| 548 | + pass # Leaderboard doesn't exist, that's fine |
| 549 | + |
| 550 | + db.create_leaderboard( |
| 551 | + name=leaderboard_name, |
| 552 | + deadline=deadline_value, |
| 553 | + definition=definition, |
| 554 | + creator_id=0, |
| 555 | + forum_id=-1, |
| 556 | + gpu_types=definition.gpus, |
| 557 | + ) |
| 558 | + return {"status": "ok", "leaderboard": leaderboard_name} |
| 559 | + |
| 560 | + |
| 561 | +@app.delete("/admin/leaderboards/{leaderboard_name}") |
| 562 | +async def admin_delete_leaderboard( |
| 563 | + leaderboard_name: str, |
| 564 | + _: Annotated[None, Depends(require_admin)], |
| 565 | + db_context=Depends(get_db), |
| 566 | + force: bool = False, |
| 567 | +) -> dict: |
| 568 | + with db_context as db: |
| 569 | + db.delete_leaderboard(leaderboard_name, force=force) |
| 570 | + return {"status": "ok", "leaderboard": leaderboard_name, "force": force} |
| 571 | + |
| 572 | + |
| 573 | +@app.delete("/admin/submissions/{submission_id}") |
| 574 | +async def admin_delete_submission( |
| 575 | + submission_id: int, |
| 576 | + _: Annotated[None, Depends(require_admin)], |
| 577 | + db_context=Depends(get_db), |
| 578 | +) -> dict: |
| 579 | + with db_context as db: |
| 580 | + db.delete_submission(submission_id) |
| 581 | + return {"status": "ok", "submission_id": submission_id} |
| 582 | + |
| 583 | + |
| 584 | +@app.get("/admin/stats") |
| 585 | +async def admin_stats( |
| 586 | + _: Annotated[None, Depends(require_admin)], |
| 587 | + db_context=Depends(get_db), |
| 588 | + last_day_only: bool = False, |
| 589 | +) -> dict: |
| 590 | + with db_context as db: |
| 591 | + stats = db.generate_stats(last_day_only) |
| 592 | + return {"status": "ok", "stats": stats} |
| 593 | + |
| 594 | + |
| 595 | +@app.get("/admin/submissions/{submission_id}") |
| 596 | +async def admin_get_submission( |
| 597 | + submission_id: int, |
| 598 | + _: Annotated[None, Depends(require_admin)], |
| 599 | + db_context=Depends(get_db), |
| 600 | +) -> dict: |
| 601 | + with db_context as db: |
| 602 | + submission = db.get_submission_by_id(submission_id) |
| 603 | + if submission is None: |
| 604 | + raise HTTPException(status_code=404, detail="Submission not found") |
| 605 | + return {"status": "ok", "submission": submission} |
| 606 | + |
| 607 | + |
| 608 | +@app.post("/admin/update-problems") |
| 609 | +async def admin_update_problems( |
| 610 | + payload: dict, |
| 611 | + _: Annotated[None, Depends(require_admin)], |
| 612 | + db_context=Depends(get_db), |
| 613 | +) -> dict: |
| 614 | + """Update problems from a GitHub repository. |
| 615 | +
|
| 616 | + Mirrors the Discord /admin update-problems command. |
| 617 | + Downloads the repository, parses competition YAML files, and creates/updates leaderboards. |
| 618 | + """ |
| 619 | + repository = payload.get("repository", "gpu-mode/reference-kernels") |
| 620 | + problem_set = payload.get("problem_set") |
| 621 | + branch = payload.get("branch", "main") |
| 622 | + force = payload.get("force", False) |
| 623 | + |
| 624 | + try: |
| 625 | + result = sync_problems( |
| 626 | + db_context=db_context, |
| 627 | + repository=repository, |
| 628 | + problem_set=problem_set, |
| 629 | + branch=branch, |
| 630 | + force=force, |
| 631 | + creator_id=0, # API-created |
| 632 | + forum_id=-1, # No Discord forum |
| 633 | + ) |
| 634 | + except ValueError as e: |
| 635 | + raise HTTPException(status_code=400, detail=str(e)) from e |
| 636 | + |
| 637 | + return { |
| 638 | + "status": "ok", |
| 639 | + "created": result.created, |
| 640 | + "updated": result.updated, |
| 641 | + "skipped": result.skipped, |
| 642 | + "errors": result.errors, |
| 643 | + } |
| 644 | + |
| 645 | + |
473 | 646 | @app.get("/leaderboards") |
474 | 647 | async def get_leaderboards(db_context=Depends(get_db)): |
475 | 648 | """An endpoint that returns all leaderboards. |
|
0 commit comments