99from .errors import DataJointError , LostConnectionError
1010from .table import FreeTable
1111import signal
12+ import multiprocessing as mp
1213
1314# noinspection PyExceptionInherit,PyCallingNonCallable
1415
1516logger = logging .getLogger (__name__ )
1617
1718
19+ def initializer (table ):
20+ """Save pickled copy of (disconnected) table to the current process,
21+ then reconnect to server. For use by call_make_key()"""
22+ mp .current_process ().table = table
23+ table .connection .connect () # reconnect
24+
25+ def call_make_key (key ):
26+ """Call current process' table.make_key()"""
27+ table = mp .current_process ().table
28+ error = table .make_key (key )
29+ return error
30+
31+
1832class AutoPopulate :
1933 """
2034 AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -102,29 +116,36 @@ def _jobs_to_do(self, restrictions):
102116
103117 def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
104118 reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
105- display_progress = False ):
119+ display_progress = False , multiprocess = False ):
106120 """
107121 rel.populate() calls rel.make(key) for every primary key in self.key_source
108122 for which there is not already a tuple in rel.
109123 :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
110124 :param suppress_errors: if True, do not terminate execution.
111125 :param return_exception_objects: return error objects instead of just error messages
112- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
126+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
113127 :param order: "original"|"reverse"|"random" - the order of execution
128+ :param limit: if not None, check at most this many keys
129+ :param max_calls: if not None, populate at most this many keys
114130 :param display_progress: if True, report progress_bar
115- :param limit : if not None, checks at most that many keys
116- :param max_calls: if not None, populates at max that many keys
131+ :param multiprocess : if True, use as many processes as CPU cores, or use the integer
132+ number of processes specified
117133 """
118134 if self .connection .in_transaction :
119135 raise DataJointError ('Populate cannot be called during a transaction.' )
120136
121137 valid_order = ['original' , 'reverse' , 'random' ]
122138 if order not in valid_order :
123139 raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
124- error_list = [] if suppress_errors else None
125140 jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
126141
127- # define and setup signal handler for SIGTERM
142+ self ._make_key_kwargs = {'suppress_errors' :suppress_errors ,
143+ 'return_exception_objects' :return_exception_objects ,
144+ 'reserve_jobs' :reserve_jobs ,
145+ 'jobs' :jobs ,
146+ }
147+
148+ # define and set up signal handler for SIGTERM:
128149 if reserve_jobs :
129150 def handler (signum , frame ):
130151 logger .info ('Populate terminated by SIGTERM' )
@@ -137,55 +158,101 @@ def handler(signum, frame):
137158 elif order == "random" :
138159 random .shuffle (keys )
139160
140- call_count = 0
141161 logger .info ('Found %d keys to populate' % len (keys ))
142162
143- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
163+ if max_calls is not None :
164+ keys = keys [:max_calls ]
165+ nkeys = len (keys )
144166
145- for key in (tqdm (keys ) if display_progress else keys ):
146- if max_calls is not None and call_count >= max_calls :
147- break
148- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
149- self .connection .start_transaction ()
150- if key in self .target : # already populated
151- self .connection .cancel_transaction ()
152- if reserve_jobs :
153- jobs .complete (self .target .table_name , self ._job_key (key ))
167+ if multiprocess : # True or int, presumably
168+ if multiprocess is True :
169+ nproc = mp .cpu_count ()
170+ else :
171+ if not isinstance (multiprocess , int ):
172+ raise DataJointError ("multiprocess can be False, True or a positive integer" )
173+ nproc = multiprocess
174+ else :
175+ nproc = 1
176+ nproc = min (nproc , nkeys ) # no sense spawning more than can be used
177+ error_list = []
178+ if nproc > 1 : # spawn multiple processes
179+ # prepare to pickle self:
180+ self .connection .close () # disconnect parent process from MySQL server
181+ del self .connection ._conn .ctx # SSLContext is not picklable
182+ print ('*** Spawning pool of %d processes' % nproc )
183+ # send pickled copy of self to each process,
184+ # each worker process calls initializer(*initargs) when it starts
185+ with mp .Pool (nproc , initializer , (self ,)) as pool :
186+ if display_progress :
187+ with tqdm (total = nkeys ) as pbar :
188+ for error in pool .imap (call_make_key , keys , chunksize = 1 ):
189+ if error is not None :
190+ error_list .append (error )
191+ pbar .update ()
154192 else :
155- logger .info ('Populating: ' + str (key ))
156- call_count += 1
157- self .__class__ ._allow_insert = True
158- try :
159- make (dict (key ))
160- except (KeyboardInterrupt , SystemExit , Exception ) as error :
161- try :
162- self .connection .cancel_transaction ()
163- except LostConnectionError :
164- pass
165- error_message = '{exception}{msg}' .format (
166- exception = error .__class__ .__name__ ,
167- msg = ': ' + str (error ) if str (error ) else '' )
168- if reserve_jobs :
169- # show error name and error message (if any)
170- jobs .error (
171- self .target .table_name , self ._job_key (key ),
172- error_message = error_message , error_stack = traceback .format_exc ())
173- if not suppress_errors or isinstance (error , SystemExit ):
174- raise
175- else :
176- logger .error (error )
177- error_list .append ((key , error if return_exception_objects else error_message ))
178- else :
179- self .connection .commit_transaction ()
180- if reserve_jobs :
181- jobs .complete (self .target .table_name , self ._job_key (key ))
182- finally :
183- self .__class__ ._allow_insert = False
193+ for error in pool .imap (call_make_key , keys ):
194+ if error is not None :
195+ error_list .append (error )
196+ self .connection .connect () # reconnect parent process to MySQL server
197+ else : # use single process
198+ for key in tqdm (keys ) if display_progress else keys :
199+ error = self .make_key (key )
200+ if error is not None :
201+ error_list .append (error )
184202
185- # place back the original signal handler
203+ del self ._make_key_kwargs # clean up
204+
205+ # restore original signal handler:
186206 if reserve_jobs :
187207 signal .signal (signal .SIGTERM , old_handler )
188- return error_list
208+
209+ if suppress_errors :
210+ return error_list
211+
212+ def make_key (self , key ):
213+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
214+
215+ kwargs = self ._make_key_kwargs
216+ suppress_errors = kwargs ['suppress_errors' ]
217+ return_exception_objects = kwargs ['return_exception_objects' ]
218+ reserve_jobs = kwargs ['reserve_jobs' ]
219+ jobs = kwargs ['jobs' ]
220+
221+ if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
222+ self .connection .start_transaction ()
223+ if key in self .target : # already populated
224+ self .connection .cancel_transaction ()
225+ if reserve_jobs :
226+ jobs .complete (self .target .table_name , self ._job_key (key ))
227+ else :
228+ logger .info ('Populating: ' + str (key ))
229+ self .__class__ ._allow_insert = True
230+ try :
231+ make (dict (key ))
232+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
233+ try :
234+ self .connection .cancel_transaction ()
235+ except LostConnectionError :
236+ pass
237+ error_message = '{exception}{msg}' .format (
238+ exception = error .__class__ .__name__ ,
239+ msg = ': ' + str (error ) if str (error ) else '' )
240+ if reserve_jobs :
241+ # show error name and error message (if any)
242+ jobs .error (
243+ self .target .table_name , self ._job_key (key ),
244+ error_message = error_message , error_stack = traceback .format_exc ())
245+ if not suppress_errors or isinstance (error , SystemExit ):
246+ raise
247+ else :
248+ logger .error (error )
249+ return (key , error if return_exception_objects else error_message )
250+ else :
251+ self .connection .commit_transaction ()
252+ if reserve_jobs :
253+ jobs .complete (self .target .table_name , self ._job_key (key ))
254+ finally :
255+ self .__class__ ._allow_insert = False
189256
190257 def progress (self , * restrictions , display = True ):
191258 """
0 commit comments