1010from .errors import DataJointError
1111from .table import FreeTable
1212import signal
13+ import multiprocessing as mp
1314
1415# noinspection PyExceptionInherit,PyCallingNonCallable
1516
1617logger = logging .getLogger (__name__ )
1718
1819
20+ def initializer (table ):
21+ """Save pickled copy of (disconnected) table to the current process,
22+ then reconnect to server. For use by call_make_key()"""
23+ mp .current_process ().table = table
24+ table .connection .connect () # reconnect
25+
26+ def call_make_key (key ):
27+ """Call current process' table.make_key()"""
28+ table = mp .current_process ().table
29+ error = table .make_key (key )
30+ return error
31+
32+
1933class AutoPopulate :
2034 """
2135 AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -103,29 +117,36 @@ def _jobs_to_do(self, restrictions):
103117
104118 def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
105119 reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
106- display_progress = False ):
120+ display_progress = False , multiprocess = False ):
107121 """
108122 rel.populate() calls rel.make(key) for every primary key in self.key_source
109123 for which there is not already a tuple in rel.
110124 :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
111125 :param suppress_errors: if True, do not terminate execution.
112126 :param return_exception_objects: return error objects instead of just error messages
113- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
127+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
114128 :param order: "original"|"reverse"|"random" - the order of execution
129+ :param limit: if not None, check at most this many keys
130+ :param max_calls: if not None, populate at most this many keys
115131 :param display_progress: if True, report progress_bar
116- :param limit : if not None, checks at most that many keys
117- :param max_calls: if not None, populates at max that many keys
132+ :param multiprocess : if True, use as many processes as CPU cores, or use the integer
133+ number of processes specified
118134 """
119135 if self .connection .in_transaction :
120136 raise DataJointError ('Populate cannot be called during a transaction.' )
121137
122138 valid_order = ['original' , 'reverse' , 'random' ]
123139 if order not in valid_order :
124140 raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
125- error_list = [] if suppress_errors else None
126141 jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
127142
128- # define and setup signal handler for SIGTERM
143+ self ._make_key_kwargs = {'suppress_errors' :suppress_errors ,
144+ 'return_exception_objects' :return_exception_objects ,
145+ 'reserve_jobs' :reserve_jobs ,
146+ 'jobs' :jobs ,
147+ }
148+
149+ # define and set up signal handler for SIGTERM:
129150 if reserve_jobs :
130151 def handler (signum , frame ):
131152 logger .info ('Populate terminated by SIGTERM' )
@@ -138,55 +159,101 @@ def handler(signum, frame):
138159 elif order == "random" :
139160 random .shuffle (keys )
140161
141- call_count = 0
142162 logger .info ('Found %d keys to populate' % len (keys ))
143163
144- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
164+ if max_calls is not None :
165+ keys = keys [:max_calls ]
166+ nkeys = len (keys )
145167
146- for key in (tqdm (keys ) if display_progress else keys ):
147- if max_calls is not None and call_count >= max_calls :
148- break
149- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
150- self .connection .start_transaction ()
151- if key in self .target : # already populated
152- self .connection .cancel_transaction ()
153- if reserve_jobs :
154- jobs .complete (self .target .table_name , self ._job_key (key ))
168+ if multiprocess : # True or int, presumably
169+ if multiprocess is True :
170+ nproc = mp .cpu_count ()
171+ else :
172+ if not isinstance (multiprocess , int ):
173+ raise DataJointError ("multiprocess can be False, True or a positive integer" )
174+ nproc = multiprocess
175+ else :
176+ nproc = 1
177+ nproc = min (nproc , nkeys ) # no sense spawning more than can be used
178+ error_list = []
179+ if nproc > 1 : # spawn multiple processes
180+ # prepare to pickle self:
181+ self .connection .close () # disconnect parent process from MySQL server
182+ del self .connection ._conn .ctx # SSLContext is not picklable
183+ print ('*** Spawning pool of %d processes' % nproc )
184+ # send pickled copy of self to each process,
185+ # each worker process calls initializer(*initargs) when it starts
186+ with mp .Pool (nproc , initializer , (self ,)) as pool :
187+ if display_progress :
188+ with tqdm (total = nkeys ) as pbar :
189+ for error in pool .imap (call_make_key , keys , chunksize = 1 ):
190+ if error is not None :
191+ error_list .append (error )
192+ pbar .update ()
155193 else :
156- logger .info ('Populating: ' + str (key ))
157- call_count += 1
158- self .__class__ ._allow_insert = True
159- try :
160- make (dict (key ))
161- except (KeyboardInterrupt , SystemExit , Exception ) as error :
162- try :
163- self .connection .cancel_transaction ()
164- except OperationalError :
165- pass
166- error_message = '{exception}{msg}' .format (
167- exception = error .__class__ .__name__ ,
168- msg = ': ' + str (error ) if str (error ) else '' )
169- if reserve_jobs :
170- # show error name and error message (if any)
171- jobs .error (
172- self .target .table_name , self ._job_key (key ),
173- error_message = error_message , error_stack = traceback .format_exc ())
174- if not suppress_errors or isinstance (error , SystemExit ):
175- raise
176- else :
177- logger .error (error )
178- error_list .append ((key , error if return_exception_objects else error_message ))
179- else :
180- self .connection .commit_transaction ()
181- if reserve_jobs :
182- jobs .complete (self .target .table_name , self ._job_key (key ))
183- finally :
184- self .__class__ ._allow_insert = False
194+ for error in pool .imap (call_make_key , keys ):
195+ if error is not None :
196+ error_list .append (error )
197+ self .connection .connect () # reconnect parent process to MySQL server
198+ else : # use single process
199+ for key in tqdm (keys ) if display_progress else keys :
200+ error = self .make_key (key )
201+ if error is not None :
202+ error_list .append (error )
185203
186- # place back the original signal handler
204+ del self ._make_key_kwargs # clean up
205+
206+ # restore original signal handler:
187207 if reserve_jobs :
188208 signal .signal (signal .SIGTERM , old_handler )
189- return error_list
209+
210+ if suppress_errors :
211+ return error_list
212+
213+ def make_key (self , key ):
214+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
215+
216+ kwargs = self ._make_key_kwargs
217+ suppress_errors = kwargs ['suppress_errors' ]
218+ return_exception_objects = kwargs ['return_exception_objects' ]
219+ reserve_jobs = kwargs ['reserve_jobs' ]
220+ jobs = kwargs ['jobs' ]
221+
222+ if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
223+ self .connection .start_transaction ()
224+ if key in self .target : # already populated
225+ self .connection .cancel_transaction ()
226+ if reserve_jobs :
227+ jobs .complete (self .target .table_name , self ._job_key (key ))
228+ else :
229+ logger .info ('Populating: ' + str (key ))
230+ self .__class__ ._allow_insert = True
231+ try :
232+ make (dict (key ))
233+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
234+ try :
235+ self .connection .cancel_transaction ()
236+ except OperationalError :
237+ pass
238+ error_message = '{exception}{msg}' .format (
239+ exception = error .__class__ .__name__ ,
240+ msg = ': ' + str (error ) if str (error ) else '' )
241+ if reserve_jobs :
242+ # show error name and error message (if any)
243+ jobs .error (
244+ self .target .table_name , self ._job_key (key ),
245+ error_message = error_message , error_stack = traceback .format_exc ())
246+ if not suppress_errors or isinstance (error , SystemExit ):
247+ raise
248+ else :
249+ logger .error (error )
250+ return (key , error if return_exception_objects else error_message )
251+ else :
252+ self .connection .commit_transaction ()
253+ if reserve_jobs :
254+ jobs .complete (self .target .table_name , self ._job_key (key ))
255+ finally :
256+ self .__class__ ._allow_insert = False
190257
191258 def progress (self , * restrictions , display = True ):
192259 """
0 commit comments