diff --git a/onnx2keras/pooling_layers.py b/onnx2keras/pooling_layers.py index 963032d9..21860316 100644 --- a/onnx2keras/pooling_layers.py +++ b/onnx2keras/pooling_layers.py @@ -149,9 +149,9 @@ def convert_global_avg_pool(node, params, layers, lambda_func, node_name, keras_ global_pool = keras.layers.GlobalAveragePooling2D(data_format='channels_first', name=keras_name) input_0 = global_pool(input_0) - def target_layer(x): + def target_layer(x, axis=-1): from tensorflow import keras - return keras.backend.expand_dims(x) + return keras.backend.expand_dims(x, axis=axis) logger.debug('Now expand dimensions twice.') lambda_layer1 = keras.layers.Lambda(target_layer, name=keras_name + '_EXPAND1')