Skip to content

Commit 971acff

Browse files
committed
add python interface for mkldnn_pool
1 parent 9d108a2 commit 971acff

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

python/paddle/trainer/config_parser.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2286,8 +2286,15 @@ def __init__(self, name, inputs, **xargs):
22862286

22872287
@config_layer('pool')
22882288
class PoolLayer(LayerBase):
2289+
layer_type = 'pool'
2290+
22892291
def __init__(self, name, inputs, ceil_mode=True, **xargs):
2290-
super(PoolLayer, self).__init__(name, 'pool', 0, inputs=inputs, **xargs)
2292+
use_mkldnn = int(g_command_config_args.get("use_mkldnn", 0))
2293+
if self.layer_type == "mkldnn_pool":
2294+
config_assert(use_mkldnn, "mkldnn_pool only support MKLDNN")
2295+
self.layer_type = 'mkldnn_pool' if use_mkldnn else 'pool'
2296+
super(PoolLayer, self).__init__(
2297+
name, self.layer_type, 0, inputs=inputs, **xargs)
22912298
for input_index in xrange(len(self.inputs)):
22922299
input_layer = self.get_input_layer(input_index)
22932300
pool_conf = self.config.inputs[input_index].pool_conf
@@ -2297,6 +2304,11 @@ def __init__(self, name, inputs, ceil_mode=True, **xargs):
22972304
pool_conf.channels)
22982305

22992306

2307+
@config_layer('mkldnn_pool')
2308+
class MKLDNNPoolLayer(PoolLayer):
2309+
layer_type = 'mkldnn_pool'
2310+
2311+
23002312
@config_layer('pool3d')
23012313
class Pool3DLayer(LayerBase):
23022314
def __init__(self, name, inputs, ceil_mode=True, **xargs):

0 commit comments

Comments
 (0)