Skip to content

Support for 'mixed_float16' #29

@To3No7

Description

@To3No7

Saw your article in Medium, from October 7, and wanted to try it out on my 3090. (Training in 143.27s and Inference in 196.57s, for the Semantic Segmentation). Then I wanted to test with 16-bit floats, by adding:

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

However, the code crashes in model.fit due to lack of float16 support in your custom_layers_and_blocks:

File ~/anaconda3/lib/python3.9/site-packages/tensorflow_advanced_segmentation_models/models/_custom_layers_and_blocks.py:252, in AtrousSpatialPyramidPoolingV3.call(self, input_tensor, training)
    249 z = self.atrous_sepconv_bn_relu_3(input_tensor, training=training)
    251 # concatenation
--> 252 net = tf.concat([glob_avg_pool, w, x, y, z], axis=-1)
    253 net = self.conv_reduction_1(net, training=training)
    255 return net

InvalidArgumentError: Exception encountered when calling layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3).

cannot compute ConcatV2 as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:ConcatV2] name: concat

Call arguments received by layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3):
  • input_tensor=tf.Tensor(shape=(16, 40, 40, 288), dtype=float16)
  • training=True

Any chance of an update to support float16?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions