0026377: Passing Handle objects as arguments to functions as non-const reference...
[occt.git] / src / OSD / OSD_Parallel.hxx
1 // Copyright (c) 2013-2014 OPEN CASCADE SAS
2 //
3 // This file is part of Open CASCADE Technology software library.
4 //
5 // This library is free software; you can redistribute it and/or modify it under
6 // the terms of the GNU Lesser General Public License version 2.1 as published
7 // by the Free Software Foundation, with special exception defined in the file
8 // OCCT_LGPL_EXCEPTION.txt. Consult the file LICENSE_LGPL_21.txt included in OCCT
9 // distribution for complete text of the license and disclaimer of any warranty.
10 //
11 // Alternatively, this file may be used under the terms of Open CASCADE
12 // commercial license or contractual agreement.
13
14 #ifndef OSD_Parallel_HeaderFile
15 #define OSD_Parallel_HeaderFile
16
17 #include <OSD_Thread.hxx>
18 #include <Standard_Mutex.hxx>
19 #include <Standard_NotImplemented.hxx>
20 #include <Standard_Atomic.hxx>
21 #include <NCollection_Array1.hxx>
22
23 #ifdef HAVE_TBB
24 #include <tbb/parallel_for.h>
25 #include <tbb/parallel_for_each.h>
26 #include <tbb/blocked_range.h>
27 #endif
28
29 //! @class OSD_Parallel
30 //! @brief Simplifies code parallelization.
31 //!
32 //! The Class provides an interface of parallel processing "for" and "foreach" loops.
33 //! These primitives encapsulates complete logic for creating and managing parallel context of loops.
34 //! Moreover the primitives may be a wrapper for some primitives from 3rd-party library - TBB.
35 //! To use it is necessary to implement TBB like interface which is based on functors.
36 //!
37 //! @code
38 //! class Functor
39 //! {
40 //! public:
41 //!   void operator() ([proccesing instance]) const
42 //!   {
43 //!     //...
44 //!   }
45 //! };
46 //! @endcode
47 //!
48 //! In the body of the operator () should be implemented thread-safe logic of computations that can be performed in parallel context.
49 //! If parallelized loop iterates on the collections with direct access by index (such as Vector, Array),
50 //! it is more efficient to use the primitive ParallelFor (because it has no critical section).
51 class OSD_Parallel
52 {
53   //! Auxiliary class which ensures exclusive
54   //! access to iterators of processed data pool.
55   template <typename Value>
56   class Range
57   {
58   public: //! @name public methods
59
60     typedef Value Iterator;
61
62     //! Constructor
63     Range(const Value& theBegin, const Value& theEnd)
64     : myBegin(theBegin),
65       myEnd  (theEnd),
66       myIt   (theBegin)
67     {
68     }
69
70     //! Returns const link on the first element.
71     inline const Value& Begin() const
72     {
73       return myBegin;
74     }
75
76     //! Returns const link on the last element.
77     inline const Value& End() const
78     {
79       return myEnd;
80     }
81
82     //! Returns first non processed element or end.
83     //! Thread-safe method.
84     inline Iterator It() const
85     {
86       Standard_Mutex::Sentry aMutex( myMutex );
87       return ( myIt != myEnd ) ? myIt++ : myEnd;
88     }
89
90   private: //! @name private methods
91
92     //! Empty copy constructor
93     Range(const Range& theCopy);
94
95     //! Empty copy operator.
96     Range& operator=(const Range& theCopy);
97
98   private: //! @name private fields
99
100     const Value&           myBegin; //!< Fisrt element of range.
101     const Value&           myEnd;   //!< Last element of range.
102     mutable Value          myIt;    //!< First non processed element of range.
103     mutable Standard_Mutex myMutex; //!< Access controller for the first non processed element.
104   };
105
106   //! Auxiliary wrapper class for thread function.
107   template <typename Functor, typename InputIterator>
108   class Task
109   {
110   public: //! @name public methods
111
112     //! Constructor.
113     Task(const Functor& thePerformer, Range<InputIterator>& theRange)
114     : myPerformer(thePerformer),
115       myRange    (theRange)
116     {
117     }
118
119     //! Method is executed in the context of thread,
120     //! so this method defines the main calculations.
121     static Standard_Address RunWithIterator(Standard_Address theTask)
122     {
123       Task<Functor, InputIterator>& aTask =
124         *( static_cast< Task<Functor, InputIterator>* >(theTask) );
125
126       const Range<InputIterator>& aData( aTask.myRange );
127       typename Range<InputIterator>::Iterator i = aData.It();
128
129       for ( ; i != aData.End(); i = aData.It() )
130       {
131         aTask.myPerformer(*i);
132       }
133
134       return NULL;
135     }
136
137     //! Method is executed in the context of thread,
138     //! so this method defines the main calculations.
139     static Standard_Address RunWithIndex(Standard_Address theTask)
140     {
141       Task<Functor, InputIterator>& aTask =
142         *( static_cast< Task<Functor, Standard_Integer>* >(theTask) );
143
144       const Range<Standard_Integer>& aData( aTask.myRange );
145       Standard_Integer i = aData.It();
146
147       for ( ; i < aData.End(); i = aData.It())
148       {
149         aTask.myPerformer(i);
150       }
151
152       return NULL;
153     }
154
155   private: //! @name private methods
156
157     //! Empty copy constructor.
158     Task(const Task& theCopy);
159
160     //! Empty copy operator.
161     Task& operator=(const Task& theCopy);
162
163   private: //! @name private fields
164
165     const Functor&              myPerformer; //!< Link on functor.
166     const Range<InputIterator>& myRange;     //!< Link on processed data block.
167   };
168
169 public: //! @name public methods
170
171   //! Returns number of logical proccesrs.
172   Standard_EXPORT static Standard_Integer NbLogicalProcessors();
173
174   //! Simple primitive for parallelization of "foreach" loops.
175   template <typename InputIterator, typename Functor>
176   static void ForEach( InputIterator  theBegin,
177                        InputIterator  theEnd,
178                        const Functor& theFunctor,
179                        const Standard_Boolean isForceSingleThreadExecution
180                          = Standard_False );
181
182   //! Simple primitive for parallelization of "for" loops.
183   template <typename Functor>
184   static void For( const Standard_Integer theBegin,
185                    const Standard_Integer theEnd,
186                    const Functor&         theFunctor,
187                    const Standard_Boolean isForceSingleThreadExecution
188                      = Standard_False );
189 };
190
191 //=======================================================================
192 //function : OSD_Parallel::Range::It
193 //purpose  : Template concretization.
194 //=======================================================================
195 template<> inline Standard_Integer OSD_Parallel::Range<Standard_Integer>::It() const
196 {
197   return Standard_Atomic_Increment( reinterpret_cast<volatile int*>(&myIt) ) - 1;
198 }
199
200 //=======================================================================
201 //function : ParallelForEach
202 //purpose  : 
203 //=======================================================================
204 template <typename InputIterator, typename Functor>
205 void OSD_Parallel::ForEach( InputIterator          theBegin,
206                             InputIterator          theEnd,
207                             const Functor&         theFunctor,
208                             const Standard_Boolean isForceSingleThreadExecution )
209 {
210   if ( isForceSingleThreadExecution )
211   {
212     for ( InputIterator it(theBegin); it != theEnd; it++ )
213       theFunctor(*it);
214
215     return;
216   }
217   #ifdef HAVE_TBB
218   {
219     try
220     {
221       tbb::parallel_for_each(theBegin, theEnd, theFunctor);
222     }
223     catch ( tbb::captured_exception& anException )
224     {
225       Standard_NotImplemented::Raise(anException.what());
226     }
227   }
228   #else
229   {
230     Range<InputIterator> aData(theBegin, theEnd);
231     Task<Functor, InputIterator> aTask(theFunctor, aData);
232
233     const Standard_Integer aNbThreads = OSD_Parallel::NbLogicalProcessors();
234     NCollection_Array1<OSD_Thread> aThreads(0, aNbThreads - 1);
235
236     for ( Standard_Integer i = 0; i < aNbThreads; ++i )
237     {
238       OSD_Thread& aThread = aThreads(i);
239       aThread.SetFunction(&Task<Functor, InputIterator>::RunWithIterator);
240       aThread.Run(&aTask);
241     }
242
243     for ( Standard_Integer i = 0; i < aNbThreads; ++i )
244       aThreads(i).Wait();
245   }
246   #endif
247 }
248
249 //=======================================================================
250 //function : ParallelFor
251 //purpose  : 
252 //=======================================================================
253 template <typename Functor>
254 void OSD_Parallel::For( const Standard_Integer theBegin,
255                         const Standard_Integer theEnd,
256                         const Functor&         theFunctor,
257                         const Standard_Boolean isForceSingleThreadExecution )
258 {
259   if ( isForceSingleThreadExecution )
260   {
261     for ( Standard_Integer i = theBegin; i < theEnd; ++i )
262       theFunctor(i);
263
264     return;
265   }
266   #ifdef HAVE_TBB
267   {
268     try
269     {
270       tbb::parallel_for( theBegin, theEnd, theFunctor );
271     }
272     catch ( tbb::captured_exception& anException )
273     {
274       Standard_NotImplemented::Raise(anException.what());
275     }
276   }
277   #else
278   {
279     Range<Standard_Integer> aData(theBegin, theEnd);
280     Task<Functor, Standard_Integer> aTask(theFunctor, aData);
281
282     const Standard_Integer aNbThreads = OSD_Parallel::NbLogicalProcessors();
283     NCollection_Array1<OSD_Thread> aThreads(0, aNbThreads - 1);
284
285     for ( Standard_Integer i = 0; i < aNbThreads; ++i )
286     {
287       OSD_Thread& aThread = aThreads(i);
288       aThread.SetFunction(&Task<Functor, Standard_Integer>::RunWithIndex);
289       aThread.Run(&aTask);
290     }
291
292     for ( Standard_Integer i = 0; i < aNbThreads; ++i )
293       aThreads(i).Wait();
294   }
295   #endif
296 }
297
298 #endif