Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions lib/crewai-files/src/crewai_files/cache/upload_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import json
import logging
from typing import TYPE_CHECKING, Any

from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from aiocache.serializers import BaseSerializer # type: ignore[import-untyped]

from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
from crewai_files.uploaders.factory import ProviderType
Expand All @@ -25,6 +26,62 @@
logger = logging.getLogger(__name__)


class _CachedUploadSerializer(BaseSerializer):
"""JSON serializer for cached upload metadata.

The UploadCache supports external backends (e.g. redis). Avoid pickle-based
serialization for cached values so cache poisoning cannot turn into code
execution via unsafe deserialization.
"""

@staticmethod
def _to_json(obj: CachedUpload) -> dict[str, Any]:
return {
"file_id": obj.file_id,
"provider": obj.provider,
"file_uri": obj.file_uri,
"content_type": obj.content_type,
"uploaded_at": obj.uploaded_at.isoformat(),
"expires_at": obj.expires_at.isoformat()
if obj.expires_at is not None
else None,
}

@staticmethod
def _from_json(data: dict[str, Any]) -> CachedUpload:
return CachedUpload(
file_id=data["file_id"],
provider=data["provider"],
file_uri=data.get("file_uri"),
content_type=data["content_type"],
uploaded_at=datetime.fromisoformat(data["uploaded_at"]),
expires_at=(
datetime.fromisoformat(data["expires_at"])
if data.get("expires_at") is not None
else None
),
)

def dumps(self, value: CachedUpload | None) -> str: # type: ignore[override]
if value is None:
return "null"
return json.dumps(self._to_json(value), sort_keys=True)

def loads(self, value: str | None) -> CachedUpload | None: # type: ignore[override]
if value is None:
return None
try:
parsed = json.loads(value)
if parsed is None:
return None
if not isinstance(parsed, dict):
return None
return self._from_json(parsed)
except (TypeError, ValueError, KeyError) as exc:
logger.debug("Ignoring unreadable cached upload payload: %s", exc)
return None


@dataclass
class CachedUpload:
"""Represents a cached file upload.
Expand Down Expand Up @@ -123,13 +180,13 @@ def __init__(
if cache_type == "redis":
self._cache = Cache(
Cache.REDIS,
serializer=PickleSerializer(),
serializer=_CachedUploadSerializer(),
namespace=namespace,
**cache_kwargs,
)
else:
self._cache = Cache(
serializer=PickleSerializer(),
serializer=_CachedUploadSerializer(),
namespace=namespace,
)

Expand Down
43 changes: 42 additions & 1 deletion lib/crewai-files/tests/test_upload_cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Tests for upload cache."""

from datetime import datetime, timedelta, timezone
import json
import pickle

from crewai_files import FileBytes, ImageFile
from crewai_files.cache.upload_cache import CachedUpload, UploadCache
from crewai_files.cache.upload_cache import (
CachedUpload,
UploadCache,
_CachedUploadSerializer,
)


# Minimal valid PNG
Expand Down Expand Up @@ -76,6 +82,41 @@ def test_is_expired_no_expiry(self):
assert cached.is_expired() is False


class TestCachedUploadSerializer:
"""Tests for cache serializer compatibility."""

def test_json_round_trip(self):
"""Test cached uploads round-trip through the JSON serializer."""
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
serializer = _CachedUploadSerializer()

loaded = serializer.loads(serializer.dumps(cached))

assert loaded == cached

def test_unreadable_payloads_are_cache_misses(self):
"""Test unreadable payloads are treated as cache misses."""
serializer = _CachedUploadSerializer()
old_pickle_payload = pickle.dumps({"file_id": "file-123"})
payloads = [
old_pickle_payload,
"not-json",
json.dumps(["not", "a", "dict"]),
json.dumps({"file_id": "file-123"}),
]

for payload in payloads:
assert serializer.loads(payload) is None # type: ignore[arg-type]


class TestUploadCache:
"""Tests for UploadCache class."""

Expand Down