@@ -2286,8 +2286,15 @@ def __init__(self, name, inputs, **xargs):
22862286
22872287@config_layer ('pool' )
22882288class PoolLayer (LayerBase ):
2289+ layer_type = 'pool'
2290+
22892291 def __init__ (self , name , inputs , ceil_mode = True , ** xargs ):
2290- super (PoolLayer , self ).__init__ (name , 'pool' , 0 , inputs = inputs , ** xargs )
2292+ use_mkldnn = int (g_command_config_args .get ("use_mkldnn" , 0 ))
2293+ if self .layer_type == "mkldnn_pool" :
2294+ config_assert (use_mkldnn , "mkldnn_pool only support MKLDNN" )
2295+ self .layer_type = 'mkldnn_pool' if use_mkldnn else 'pool'
2296+ super (PoolLayer , self ).__init__ (
2297+ name , self .layer_type , 0 , inputs = inputs , ** xargs )
22912298 for input_index in xrange (len (self .inputs )):
22922299 input_layer = self .get_input_layer (input_index )
22932300 pool_conf = self .config .inputs [input_index ].pool_conf
@@ -2297,6 +2304,11 @@ def __init__(self, name, inputs, ceil_mode=True, **xargs):
22972304 pool_conf .channels )
22982305
22992306
2307+ @config_layer ('mkldnn_pool' )
2308+ class MKLDNNPoolLayer (PoolLayer ):
2309+ layer_type = 'mkldnn_pool'
2310+
2311+
23002312@config_layer ('pool3d' )
23012313class Pool3DLayer (LayerBase ):
23022314 def __init__ (self , name , inputs , ceil_mode = True , ** xargs ):
0 commit comments