1111#ifndef BOOST_COMPUTE_ALGORITHM_COPY_IF_HPP
1212#define BOOST_COMPUTE_ALGORITHM_COPY_IF_HPP
1313
14- #include < boost/compute/cl.hpp>
15- #include < boost/compute/system.hpp>
16- #include < boost/compute/command_queue.hpp>
17- #include < boost/compute/algorithm/count.hpp>
18- #include < boost/compute/algorithm/count_if.hpp>
19- #include < boost/compute/algorithm/exclusive_scan.hpp>
20- #include < boost/compute/container/vector.hpp>
21- #include < boost/compute/detail/meta_kernel.hpp>
22- #include < boost/compute/detail/iterator_range_size.hpp>
23- #include < boost/compute/iterator/discard_iterator.hpp>
14+ #include < boost/compute/algorithm/transform_if.hpp>
15+ #include < boost/compute/functional/identity.hpp>
2416
2517namespace boost {
2618namespace compute {
2719namespace detail {
2820
29- template <class InputIterator , class OutputIterator , class Predicate >
30- inline OutputIterator copy_if_impl (InputIterator first,
31- InputIterator last,
32- OutputIterator result,
33- Predicate predicate,
34- bool copyIndex,
35- command_queue &queue)
36- {
37- typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type;
38-
39- size_t count = detail::iterator_range_size (first, last);
40- if (count == 0 ){
41- return result;
42- }
43-
44- const context &context = queue.get_context ();
45-
46- // storage for destination indices
47- ::boost::compute::vector<cl_uint> indices (count, context);
48-
49- // write counts
50- ::boost::compute::detail::meta_kernel k1 (" copy_if_write_counts" );
51- k1 << indices.begin ()[k1.get_global_id (0 )] << " = "
52- << predicate (first[k1.get_global_id (0 )]) << " ? 1 : 0;\n " ;
53- k1.exec_1d (queue, 0 , count);
54-
55- // count number of elements to be copied
56- size_t copied_element_count =
57- ::boost::compute::count (indices.begin(), indices.end(), 1, queue);
58-
59- // scan indices
60- ::boost::compute::exclusive_scan (indices.begin(),
61- indices.end(),
62- indices.begin(),
63- queue);
64-
65- // copy values
66- ::boost::compute::detail::meta_kernel k2 (" copy_if_do_copy" );
67- k2 << " if(" << predicate (first[k2.get_global_id (0 )]) << " )" <<
68- " " << result[indices.begin ()[k2.get_global_id (0 )]] << " =" ;
69-
70- if (copyIndex){
71- k2 << k2.get_global_id (0 ) << " ;\n " ;
72- }
73- else {
74- k2 << first[k2.get_global_id (0 )] << " ;\n " ;
75- }
76-
77- k2.exec_1d (queue, 0 , count);
78-
79- return result + static_cast <difference_type>(copied_element_count);
80- }
81-
82- template <class InputIterator , class Predicate >
83- inline discard_iterator copy_if_impl (InputIterator first,
84- InputIterator last,
85- discard_iterator result,
86- Predicate predicate,
87- bool copyIndex,
88- command_queue &queue)
89- {
90- (void ) copyIndex;
91-
92- return result + count_if (first, last, predicate, queue);
93- }
94-
9521// like the copy_if() algorithm but writes the indices of the values for which
9622// predicate returns true.
9723template <class InputIterator , class OutputIterator , class Predicate >
@@ -101,7 +27,11 @@ inline OutputIterator copy_index_if(InputIterator first,
10127 Predicate predicate,
10228 command_queue &queue = system::default_queue())
10329{
104- return detail::copy_if_impl (first, last, result, predicate, true , queue);
30+ typedef typename std::iterator_traits<InputIterator>::value_type T;
31+
32+ return detail::transform_if_impl (
33+ first, last, result, identity<T>(), predicate, true , queue
34+ );
10535}
10636
10737} // end detail namespace
@@ -115,7 +45,11 @@ inline OutputIterator copy_if(InputIterator first,
11545 Predicate predicate,
11646 command_queue &queue = system::default_queue())
11747{
118- return detail::copy_if_impl (first, last, result, predicate, false , queue);
48+ typedef typename std::iterator_traits<InputIterator>::value_type T;
49+
50+ return ::boost::compute::transform_if (
51+ first, last, result, identity<T>(), predicate, queue
52+ );
11953}
12054
12155} // end compute namespace
0 commit comments