Skip to content

Commit 9c30132

Browse files
committed
Merge pull request #447 from kylelutz/transform-if
Add transform_if() algorithm
2 parents 415e7a0 + f9136ed commit 9c30132

File tree

7 files changed

+163
-158
lines changed

7 files changed

+163
-158
lines changed

include/boost/compute/algorithm/copy_if.hpp

Lines changed: 12 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,87 +11,13 @@
1111
#ifndef BOOST_COMPUTE_ALGORITHM_COPY_IF_HPP
1212
#define BOOST_COMPUTE_ALGORITHM_COPY_IF_HPP
1313

14-
#include <boost/compute/cl.hpp>
15-
#include <boost/compute/system.hpp>
16-
#include <boost/compute/command_queue.hpp>
17-
#include <boost/compute/algorithm/count.hpp>
18-
#include <boost/compute/algorithm/count_if.hpp>
19-
#include <boost/compute/algorithm/exclusive_scan.hpp>
20-
#include <boost/compute/container/vector.hpp>
21-
#include <boost/compute/detail/meta_kernel.hpp>
22-
#include <boost/compute/detail/iterator_range_size.hpp>
23-
#include <boost/compute/iterator/discard_iterator.hpp>
14+
#include <boost/compute/algorithm/transform_if.hpp>
15+
#include <boost/compute/functional/identity.hpp>
2416

2517
namespace boost {
2618
namespace compute {
2719
namespace detail {
2820

29-
template<class InputIterator, class OutputIterator, class Predicate>
30-
inline OutputIterator copy_if_impl(InputIterator first,
31-
InputIterator last,
32-
OutputIterator result,
33-
Predicate predicate,
34-
bool copyIndex,
35-
command_queue &queue)
36-
{
37-
typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type;
38-
39-
size_t count = detail::iterator_range_size(first, last);
40-
if(count == 0){
41-
return result;
42-
}
43-
44-
const context &context = queue.get_context();
45-
46-
// storage for destination indices
47-
::boost::compute::vector<cl_uint> indices(count, context);
48-
49-
// write counts
50-
::boost::compute::detail::meta_kernel k1("copy_if_write_counts");
51-
k1 << indices.begin()[k1.get_global_id(0)] << " = "
52-
<< predicate(first[k1.get_global_id(0)]) << " ? 1 : 0;\n";
53-
k1.exec_1d(queue, 0, count);
54-
55-
// count number of elements to be copied
56-
size_t copied_element_count =
57-
::boost::compute::count(indices.begin(), indices.end(), 1, queue);
58-
59-
// scan indices
60-
::boost::compute::exclusive_scan(indices.begin(),
61-
indices.end(),
62-
indices.begin(),
63-
queue);
64-
65-
// copy values
66-
::boost::compute::detail::meta_kernel k2("copy_if_do_copy");
67-
k2 << "if(" << predicate(first[k2.get_global_id(0)]) << ")" <<
68-
" " << result[indices.begin()[k2.get_global_id(0)]] << "=";
69-
70-
if(copyIndex){
71-
k2 << k2.get_global_id(0) << ";\n";
72-
}
73-
else {
74-
k2 << first[k2.get_global_id(0)] << ";\n";
75-
}
76-
77-
k2.exec_1d(queue, 0, count);
78-
79-
return result + static_cast<difference_type>(copied_element_count);
80-
}
81-
82-
template<class InputIterator, class Predicate>
83-
inline discard_iterator copy_if_impl(InputIterator first,
84-
InputIterator last,
85-
discard_iterator result,
86-
Predicate predicate,
87-
bool copyIndex,
88-
command_queue &queue)
89-
{
90-
(void) copyIndex;
91-
92-
return result + count_if(first, last, predicate, queue);
93-
}
94-
9521
// like the copy_if() algorithm but writes the indices of the values for which
9622
// predicate returns true.
9723
template<class InputIterator, class OutputIterator, class Predicate>
@@ -101,7 +27,11 @@ inline OutputIterator copy_index_if(InputIterator first,
10127
Predicate predicate,
10228
command_queue &queue = system::default_queue())
10329
{
104-
return detail::copy_if_impl(first, last, result, predicate, true, queue);
30+
typedef typename std::iterator_traits<InputIterator>::value_type T;
31+
32+
return detail::transform_if_impl(
33+
first, last, result, identity<T>(), predicate, true, queue
34+
);
10535
}
10636

10737
} // end detail namespace
@@ -115,7 +45,11 @@ inline OutputIterator copy_if(InputIterator first,
11545
Predicate predicate,
11646
command_queue &queue = system::default_queue())
11747
{
118-
return detail::copy_if_impl(first, last, result, predicate, false, queue);
48+
typedef typename std::iterator_traits<InputIterator>::value_type T;
49+
50+
return ::boost::compute::transform_if(
51+
first, last, result, identity<T>(), predicate, queue
52+
);
11953
}
12054

12155
} // end compute namespace
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//---------------------------------------------------------------------------//
2+
// Copyright (c) 2013-2015 Kyle Lutz <kyle.r.lutz@gmail.com>
3+
//
4+
// Distributed under the Boost Software License, Version 1.0
5+
// See accompanying file LICENSE_1_0.txt or copy at
6+
// http://www.boost.org/LICENSE_1_0.txt
7+
//
8+
// See http://kylelutz.github.com/compute for more information.
9+
//---------------------------------------------------------------------------//
10+
11+
#ifndef BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP
12+
#define BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP
13+
14+
#include <boost/compute/cl.hpp>
15+
#include <boost/compute/system.hpp>
16+
#include <boost/compute/command_queue.hpp>
17+
#include <boost/compute/algorithm/count.hpp>
18+
#include <boost/compute/algorithm/count_if.hpp>
19+
#include <boost/compute/algorithm/exclusive_scan.hpp>
20+
#include <boost/compute/container/vector.hpp>
21+
#include <boost/compute/detail/meta_kernel.hpp>
22+
#include <boost/compute/detail/iterator_range_size.hpp>
23+
#include <boost/compute/iterator/discard_iterator.hpp>
24+
25+
namespace boost {
26+
namespace compute {
27+
namespace detail {
28+
29+
template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate>
30+
inline OutputIterator transform_if_impl(InputIterator first,
31+
InputIterator last,
32+
OutputIterator result,
33+
UnaryFunction function,
34+
Predicate predicate,
35+
bool copyIndex,
36+
command_queue &queue)
37+
{
38+
typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type;
39+
40+
size_t count = detail::iterator_range_size(first, last);
41+
if(count == 0){
42+
return result;
43+
}
44+
45+
const context &context = queue.get_context();
46+
47+
// storage for destination indices
48+
::boost::compute::vector<cl_uint> indices(count, context);
49+
50+
// write counts
51+
::boost::compute::detail::meta_kernel k1("transform_if_write_counts");
52+
k1 << indices.begin()[k1.get_global_id(0)] << " = "
53+
<< predicate(first[k1.get_global_id(0)]) << " ? 1 : 0;\n";
54+
k1.exec_1d(queue, 0, count);
55+
56+
// count number of elements to be copied
57+
size_t copied_element_count =
58+
::boost::compute::count(indices.begin(), indices.end(), 1, queue);
59+
60+
// scan indices
61+
::boost::compute::exclusive_scan(
62+
indices.begin(), indices.end(), indices.begin(), queue
63+
);
64+
65+
// copy values
66+
::boost::compute::detail::meta_kernel k2("transform_if_do_copy");
67+
k2 << "if(" << predicate(first[k2.get_global_id(0)]) << ")" <<
68+
" " << result[indices.begin()[k2.get_global_id(0)]] << "=";
69+
70+
if(copyIndex){
71+
k2 << k2.get_global_id(0) << ";\n";
72+
}
73+
else {
74+
k2 << function(first[k2.get_global_id(0)]) << ";\n";
75+
}
76+
77+
k2.exec_1d(queue, 0, count);
78+
79+
return result + static_cast<difference_type>(copied_element_count);
80+
}
81+
82+
template<class InputIterator, class UnaryFunction, class Predicate>
83+
inline discard_iterator transform_if_impl(InputIterator first,
84+
InputIterator last,
85+
discard_iterator result,
86+
UnaryFunction function,
87+
Predicate predicate,
88+
bool copyIndex,
89+
command_queue &queue)
90+
{
91+
(void) function;
92+
(void) copyIndex;
93+
94+
return result + count_if(first, last, predicate, queue);
95+
}
96+
97+
} // end detail namespace
98+
99+
/// Copies each element in the range [\p first, \p last) for which
100+
/// \p predicate returns \c true to the range beginning at \p result.
101+
template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate>
102+
inline OutputIterator transform_if(InputIterator first,
103+
InputIterator last,
104+
OutputIterator result,
105+
UnaryFunction function,
106+
Predicate predicate,
107+
command_queue &queue = system::default_queue())
108+
{
109+
return detail::transform_if_impl(
110+
first, last, result, function, predicate, false, queue
111+
);
112+
}
113+
114+
} // end compute namespace
115+
} // end boost namespace
116+
117+
#endif // BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP

include/boost/compute/experimental/transform_if.hpp

Lines changed: 0 additions & 63 deletions
This file was deleted.

include/boost/compute/types/struct.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
#include <boost/static_assert.hpp>
1717

18+
#include <boost/preprocessor/expr_if.hpp>
1819
#include <boost/preprocessor/stringize.hpp>
19-
#include <boost/preprocessor/seq/for_each.hpp>
2020
#include <boost/preprocessor/seq/fold_left.hpp>
21+
#include <boost/preprocessor/seq/for_each.hpp>
2122
#include <boost/preprocessor/seq/transform.hpp>
2223

2324
#include <boost/compute/type_traits/type_definition.hpp>

test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ add_compute_test("algorithm.sort_by_key" test_sort_by_key.cpp)
117117
add_compute_test("algorithm.stable_partition" test_stable_partition.cpp)
118118
add_compute_test("algorithm.stable_sort" test_stable_sort.cpp)
119119
add_compute_test("algorithm.transform" test_transform.cpp)
120+
add_compute_test("algorithm.transform_if" test_transform_if.cpp)
120121
add_compute_test("algorithm.transform_reduce" test_transform_reduce.cpp)
121122
add_compute_test("algorithm.unique" test_unique.cpp)
122123
add_compute_test("algorithm.unique_copy" test_unique_copy.cpp)
@@ -189,7 +190,6 @@ add_compute_test("experimental.clamp_range" test_clamp_range.cpp)
189190
add_compute_test("experimental.malloc" test_malloc.cpp)
190191
add_compute_test("experimental.sort_by_transform" test_sort_by_transform.cpp)
191192
add_compute_test("experimental.tabulate" test_tabulate.cpp)
192-
add_compute_test("experimental.transform_if" test_transform_if.cpp)
193193

194194
# miscellaneous tests
195195
add_compute_test("misc.amd_cpp_kernel_language" test_amd_cpp_kernel_language.cpp)

test/test_transform.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,4 +298,27 @@ boost::compute::transform(
298298
CHECK_RANGE_EQUAL(int, 4, vec, (1, 2, 3, 4));
299299
}
300300

301+
BOOST_AUTO_TEST_CASE(abs_if_odd)
302+
{
303+
// return absolute value only for odd values
304+
BOOST_COMPUTE_FUNCTION(int, abs_if_odd, (int x),
305+
{
306+
if(x & 1){
307+
return abs(x);
308+
}
309+
else {
310+
return x;
311+
}
312+
});
313+
314+
int data[] = { -2, -3, -4, -5, -6, -7, -8, -9 };
315+
compute::vector<int> vector(data, data + 8, queue);
316+
317+
compute::transform(
318+
vector.begin(), vector.end(), vector.begin(), abs_if_odd, queue
319+
);
320+
321+
CHECK_RANGE_EQUAL(int, 8, vector, (-2, +3, -4, +5, -6, +7, -8, +9));
322+
}
323+
301324
BOOST_AUTO_TEST_SUITE_END()

test/test_transform_if.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,28 @@
1212
#include <boost/test/unit_test.hpp>
1313

1414
#include <boost/compute/lambda.hpp>
15-
#include <boost/compute/functional.hpp>
16-
#include <boost/compute/experimental/transform_if.hpp>
15+
#include <boost/compute/algorithm/transform_if.hpp>
1716
#include <boost/compute/container/vector.hpp>
1817

1918
#include "check_macros.hpp"
2019
#include "context_setup.hpp"
2120

2221
namespace compute = boost::compute;
2322

24-
BOOST_AUTO_TEST_CASE(abs_if_odd)
23+
BOOST_AUTO_TEST_CASE(transform_if_odd)
2524
{
26-
using compute::lambda::_1;
25+
using boost::compute::abs;
26+
using boost::compute::lambda::_1;
2727

28-
// input data
2928
int data[] = { -2, -3, -4, -5, -6, -7, -8, -9 };
3029
compute::vector<int> vector(data, data + 8, queue);
3130

32-
// calculate absolute value only for odd values
33-
compute::experimental::transform_if(
34-
vector.begin(),
35-
vector.end(),
36-
vector.begin(),
37-
compute::abs<int>(),
38-
_1 % 2 != 0,
39-
queue
31+
compute::vector<int>::iterator end = compute::transform_if(
32+
vector.begin(), vector.end(), vector.begin(), abs<int>(), _1 % 2 != 0, queue
4033
);
34+
BOOST_CHECK_EQUAL(std::distance(vector.begin(), end), 4);
4135

42-
// check transformed values
43-
CHECK_RANGE_EQUAL(int, 8, vector, (-2, +3, -4, +5, -6, +7, -8, +9));
36+
CHECK_RANGE_EQUAL(int, 4, vector, (+3, +5, +7, +9));
4437
}
4538

4639
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)