Skip to content

Commit d096d3a

Browse files
authored
update_ofa (PaddlePaddle#1012)
1 parent caf3002 commit d096d3a

File tree

5 files changed

+324
-138
lines changed

5 files changed

+324
-138
lines changed

paddleslim/nas/ofa/convert_super.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from . import layers
3636
Layer = paddle.nn.Layer
3737
from .layers_base import Block
38-
38+
from . import layers_old
3939
_logger = get_logger(__name__, level=logging.INFO)
4040

4141
__all__ = ['supernet', 'Convert']
@@ -58,11 +58,16 @@ class Convert:
5858
def __init__(self, context):
5959
self.context = context
6060

61-
def _change_name(self, layer, pd_ver, has_bias=True, conv=False):
61+
def _change_name(self,
62+
layer,
63+
pd_ver,
64+
has_bias=True,
65+
conv=False,
66+
use_bn_old=False):
6267
if conv:
6368
w_attr = layer._param_attr
6469
else:
65-
w_attr = layer._param_attr if pd_ver == 185 else layer._weight_attr
70+
w_attr = layer._param_attr if pd_ver == 185 or use_bn_old else layer._weight_attr
6671

6772
if isinstance(w_attr, ParamAttr):
6873
if w_attr != None and not isinstance(w_attr,
@@ -241,28 +246,32 @@ def convert(self, network):
241246
layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
242247
model[idx] = layer
243248

244-
elif isinstance(layer,
245-
getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and (
246-
getattr(self.context, 'expand', None) != None or
247-
getattr(self.context, 'channel', None) != None):
249+
elif (isinstance(layer, nn.BatchNorm2D) or
250+
isinstance(layer, nn.BatchNorm)) and (
251+
getattr(self.context, 'expand', None) != None or
252+
getattr(self.context, 'channel', None) != None):
248253
# num_features in BatchNorm don't change after last weight operators
249254
if idx > last_weight_layer_idx:
250255
continue
251256

257+
use_bn_old = False
258+
if isinstance(layer, nn.BatchNorm):
259+
use_bn_old = True
260+
252261
attr_dict = layer.__dict__
253262
new_attr_name = ['momentum', 'epsilon', 'bias_attr']
254263

255-
if pd_ver == 185:
264+
if pd_ver == 185 or use_bn_old:
256265
new_attr_name += [
257266
'param_attr', 'act', 'dtype', 'in_place', 'data_layout',
258267
'is_test', 'use_global_stats', 'trainable_statistics'
259268
]
260269
else:
261270
new_attr_name += ['weight_attr', 'data_format', 'name']
262271

263-
self._change_name(layer, pd_ver)
272+
self._change_name(layer, pd_ver, use_bn_old=use_bn_old)
264273
new_attr_dict = dict.fromkeys(new_attr_name, None)
265-
if pd_ver == 185:
274+
if pd_ver == 185 or use_bn_old:
266275
new_attr_dict['num_channels'] = None
267276
else:
268277
new_attr_dict['num_features'] = None
@@ -284,9 +293,10 @@ def convert(self, network):
284293

285294
del layer, attr_dict
286295

287-
layer = layers.SuperBatchNorm(
296+
layer = layers_old.SuperBatchNorm(
288297
**new_attr_dict
289-
) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict)
298+
) if pd_ver == 185 or use_bn_old else layers.SuperBatchNorm2D(
299+
**new_attr_dict)
290300
model[idx] = layer
291301

292302
elif isinstance(layer, SyncBatchNorm) and (
@@ -755,4 +765,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):
755765
# def convert(*args, **kwargs):
756766
# supernet_convert(*args, **kwargs)
757767
# return convert
758-
# return _ofa_supernet
768+
# return _ofa_supernet

paddleslim/nas/ofa/layers.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@
4040

4141
class SuperConv2D(nn.Conv2D):
4242
"""This interface is used to construct a callable object of the ``SuperConv2D`` class.
43-
4443
Note: the channel in config need to less than first defined.
45-
4644
The super convolution2D layer calculates the output based on the input, filter
4745
and strides, paddings, dilations, groups parameters. Input and
4846
Output are in NCHW format, where N is batch size, C is the number of
@@ -59,17 +57,14 @@ class SuperConv2D(nn.Conv2D):
5957
applied to the final result.
6058
For each input :math:`X`, the equation is:
6159
.. math::
62-
6360
Out = sigma (W \\ast X + b)
64-
6561
Where:
6662
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
6763
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
6864
* :math:`\\ast`: Convolution operation.
6965
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
7066
* :math:`\\sigma`: Activation function.
7167
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
72-
7368
Example:
7469
- Input:
7570
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
@@ -78,11 +73,8 @@ class SuperConv2D(nn.Conv2D):
7873
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
7974
Where
8075
.. math::
81-
8276
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
83-
8477
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
85-
8678
Parameters:
8779
num_channels(int): The number of channels in the input image.
8880
num_filters(int): The number of filter. It is as same as the output
@@ -144,7 +136,6 @@ class SuperConv2D(nn.Conv2D):
144136
config = {'channel': 5}
145137
data = paddle.to_tensor(data)
146138
conv = super_conv2d(data, config)
147-
148139
"""
149140

150141
### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network.
@@ -214,10 +205,6 @@ def __init__(self,
214205
setattr(self, name, param)
215206

216207
def get_active_filter(self, in_nc, out_nc, kernel_size):
217-
### Unsupport for asymmetric kernels
218-
if self._kernel_size[0] != self._kernel_size[1]:
219-
return self.weight[:out_nc, :in_nc, :, :]
220-
221208
start, end = compute_start_end(self._kernel_size[0], kernel_size)
222209
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
223210
filters = self.weight[:out_nc, :in_nc, start:end, start:end]
@@ -292,14 +279,9 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
292279
out_nc = int(channel)
293280
else:
294281
out_nc = self._out_channels
295-
296282
ks = int(self._kernel_size[0]) if kernel_size == None else int(
297283
kernel_size)
298284

299-
if kernel_size is not None and self._kernel_size[
300-
0] != self._kernel_size[1]:
301-
_logger.error("Searching for asymmetric kernels is NOT supported")
302-
303285
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
304286
out_nc)
305287

@@ -324,6 +306,7 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
324306
else:
325307
bias = self.bias
326308
self.cur_config['prune_dim'] = list(weight.shape)
309+
self.cur_config['prune_group'] = groups
327310
out = F.conv2d(
328311
input,
329312
weight,
@@ -361,9 +344,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
361344
"""
362345
This interface is used to construct a callable object of the ``SuperConv2DTranspose``
363346
class.
364-
365347
Note: the channel in config need to less than first defined.
366-
367348
The super convolution2D transpose layer calculates the output based on the input,
368349
filter, and dilations, strides, paddings. Input and output
369350
are in NCHW format. Where N is batch size, C is the number of feature map,
@@ -527,9 +508,6 @@ def __init__(self,
527508
setattr(self, name, param)
528509

529510
def get_active_filter(self, in_nc, out_nc, kernel_size):
530-
### Unsupport for asymmetric kernels
531-
if self._kernel_size[0] != self._kernel_size[1]:
532-
return self.weight[:out_nc, :in_nc, :, :]
533511
start, end = compute_start_end(self._kernel_size[0], kernel_size)
534512
filters = self.weight[:in_nc, :out_nc, start:end, start:end]
535513
if self.transform_kernel != False and kernel_size < self._kernel_size[
@@ -612,10 +590,6 @@ def forward(self,
612590
ks = int(self._kernel_size[0]) if kernel_size == None else int(
613591
kernel_size)
614592

615-
if kernel_size is not None and self._kernel_size[
616-
0] != self._kernel_size[1]:
617-
_logger.error("Searching for asymmetric kernels is NOT supported")
618-
619593
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
620594
out_nc)
621595

@@ -638,6 +612,7 @@ def forward(self,
638612
else:
639613
bias = self.bias
640614
self.cur_config['prune_dim'] = list(weight.shape)
615+
self.cur_config['prune_group'] = groups
641616
out = F.conv2d_transpose(
642617
input,
643618
weight,
@@ -682,12 +657,10 @@ class SuperSeparableConv2D(nn.Layer):
682657
{'channel', num_of_channel} represents the channels of the first conv's outputs and
683658
the second conv's inputs, used to change the first dimension of weight and bias,
684659
only train the first channels of the weight and bias.
685-
686660
The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm2D
687661
or InstanceNorm2D), Conv2D]. The first conv is depthwise conv, the filter number is input channel
688662
multiply scale_factor, the group is equal to the number of input channel. The second conv
689663
is standard conv, which filter size and stride size are 1.
690-
691664
Parameters:
692665
num_channels(int): The number of channels in the input image.
693666
num_filters(int): The number of the second conv's filter. It is as same as the output
@@ -923,7 +896,6 @@ def forward(self, input, expand_ratio=None, channel=None):
923896
class SuperBatchNorm2D(nn.BatchNorm2D):
924897
"""
925898
This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class.
926-
927899
Parameters:
928900
num_features(int): Indicate the number of channels of the input ``Tensor``.
929901
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
@@ -938,7 +910,6 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
938910
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
939911
data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
940912
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
941-
942913
Examples:
943914
.. code-block:: python
944915
import paddle
@@ -1062,7 +1033,6 @@ def forward(self, input):
10621033
class SuperInstanceNorm2D(nn.InstanceNorm2D):
10631034
"""
10641035
This interface is used to construct a callable object of the ``SuperInstanceNorm2D`` class.
1065-
10661036
Parameters:
10671037
num_features(int): Indicate the number of channels of the input ``Tensor``.
10681038
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
@@ -1077,7 +1047,6 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
10771047
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
10781048
data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
10791049
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
1080-
10811050
Examples:
10821051
.. code-block:: python
10831052
import paddle
@@ -1121,11 +1090,9 @@ def forward(self, input):
11211090
class SuperLayerNorm(nn.LayerNorm):
11221091
"""
11231092
This interface is used to construct a callable object of the ``SuperLayerNorm`` class.
1124-
11251093
The difference between ```SuperLayerNorm``` and ```LayerNorm``` is:
11261094
the trained weight and bias in ```SuperLayerNorm``` can be changed according to the shape of input,
11271095
only train the first channels of the weight and bias.
1128-
11291096
Parameters:
11301097
normalized_shape(int|list|tuple): Input shape from an expected input of
11311098
size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`.
@@ -1193,7 +1160,6 @@ def forward(self, input):
11931160
class SuperEmbedding(nn.Embedding):
11941161
"""
11951162
This interface is used to construct a callable object of the ``SuperEmbedding`` class.
1196-
11971163
Parameters:
11981164
num_embeddings (int): Just one element which indicate the size
11991165
of the dictionary of embeddings.
@@ -1280,4 +1246,4 @@ def forward(self, input, expand_ratio=None, channel=None):
12801246
weight=weight,
12811247
padding_idx=self._padding_idx,
12821248
sparse=self._sparse,
1283-
name=self._name)
1249+
name=self._name)

0 commit comments

Comments
 (0)