2626from typing import (
2727 TYPE_CHECKING ,
2828 Any ,
29- Generator ,
29+ Callable ,
30+ Iterable ,
3031 Iterator ,
3132 Mapping ,
3233 Optional ,
@@ -111,9 +112,6 @@ def __init__(
111112 self .uses_hint_update = False
112113 self .uses_hint_delete = False
113114 self .uses_sort = False
114- self .is_retryable = True
115- self .retrying = False
116- self .started_retryable_write = False
117115 # Extra state so that we know where to pick up on a retry attempt.
118116 self .current_run = None
119117 self .next_run = None
@@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
129127 self .is_encrypted = False
130128 return _BulkWriteContext
131129
132- def add_insert (self , document : _DocumentOut ) -> None :
130+ @property
131+ def is_retryable (self ) -> bool :
132+ if self .current_run :
133+ return self .current_run .is_retryable
134+ return True
135+
136+ @property
137+ def retrying (self ) -> bool :
138+ if self .current_run :
139+ return self .current_run .retrying
140+ return False
141+
142+ @property
143+ def started_retryable_write (self ) -> bool :
144+ if self .current_run :
145+ return self .current_run .started_retryable_write
146+ return False
147+
148+ def add_insert (self , document : _DocumentOut ) -> bool :
133149 """Add an insert document to the list of ops."""
134150 validate_is_document_type ("document" , document )
135151 # Generate ObjectId client side.
136152 if not (isinstance (document , RawBSONDocument ) or "_id" in document ):
137153 document ["_id" ] = ObjectId ()
138154 self .ops .append ((_INSERT , document ))
155+ return True
139156
140157 def add_update (
141158 self ,
@@ -147,7 +164,7 @@ def add_update(
147164 array_filters : Optional [list [Mapping [str , Any ]]] = None ,
148165 hint : Union [str , dict [str , Any ], None ] = None ,
149166 sort : Optional [Mapping [str , Any ]] = None ,
150- ) -> None :
167+ ) -> bool :
151168 """Create an update document and add it to the list of ops."""
152169 validate_ok_for_update (update )
153170 cmd : dict [str , Any ] = {"q" : selector , "u" : update , "multi" : multi }
@@ -165,10 +182,12 @@ def add_update(
165182 if sort is not None :
166183 self .uses_sort = True
167184 cmd ["sort" ] = sort
185+
186+ self .ops .append ((_UPDATE , cmd ))
168187 if multi :
169188 # A bulk_write containing an update_many is not retryable.
170- self . is_retryable = False
171- self . ops . append (( _UPDATE , cmd ))
189+ return False
190+ return True
172191
173192 def add_replace (
174193 self ,
@@ -178,7 +197,7 @@ def add_replace(
178197 collation : Optional [Mapping [str , Any ]] = None ,
179198 hint : Union [str , dict [str , Any ], None ] = None ,
180199 sort : Optional [Mapping [str , Any ]] = None ,
181- ) -> None :
200+ ) -> bool :
182201 """Create a replace document and add it to the list of ops."""
183202 validate_ok_for_replace (replacement )
184203 cmd : dict [str , Any ] = {"q" : selector , "u" : replacement }
@@ -194,14 +213,15 @@ def add_replace(
194213 self .uses_sort = True
195214 cmd ["sort" ] = sort
196215 self .ops .append ((_UPDATE , cmd ))
216+ return True
197217
198218 def add_delete (
199219 self ,
200220 selector : Mapping [str , Any ],
201221 limit : int ,
202222 collation : Optional [Mapping [str , Any ]] = None ,
203223 hint : Union [str , dict [str , Any ], None ] = None ,
204- ) -> None :
224+ ) -> bool :
205225 """Create a delete document and add it to the list of ops."""
206226 cmd : dict [str , Any ] = {"q" : selector , "limit" : limit }
207227 if collation is not None :
@@ -210,44 +230,50 @@ def add_delete(
210230 if hint is not None :
211231 self .uses_hint_delete = True
212232 cmd ["hint" ] = hint
233+
234+ self .ops .append ((_DELETE , cmd ))
213235 if limit == _DELETE_ALL :
214236 # A bulk_write containing a delete_many is not retryable.
215- self . is_retryable = False
216- self . ops . append (( _DELETE , cmd ))
237+ return False
238+ return True
217239
218- def gen_ordered (self , requests ) -> Iterator [Optional [_Run ]]:
240+ def gen_ordered (
241+ self ,
242+ requests : Iterable [Any ],
243+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
244+ ) -> Iterator [_Run ]:
219245 """Generate batches of operations, batched by type of
220246 operation, in the order **provided**.
221247 """
222248 run = None
223249 for idx , request in enumerate (requests ):
224- try :
225- request ._add_to_bulk (self )
226- except AttributeError :
227- raise TypeError (f"{ request !r} is not a valid request" ) from None
250+ retryable = process (request )
228251 (op_type , operation ) = self .ops [idx ]
229252 if run is None :
230253 run = _Run (op_type )
231254 elif run .op_type != op_type :
232255 yield run
233256 run = _Run (op_type )
234257 run .add (idx , operation )
258+ run .is_retryable = run .is_retryable and retryable
235259 if run is None :
236260 raise InvalidOperation ("No operations to execute" )
237261 yield run
238262
239- def gen_unordered (self , requests ) -> Iterator [_Run ]:
263+ def gen_unordered (
264+ self ,
265+ requests : Iterable [Any ],
266+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
267+ ) -> Iterator [_Run ]:
240268 """Generate batches of operations, batched by type of
241269 operation, in arbitrary order.
242270 """
243271 operations = [_Run (_INSERT ), _Run (_UPDATE ), _Run (_DELETE )]
244272 for idx , request in enumerate (requests ):
245- try :
246- request ._add_to_bulk (self )
247- except AttributeError :
248- raise TypeError (f"{ request !r} is not a valid request" ) from None
273+ retryable = process (request )
249274 (op_type , operation ) = self .ops [idx ]
250275 operations [op_type ].add (idx , operation )
276+ operations [op_type ].is_retryable = operations [op_type ].is_retryable and retryable
251277 if (
252278 len (operations [_INSERT ].ops ) == 0
253279 and len (operations [_UPDATE ].ops ) == 0
@@ -488,8 +514,8 @@ async def _execute_command(
488514 session : Optional [AsyncClientSession ],
489515 conn : AsyncConnection ,
490516 op_id : int ,
491- retryable : bool ,
492517 full_result : MutableMapping [str , Any ],
518+ validate : bool ,
493519 final_write_concern : Optional [WriteConcern ] = None ,
494520 ) -> None :
495521 db_name = self .collection .database .name
@@ -507,7 +533,7 @@ async def _execute_command(
507533 last_run = False
508534
509535 while run :
510- if not self .retrying :
536+ if not run .retrying :
511537 self .next_run = next (generator , None )
512538 if self .next_run is None :
513539 last_run = True
@@ -541,20 +567,21 @@ async def _execute_command(
541567 if session :
542568 # Start a new retryable write unless one was already
543569 # started for this command.
544- if retryable and not self .started_retryable_write :
570+ if run . is_retryable and not run .started_retryable_write :
545571 session ._start_retryable_write ()
546572 self .started_retryable_write = True
547- session ._apply_to (cmd , retryable , ReadPreference .PRIMARY , conn )
573+ session ._apply_to (cmd , run . is_retryable , ReadPreference .PRIMARY , conn )
548574 conn .send_cluster_time (cmd , session , client )
549575 conn .add_server_api (cmd )
550576 # CSOT: apply timeout before encoding the command.
551577 conn .apply_timeout (client , cmd )
552578 ops = islice (run .ops , run .idx_offset , None )
553579
554580 # Run as many ops as possible in one command.
581+ if validate :
582+ await self .validate_batch (conn , write_concern )
555583 if write_concern .acknowledged :
556584 result , to_send = await self ._execute_batch (bwc , cmd , ops , client )
557-
558585 # Retryable writeConcernErrors halt the execution of this run.
559586 wce = result .get ("writeConcernError" , {})
560587 if wce .get ("code" , 0 ) in _RETRYABLE_ERROR_CODES :
@@ -567,8 +594,8 @@ async def _execute_command(
567594 _merge_command (run , full_result , run .idx_offset , result )
568595
569596 # We're no longer in a retry once a command succeeds.
570- self .retrying = False
571- self .started_retryable_write = False
597+ run .retrying = False
598+ run .started_retryable_write = False
572599
573600 if self .ordered and "writeErrors" in result :
574601 break
@@ -606,34 +633,33 @@ async def execute_command(
606633 op_id = _randint ()
607634
608635 async def retryable_bulk (
609- session : Optional [AsyncClientSession ], conn : AsyncConnection , retryable : bool
636+ session : Optional [AsyncClientSession ],
637+ conn : AsyncConnection ,
610638 ) -> None :
611639 await self ._execute_command (
612640 generator ,
613641 write_concern ,
614642 session ,
615643 conn ,
616644 op_id ,
617- retryable ,
618645 full_result ,
646+ validate = False ,
619647 )
620648
621649 client = self .collection .database .client
622650 _ = await client ._retryable_write (
623- self .is_retryable ,
624651 retryable_bulk ,
625652 session ,
626653 operation ,
627654 bulk = self , # type: ignore[arg-type]
628655 operation_id = op_id ,
629656 )
630-
631657 if full_result ["writeErrors" ] or full_result ["writeConcernErrors" ]:
632658 _raise_bulk_write_error (full_result )
633659 return full_result
634660
635661 async def execute_op_msg_no_results (
636- self , conn : AsyncConnection , generator : Iterator [Any ]
662+ self , conn : AsyncConnection , generator : Iterator [Any ], write_concern : WriteConcern
637663 ) -> None :
638664 """Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
639665 db_name = self .collection .database .name
@@ -667,6 +693,7 @@ async def execute_op_msg_no_results(
667693 conn .add_server_api (cmd )
668694 ops = islice (run .ops , run .idx_offset , None )
669695 # Run as many ops as possible.
696+ await self .validate_batch (conn , write_concern )
670697 to_send = await self ._execute_batch_unack (bwc , cmd , ops , client )
671698 run .idx_offset += len (to_send )
672699 self .current_run = run = next (generator , None )
@@ -700,12 +727,15 @@ async def execute_command_no_results(
700727 None ,
701728 conn ,
702729 op_id ,
703- False ,
704730 full_result ,
731+ True ,
705732 write_concern ,
706733 )
707- except OperationFailure :
708- pass
734+ except OperationFailure as exc :
735+ if "Cannot set bypass_document_validation with unacknowledged write concern" in str (
736+ exc
737+ ):
738+ raise exc
709739
710740 async def execute_no_results (
711741 self ,
@@ -714,6 +744,11 @@ async def execute_no_results(
714744 write_concern : WriteConcern ,
715745 ) -> None :
716746 """Execute all operations, returning no results (w=0)."""
747+ if self .ordered :
748+ return await self .execute_command_no_results (conn , generator , write_concern )
749+ return await self .execute_op_msg_no_results (conn , generator , write_concern )
750+
751+ async def validate_batch (self , conn : AsyncConnection , write_concern : WriteConcern ) -> None :
717752 if self .uses_collation :
718753 raise ConfigurationError ("Collation is unsupported for unacknowledged writes." )
719754 if self .uses_array_filters :
@@ -738,13 +773,10 @@ async def execute_no_results(
738773 "Cannot set bypass_document_validation with unacknowledged write concern"
739774 )
740775
741- if self .ordered :
742- return await self .execute_command_no_results (conn , generator , write_concern )
743- return await self .execute_op_msg_no_results (conn , generator )
744-
745776 async def execute (
746777 self ,
747- generator : Generator [_WriteOp [_DocumentType ]],
778+ generator : Iterable [Any ],
779+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
748780 write_concern : WriteConcern ,
749781 session : Optional [AsyncClientSession ],
750782 operation : str ,
@@ -757,9 +789,9 @@ async def execute(
757789 session = _validate_session_write_concern (session , write_concern )
758790
759791 if self .ordered :
760- generator = self .gen_ordered (generator )
792+ generator = self .gen_ordered (generator , process )
761793 else :
762- generator = self .gen_unordered (generator )
794+ generator = self .gen_unordered (generator , process )
763795
764796 client = self .collection .database .client
765797 if not write_concern .acknowledged :
0 commit comments