|
8 | 8 | Callable, |
9 | 9 | Dict, |
10 | 10 | Generator, |
| 11 | + List, |
11 | 12 | Optional, |
12 | 13 | TypeVar, |
13 | 14 | Union, |
|
27 | 28 | validate, |
28 | 29 | ) |
29 | 30 |
|
| 31 | +from .graphql_request import GraphQLRequest |
30 | 32 | from .transport.async_transport import AsyncTransport |
31 | 33 | from .transport.exceptions import TransportClosed, TransportQueryError |
32 | 34 | from .transport.local_schema import LocalSchemaTransport |
@@ -236,6 +238,24 @@ def execute_sync( |
236 | 238 | **kwargs, |
237 | 239 | ) |
238 | 240 |
|
| 241 | + def execute_batch_sync( |
| 242 | + self, |
| 243 | + reqs: List[GraphQLRequest], |
| 244 | + serialize_variables: Optional[bool] = None, |
| 245 | + parse_result: Optional[bool] = None, |
| 246 | + get_execution_result: bool = False, |
| 247 | + **kwargs, |
| 248 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 249 | + """:meta private:""" |
| 250 | + with self as session: |
| 251 | + return session.execute_batch( |
| 252 | + reqs, |
| 253 | + serialize_variables=serialize_variables, |
| 254 | + parse_result=parse_result, |
| 255 | + get_execution_result=get_execution_result, |
| 256 | + **kwargs, |
| 257 | + ) |
| 258 | + |
239 | 259 | @overload |
240 | 260 | async def execute_async( |
241 | 261 | self, |
@@ -375,7 +395,6 @@ def execute( |
375 | 395 | """ |
376 | 396 |
|
377 | 397 | if isinstance(self.transport, AsyncTransport): |
378 | | - |
379 | 398 | # Get the current asyncio event loop |
380 | 399 | # Or create a new event loop if there isn't one (in a new Thread) |
381 | 400 | try: |
@@ -418,6 +437,48 @@ def execute( |
418 | 437 | **kwargs, |
419 | 438 | ) |
420 | 439 |
|
| 440 | + def execute_batch( |
| 441 | + self, |
| 442 | + reqs: List[GraphQLRequest], |
| 443 | + serialize_variables: Optional[bool] = None, |
| 444 | + parse_result: Optional[bool] = None, |
| 445 | + get_execution_result: bool = False, |
| 446 | + **kwargs, |
| 447 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 448 | + """Execute multiple GraphQL requests in a batch against the remote server using |
| 449 | + the transport provided during init. |
| 450 | +
|
| 451 | + This function **WILL BLOCK** until the result is received from the server. |
| 452 | +
|
| 453 | + Either the transport is sync and we execute the query synchronously directly |
| 454 | + OR the transport is async and we execute the query in the asyncio loop |
| 455 | + (blocking here until answer). |
| 456 | +
|
| 457 | + This method will: |
| 458 | +
|
| 459 | + - connect using the transport to get a session |
| 460 | + - execute the GraphQL requests on the transport session |
| 461 | + - close the session and close the connection to the server |
| 462 | +
|
| 463 | + If you want to perform multiple executions, it is better to use |
| 464 | + the context manager to keep a session active. |
| 465 | +
|
| 466 | + The extra arguments passed in the method will be passed to the transport |
| 467 | + execute method. |
| 468 | + """ |
| 469 | + |
| 470 | + if isinstance(self.transport, AsyncTransport): |
| 471 | + raise NotImplementedError("Batching is not implemented for async yet.") |
| 472 | + |
| 473 | + else: # Sync transports |
| 474 | + return self.execute_batch_sync( |
| 475 | + reqs, |
| 476 | + serialize_variables=serialize_variables, |
| 477 | + parse_result=parse_result, |
| 478 | + get_execution_result=get_execution_result, |
| 479 | + **kwargs, |
| 480 | + ) |
| 481 | + |
421 | 482 | @overload |
422 | 483 | def subscribe_async( |
423 | 484 | self, |
@@ -476,7 +537,6 @@ async def subscribe_async( |
476 | 537 | ]: |
477 | 538 | """:meta private:""" |
478 | 539 | async with self as session: |
479 | | - |
480 | 540 | generator = session.subscribe( |
481 | 541 | document, |
482 | 542 | variable_values=variable_values, |
@@ -600,7 +660,6 @@ def subscribe( |
600 | 660 | pass |
601 | 661 |
|
602 | 662 | except (KeyboardInterrupt, Exception, GeneratorExit): |
603 | | - |
604 | 663 | # Graceful shutdown |
605 | 664 | asyncio.ensure_future(async_generator.aclose(), loop=loop) |
606 | 665 |
|
@@ -661,11 +720,9 @@ async def close_async(self): |
661 | 720 | await self.transport.close() |
662 | 721 |
|
663 | 722 | async def __aenter__(self): |
664 | | - |
665 | 723 | return await self.connect_async() |
666 | 724 |
|
667 | 725 | async def __aexit__(self, exc_type, exc, tb): |
668 | | - |
669 | 726 | await self.close_async() |
670 | 727 |
|
671 | 728 | def connect_sync(self): |
@@ -705,7 +762,6 @@ def close_sync(self): |
705 | 762 | self.transport.close() |
706 | 763 |
|
707 | 764 | def __enter__(self): |
708 | | - |
709 | 765 | return self.connect_sync() |
710 | 766 |
|
711 | 767 | def __exit__(self, *args): |
@@ -880,6 +936,108 @@ def execute( |
880 | 936 |
|
881 | 937 | return result.data |
882 | 938 |
|
| 939 | + def _execute_batch( |
| 940 | + self, |
| 941 | + reqs: List[GraphQLRequest], |
| 942 | + serialize_variables: Optional[bool] = None, |
| 943 | + parse_result: Optional[bool] = None, |
| 944 | + **kwargs, |
| 945 | + ) -> List[ExecutionResult]: |
| 946 | + """Execute multiple GraphQL requests in a batch, using |
| 947 | + the sync transport, returning a list of ExecutionResult objects. |
| 948 | +
|
| 949 | + :param reqs: List of requests that will be executed. |
| 950 | + :param serialize_variables: whether the variable values should be |
| 951 | + serialized. Used for custom scalars and/or enums. |
| 952 | + By default use the serialize_variables argument of the client. |
| 953 | + :param parse_result: Whether gql will unserialize the result. |
| 954 | + By default use the parse_results argument of the client. |
| 955 | +
|
| 956 | + The extra arguments are passed to the transport execute method.""" |
| 957 | + |
| 958 | + # Validate document |
| 959 | + if self.client.schema: |
| 960 | + for req in reqs: |
| 961 | + self.client.validate(req.document) |
| 962 | + |
| 963 | + # Parse variable values for custom scalars if requested |
| 964 | + if serialize_variables or ( |
| 965 | + serialize_variables is None and self.client.serialize_variables |
| 966 | + ): |
| 967 | + reqs = [ |
| 968 | + req.serialize_variable_values(self.client.schema) |
| 969 | + if req.variable_values is not None |
| 970 | + else req |
| 971 | + for req in reqs |
| 972 | + ] |
| 973 | + |
| 974 | + results = self.transport.execute_batch(reqs, **kwargs) |
| 975 | + |
| 976 | + # Unserialize the result if requested |
| 977 | + if self.client.schema: |
| 978 | + if parse_result or (parse_result is None and self.client.parse_results): |
| 979 | + for result in results: |
| 980 | + result.data = parse_result_fn( |
| 981 | + self.client.schema, |
| 982 | + req.document, |
| 983 | + result.data, |
| 984 | + operation_name=req.operation_name, |
| 985 | + ) |
| 986 | + |
| 987 | + return results |
| 988 | + |
| 989 | + def execute_batch( |
| 990 | + self, |
| 991 | + reqs: List[GraphQLRequest], |
| 992 | + serialize_variables: Optional[bool] = None, |
| 993 | + parse_result: Optional[bool] = None, |
| 994 | + get_execution_result: bool = False, |
| 995 | + **kwargs, |
| 996 | + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: |
| 997 | + """Execute multiple GraphQL requests in a batch, using |
| 998 | + the sync transport. This method sends the requests to the server all at once. |
| 999 | +
|
| 1000 | + Raises a TransportQueryError if an error has been returned in any |
| 1001 | + ExecutionResult. |
| 1002 | +
|
| 1003 | + :param reqs: List of requests that will be executed. |
| 1004 | + :param serialize_variables: whether the variable values should be |
| 1005 | + serialized. Used for custom scalars and/or enums. |
| 1006 | + By default use the serialize_variables argument of the client. |
| 1007 | + :param parse_result: Whether gql will unserialize the result. |
| 1008 | + By default use the parse_results argument of the client. |
| 1009 | + :param get_execution_result: return the full ExecutionResult instance instead of |
| 1010 | + only the "data" field. Necessary if you want to get the "extensions" field. |
| 1011 | +
|
| 1012 | + The extra arguments are passed to the transport execute method.""" |
| 1013 | + |
| 1014 | + # Validate and execute on the transport |
| 1015 | + results = self._execute_batch( |
| 1016 | + reqs, |
| 1017 | + serialize_variables=serialize_variables, |
| 1018 | + parse_result=parse_result, |
| 1019 | + **kwargs, |
| 1020 | + ) |
| 1021 | + |
| 1022 | + for result in results: |
| 1023 | + # Raise an error if an error is returned in the ExecutionResult object |
| 1024 | + if result.errors: |
| 1025 | + raise TransportQueryError( |
| 1026 | + str_first_element(result.errors), |
| 1027 | + errors=result.errors, |
| 1028 | + data=result.data, |
| 1029 | + extensions=result.extensions, |
| 1030 | + ) |
| 1031 | + |
| 1032 | + assert ( |
| 1033 | + result.data is not None |
| 1034 | + ), "Transport returned an ExecutionResult without data or errors" |
| 1035 | + |
| 1036 | + if get_execution_result: |
| 1037 | + return results |
| 1038 | + |
| 1039 | + return cast(List[Dict[str, Any]], [result.data for result in results]) |
| 1040 | + |
883 | 1041 | def fetch_schema(self) -> None: |
884 | 1042 | """Fetch the GraphQL schema explicitly using introspection. |
885 | 1043 |
|
@@ -966,7 +1124,6 @@ async def _subscribe( |
966 | 1124 |
|
967 | 1125 | try: |
968 | 1126 | async for result in inner_generator: |
969 | | - |
970 | 1127 | if self.client.schema: |
971 | 1128 | if parse_result or ( |
972 | 1129 | parse_result is None and self.client.parse_results |
@@ -1070,7 +1227,6 @@ async def subscribe( |
1070 | 1227 | try: |
1071 | 1228 | # Validate and subscribe on the transport |
1072 | 1229 | async for result in inner_generator: |
1073 | | - |
1074 | 1230 | # Raise an error if an error is returned in the ExecutionResult object |
1075 | 1231 | if result.errors: |
1076 | 1232 | raise TransportQueryError( |
@@ -1343,7 +1499,6 @@ async def _connection_loop(self): |
1343 | 1499 | """ |
1344 | 1500 |
|
1345 | 1501 | while True: |
1346 | | - |
1347 | 1502 | # Connect to the transport with the retry decorator |
1348 | 1503 | # By default it should keep retrying until it connect |
1349 | 1504 | await self._connect_with_retries() |
|
0 commit comments