|
| 1 | +import asyncio |
1 | 2 | from io import IOBase, TextIOBase |
2 | 3 | from typing import IO, AsyncIterator, List, Literal, Optional, Union, overload |
3 | 4 |
|
|
16 | 17 | from e2b.envd.api import ENVD_API_FILES_ROUTE, ahandle_envd_api_exception |
17 | 18 | from e2b.envd.filesystem import filesystem_connect, filesystem_pb2 |
18 | 19 | from e2b.envd.rpc import authentication_header, handle_rpc_exception |
19 | | -from e2b.envd.versions import ENVD_DEFAULT_USER, ENVD_VERSION_RECURSIVE_WATCH |
| 20 | +from e2b.envd.versions import ( |
| 21 | + ENVD_DEFAULT_USER, |
| 22 | + ENVD_OCTET_STREAM_UPLOAD, |
| 23 | + ENVD_VERSION_RECURSIVE_WATCH, |
| 24 | +) |
20 | 25 | from e2b.exceptions import ( |
21 | 26 | FileNotFoundException, |
22 | 27 | InvalidArgumentException, |
@@ -223,51 +228,108 @@ async def write_files( |
223 | 228 | if username is None and self._envd_version < ENVD_DEFAULT_USER: |
224 | 229 | username = default_username |
225 | 230 |
|
226 | | - params = {} |
227 | | - if username: |
228 | | - params["username"] = username |
229 | | - if len(files) == 1: |
230 | | - params["path"] = files[0]["path"] |
231 | | - |
232 | | - # Prepare the files for the multipart/form-data request |
233 | | - httpx_files = [] |
234 | | - for file in files: |
235 | | - file_path, file_data = file["path"], file["data"] |
236 | | - if isinstance(file_data, (str, bytes)): |
237 | | - # str and bytes can be passed directly |
238 | | - httpx_files.append(("file", (file_path, file_data))) |
239 | | - elif isinstance(file_data, TextIOBase): |
240 | | - # Text streams must be read first |
241 | | - httpx_files.append(("file", (file_path, file_data.read()))) |
242 | | - elif isinstance(file_data, IOBase): |
243 | | - # Binary streams can be passed directly |
244 | | - httpx_files.append(("file", (file_path, file_data))) |
245 | | - else: |
246 | | - raise InvalidArgumentException( |
247 | | - f"Unsupported data type for file {file_path}" |
| 231 | + if len(files) == 0: |
| 232 | + return [] |
| 233 | + |
| 234 | + use_octet_stream = self._envd_version >= ENVD_OCTET_STREAM_UPLOAD |
| 235 | + |
| 236 | + results: List[WriteInfo] = [] |
| 237 | + |
| 238 | + if use_octet_stream: |
| 239 | + |
| 240 | + async def _upload_file(file): |
| 241 | + file_path, file_data = file["path"], file["data"] |
| 242 | + |
| 243 | + if isinstance(file_data, str): |
| 244 | + content = file_data.encode("utf-8") |
| 245 | + elif isinstance(file_data, bytes): |
| 246 | + content = file_data |
| 247 | + elif isinstance(file_data, TextIOBase): |
| 248 | + content = file_data.read().encode("utf-8") |
| 249 | + elif isinstance(file_data, IOBase): |
| 250 | + content = file_data.read() |
| 251 | + else: |
| 252 | + raise InvalidArgumentException( |
| 253 | + f"Unsupported data type for file {file_path}" |
| 254 | + ) |
| 255 | + |
| 256 | + params = {"path": file_path} |
| 257 | + if username: |
| 258 | + params["username"] = username |
| 259 | + |
| 260 | + r = await self._envd_api.post( |
| 261 | + ENVD_API_FILES_ROUTE, |
| 262 | + content=content, |
| 263 | + headers={"Content-Type": "application/octet-stream"}, |
| 264 | + params=params, |
| 265 | + timeout=self._connection_config.get_request_timeout( |
| 266 | + request_timeout |
| 267 | + ), |
248 | 268 | ) |
249 | 269 |
|
250 | | - # Allow passing empty list of files |
251 | | - if len(httpx_files) == 0: |
252 | | - return [] |
| 270 | + err = await _ahandle_filesystem_envd_api_exception(r) |
| 271 | + if err: |
| 272 | + raise err |
253 | 273 |
|
254 | | - r = await self._envd_api.post( |
255 | | - ENVD_API_FILES_ROUTE, |
256 | | - files=httpx_files, |
257 | | - params=params, |
258 | | - timeout=self._connection_config.get_request_timeout(request_timeout), |
259 | | - ) |
| 274 | + write_result = r.json() |
260 | 275 |
|
261 | | - err = await _ahandle_filesystem_envd_api_exception(r) |
262 | | - if err: |
263 | | - raise err |
| 276 | + if not isinstance(write_result, list) or len(write_result) == 0: |
| 277 | + raise SandboxException( |
| 278 | + "Expected to receive information about written file" |
| 279 | + ) |
| 280 | + |
| 281 | + return [WriteInfo(**f) for f in write_result] |
| 282 | + |
| 283 | + upload_results = await asyncio.gather( |
| 284 | + *[_upload_file(file) for file in files] |
| 285 | + ) |
| 286 | + for file_results in upload_results: |
| 287 | + results.extend(file_results) |
| 288 | + else: |
| 289 | + params = {} |
| 290 | + if username: |
| 291 | + params["username"] = username |
| 292 | + if len(files) == 1: |
| 293 | + params["path"] = files[0]["path"] |
| 294 | + |
| 295 | + httpx_files = [] |
| 296 | + for file in files: |
| 297 | + file_path, file_data = file["path"], file["data"] |
| 298 | + if isinstance(file_data, (str, bytes)): |
| 299 | + httpx_files.append(("file", (file_path, file_data))) |
| 300 | + elif isinstance(file_data, TextIOBase): |
| 301 | + httpx_files.append(("file", (file_path, file_data.read()))) |
| 302 | + elif isinstance(file_data, IOBase): |
| 303 | + httpx_files.append(("file", (file_path, file_data))) |
| 304 | + else: |
| 305 | + raise InvalidArgumentException( |
| 306 | + f"Unsupported data type for file {file_path}" |
| 307 | + ) |
264 | 308 |
|
265 | | - write_files = r.json() |
| 309 | + if len(httpx_files) == 0: |
| 310 | + return [] |
| 311 | + |
| 312 | + r = await self._envd_api.post( |
| 313 | + ENVD_API_FILES_ROUTE, |
| 314 | + files=httpx_files, |
| 315 | + params=params, |
| 316 | + timeout=self._connection_config.get_request_timeout(request_timeout), |
| 317 | + ) |
| 318 | + |
| 319 | + err = await _ahandle_filesystem_envd_api_exception(r) |
| 320 | + if err: |
| 321 | + raise err |
| 322 | + |
| 323 | + write_result = r.json() |
| 324 | + |
| 325 | + if not isinstance(write_result, list) or len(write_result) == 0: |
| 326 | + raise SandboxException( |
| 327 | + "Expected to receive information about written file" |
| 328 | + ) |
266 | 329 |
|
267 | | - if not isinstance(write_files, list) or len(write_files) == 0: |
268 | | - raise SandboxException("Expected to receive information about written file") |
| 330 | + results.extend([WriteInfo(**f) for f in write_result]) |
269 | 331 |
|
270 | | - return [WriteInfo(**file) for file in write_files] |
| 332 | + return results |
271 | 333 |
|
272 | 334 | async def list( |
273 | 335 | self, |
|
0 commit comments