0029935: Foundation Classes - introduce OSD_ThreadPool class defining a thread pool
[occt.git] / src / OSD / OSD_ThreadPool.cxx
1 // Created by: Kirill Gavrilov
2 // Copyright (c) 2017 OPEN CASCADE SAS
3 //
4 // This file is part of commercial software by OPEN CASCADE SAS.
5 //
6 // This software is furnished in accordance with the terms and conditions
7 // of the contract and with the inclusion of this copyright notice.
8 // This software or any other copy thereof may not be provided or otherwise
9 // be made available to any third party.
10 // No ownership title to the software is transferred hereby.
11 //
12 // OPEN CASCADE SAS makes no representation or warranties with respect to the
13 // performance of this software, and specifically disclaims any responsibility
14 // for any damages, special or consequential, connected with its use.
15
16 #include <OSD_ThreadPool.hxx>
17
18 #include <OSD.hxx>
19 #include <Standard_Atomic.hxx>
20 #include <TCollection_AsciiString.hxx>
21
22 IMPLEMENT_STANDARD_RTTIEXT(OSD_ThreadPool, Standard_Transient)
23
24 // =======================================================================
25 // function : Lock
26 // purpose  :
27 // =======================================================================
28 bool OSD_ThreadPool::EnumeratedThread::Lock()
29 {
30   return Standard_Atomic_CompareAndSwap (&myUsageCounter, 0, 1);
31 }
32
33 // =======================================================================
34 // function : Free
35 // purpose  :
36 // =======================================================================
37 void OSD_ThreadPool::EnumeratedThread::Free()
38 {
39   Standard_Atomic_CompareAndSwap (&myUsageCounter, 1, 0);
40 }
41
42 // =======================================================================
43 // function : WakeUp
44 // purpose  :
45 // =======================================================================
46 void OSD_ThreadPool::EnumeratedThread::WakeUp (JobInterface* theJob, bool theToCatchFpe)
47 {
48   myJob = theJob;
49   myToCatchFpe = theToCatchFpe;
50   if (myIsSelfThread)
51   {
52     if (theJob != NULL)
53     {
54       OSD_ThreadPool::performJob (myFailure, myJob, myThreadIndex);
55     }
56     return;
57   }
58
59   myWakeEvent.Set();
60   if (theJob != NULL && !myIsStarted)
61   {
62     myIsStarted = true;
63     Run (this);
64   }
65 }
66
67 // =======================================================================
68 // function : WaitIdle
69 // purpose  :
70 // =======================================================================
71 void OSD_ThreadPool::EnumeratedThread::WaitIdle()
72 {
73   if (!myIsSelfThread)
74   {
75     myIdleEvent.Wait();
76     myIdleEvent.Reset();
77   }
78 }
79
80 // =======================================================================
81 // function : DefaultPool
82 // purpose  :
83 // =======================================================================
84 const Handle(OSD_ThreadPool)& OSD_ThreadPool::DefaultPool (int theNbThreads)
85 {
86   static const Handle(OSD_ThreadPool) THE_GLOBAL_POOL = new OSD_ThreadPool (theNbThreads);
87   return THE_GLOBAL_POOL;
88 }
89
90 // =======================================================================
91 // function : OSD_ThreadPool
92 // purpose  :
93 // =======================================================================
94 OSD_ThreadPool::OSD_ThreadPool (int theNbThreads)
95 : myNbDefThreads (0),
96   myShutDown (false)
97 {
98   Init (theNbThreads);
99   myNbDefThreads = NbThreads();
100 }
101
102 // =======================================================================
103 // function : IsInUse
104 // purpose  :
105 // =======================================================================
106 bool OSD_ThreadPool::IsInUse()
107 {
108   for (NCollection_Array1<EnumeratedThread>::Iterator aThreadIter (myThreads);
109        aThreadIter.More(); aThreadIter.Next())
110   {
111     EnumeratedThread& aThread = aThreadIter.ChangeValue();
112     if (!aThread.Lock())
113     {
114       return true;
115     }
116     aThread.Free();
117   }
118   return false;
119 }
120
121 // =======================================================================
122 // function : Init
123 // purpose  :
124 // =======================================================================
125 void OSD_ThreadPool::Init (int theNbThreads)
126 {
127   const int aNbThreads = Max (0, (theNbThreads > 0 ? theNbThreads : OSD_Parallel::NbLogicalProcessors()) - 1);
128   if (myThreads.Size() == aNbThreads)
129   {
130     return;
131   }
132
133   // release old threads
134   if (!myThreads.IsEmpty())
135   {
136     NCollection_Array1<EnumeratedThread*> aLockThreads (myThreads.Lower(), myThreads.Upper());
137     aLockThreads.Init (NULL);
138     int aThreadIndex = myThreads.Lower();
139     for (NCollection_Array1<EnumeratedThread>::Iterator aThreadIter (myThreads);
140          aThreadIter.More(); aThreadIter.Next())
141     {
142       EnumeratedThread& aThread = aThreadIter.ChangeValue();
143       if (!aThread.Lock())
144       {
145         for (NCollection_Array1<EnumeratedThread*>::Iterator aLockThreadIter (aLockThreads);
146              aLockThreadIter.More() && aLockThreadIter.Value() != NULL; aLockThreadIter.Next())
147         {
148           aLockThreadIter.ChangeValue()->Free();
149         }
150         throw Standard_ProgramError ("Error: active ThreadPool is reinitialized");
151       }
152       aLockThreads.SetValue (aThreadIndex++, &aThread);
153     }
154   }
155   release();
156
157   myShutDown = false;
158   if (aNbThreads > 0)
159   {
160     myThreads.Resize (0, aNbThreads - 1, false);
161     int aLastThreadIndex = 0;
162     for (NCollection_Array1<EnumeratedThread>::Iterator aThreadIter (myThreads);
163          aThreadIter.More(); aThreadIter.Next())
164     {
165       EnumeratedThread& aThread = aThreadIter.ChangeValue();
166       aThread.myPool        = this;
167       aThread.myThreadIndex = aLastThreadIndex++;
168       aThread.SetFunction (&OSD_ThreadPool::EnumeratedThread::runThread);
169     }
170   }
171   else
172   {
173     NCollection_Array1<EnumeratedThread> anEmpty;
174     myThreads.Move (anEmpty);
175   }
176 }
177
178 // =======================================================================
179 // function : ~OSD_ThreadPool
180 // purpose  :
181 // =======================================================================
182 OSD_ThreadPool::~OSD_ThreadPool()
183 {
184   release();
185 }
186
187 // =======================================================================
188 // function : release
189 // purpose  :
190 // =======================================================================
191 void OSD_ThreadPool::release()
192 {
193   if (myThreads.IsEmpty())
194   {
195     return;
196   }
197
198   myShutDown = true;
199   for (NCollection_Array1<EnumeratedThread>::Iterator aThreadIter (myThreads);
200        aThreadIter.More(); aThreadIter.Next())
201   {
202     aThreadIter.ChangeValue().WakeUp (NULL, false);
203     aThreadIter.ChangeValue().Wait();
204   }
205 }
206
207 // =======================================================================
208 // function : perform
209 // purpose  :
210 // =======================================================================
211 void OSD_ThreadPool::Launcher::perform (JobInterface& theJob)
212 {
213   run (theJob);
214   wait();
215 }
216
217 // =======================================================================
218 // function : run
219 // purpose  :
220 // =======================================================================
221 void OSD_ThreadPool::Launcher::run (JobInterface& theJob)
222 {
223   bool toCatchFpe = OSD::ToCatchFloatingSignals();
224   for (NCollection_Array1<EnumeratedThread*>::Iterator aThreadIter (myThreads);
225        aThreadIter.More() && aThreadIter.Value() != NULL; aThreadIter.Next())
226   {
227     aThreadIter.ChangeValue()->WakeUp (&theJob, toCatchFpe);
228   }
229 }
230
231 // =======================================================================
232 // function : wait
233 // purpose  :
234 // =======================================================================
235 void OSD_ThreadPool::Launcher::wait()
236 {
237   int aNbFailures = 0;
238   for (NCollection_Array1<EnumeratedThread*>::Iterator aThreadIter (myThreads);
239        aThreadIter.More() && aThreadIter.Value() != NULL; aThreadIter.Next())
240   {
241     aThreadIter.ChangeValue()->WaitIdle();
242     if (!aThreadIter.Value()->myFailure.IsNull())
243     {
244       ++aNbFailures;
245     }
246   }
247   if (aNbFailures == 0)
248   {
249     return;
250   }
251
252   TCollection_AsciiString aFailures;
253   for (NCollection_Array1<EnumeratedThread*>::Iterator aThreadIter (myThreads);
254        aThreadIter.More() && aThreadIter.Value() != NULL; aThreadIter.Next())
255   {
256     if (!aThreadIter.Value()->myFailure.IsNull())
257     {
258       if (aNbFailures == 1)
259       {
260         aThreadIter.Value()->myFailure->Reraise();
261       }
262
263       if (!aFailures.IsEmpty())
264       {
265         aFailures += "\n";
266       }
267       aFailures += aThreadIter.Value()->myFailure->GetMessageString();
268     }
269   }
270
271   aFailures = TCollection_AsciiString("Multiple exceptions:\n") + aFailures;
272   throw Standard_ProgramError (aFailures.ToCString());
273 }
274
275 // =======================================================================
276 // function : performJob
277 // purpose  :
278 // =======================================================================
279 void OSD_ThreadPool::performJob (Handle(Standard_Failure)& theFailure,
280                                  OSD_ThreadPool::JobInterface* theJob,
281                                  int theThreadIndex)
282 {
283   try
284   {
285     OCC_CATCH_SIGNALS
286     theJob->Perform (theThreadIndex);
287   }
288   catch (Standard_Failure const& aFailure)
289   {
290     TCollection_AsciiString aMsg = TCollection_AsciiString (aFailure.DynamicType()->Name())
291                                  + ": " + aFailure.GetMessageString();
292     theFailure = new Standard_ProgramError (aMsg.ToCString());
293   }
294   catch (std::exception& anStdException)
295   {
296     TCollection_AsciiString aMsg = TCollection_AsciiString (typeid(anStdException).name())
297                                  + ": " + anStdException.what();
298     theFailure = new Standard_ProgramError (aMsg.ToCString());
299   }
300   catch (...)
301   {
302     theFailure = new Standard_ProgramError ("Error: Unknown exception");
303   }
304 }
305
306 // =======================================================================
307 // function : performThread
308 // purpose  :
309 // =======================================================================
310 void OSD_ThreadPool::EnumeratedThread::performThread()
311 {
312   OSD::SetSignal (false);
313   for (;;)
314   {
315     myWakeEvent.Wait();
316     myWakeEvent.Reset();
317     if (myPool->myShutDown)
318     {
319       return;
320     }
321
322     myFailure.Nullify();
323     if (myJob != NULL)
324     {
325       OSD::SetSignal (myToCatchFpe);
326       OSD_ThreadPool::performJob (myFailure, myJob, myThreadIndex);
327       myJob = NULL;
328     }
329     myIdleEvent.Set();
330   }
331 }
332
333 // =======================================================================
334 // function : runThread
335 // purpose  :
336 // =======================================================================
337 Standard_Address OSD_ThreadPool::EnumeratedThread::runThread (Standard_Address theTask)
338 {
339   EnumeratedThread* aThread = static_cast<EnumeratedThread*>(theTask);
340   aThread->performThread();
341   return NULL;
342 }
343
344 // =======================================================================
345 // function : Launcher
346 // purpose  :
347 // =======================================================================
348 OSD_ThreadPool::Launcher::Launcher (OSD_ThreadPool& thePool, Standard_Integer theMaxThreads)
349 : mySelfThread (true),
350   myNbThreads (0)
351 {
352   const int aNbThreads = theMaxThreads > 0
353                        ? Min (theMaxThreads, thePool.NbThreads())
354                        : (theMaxThreads < 0
355                         ? Max (thePool.NbDefaultThreadsToLaunch(), 1)
356                         : 1);
357   myThreads.Resize (0, aNbThreads - 1, false);
358   myThreads.Init (NULL);
359   if (aNbThreads > 1)
360   {
361     for (NCollection_Array1<EnumeratedThread>::Iterator aThreadIter (thePool.myThreads);
362          aThreadIter.More(); aThreadIter.Next())
363     {
364       if (aThreadIter.ChangeValue().Lock())
365       {
366         myThreads.SetValue (myNbThreads, &aThreadIter.ChangeValue());
367         // make thread index to fit into myThreads range
368         aThreadIter.ChangeValue().myThreadIndex = myNbThreads;
369         if (++myNbThreads == aNbThreads - 1)
370         {
371           break;
372         }
373       }
374     }
375   }
376
377   // self thread should be executed last
378   myThreads.SetValue (myNbThreads, &mySelfThread);
379   mySelfThread.myThreadIndex = myNbThreads;
380   ++myNbThreads;
381 }
382
383 // =======================================================================
384 // function : Release
385 // purpose  :
386 // =======================================================================
387 void OSD_ThreadPool::Launcher::Release()
388 {
389   for (NCollection_Array1<EnumeratedThread*>::Iterator aThreadIter (myThreads);
390        aThreadIter.More() && aThreadIter.Value() != NULL; aThreadIter.Next())
391   {
392     if (aThreadIter.Value() != &mySelfThread)
393     {
394       aThreadIter.Value()->Free();
395     }
396   }
397
398   NCollection_Array1<EnumeratedThread*> anEmpty;
399   myThreads.Move (anEmpty);
400   myNbThreads = 0;
401 }