-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
Saw your article in Medium, from October 7, and wanted to try it out on my 3090. (Training in 143.27s and Inference in 196.57s, for the Semantic Segmentation). Then I wanted to test with 16-bit floats, by adding:
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
However, the code crashes in model.fit due to lack of float16 support in your custom_layers_and_blocks:
File ~/anaconda3/lib/python3.9/site-packages/tensorflow_advanced_segmentation_models/models/_custom_layers_and_blocks.py:252, in AtrousSpatialPyramidPoolingV3.call(self, input_tensor, training)
249 z = self.atrous_sepconv_bn_relu_3(input_tensor, training=training)
251 # concatenation
--> 252 net = tf.concat([glob_avg_pool, w, x, y, z], axis=-1)
253 net = self.conv_reduction_1(net, training=training)
255 return net
InvalidArgumentError: Exception encountered when calling layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3).
cannot compute ConcatV2 as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:ConcatV2] name: concat
Call arguments received by layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3):
• input_tensor=tf.Tensor(shape=(16, 40, 40, 288), dtype=float16)
• training=True
Any chance of an update to support float16?
Metadata
Metadata
Assignees
Labels
No labels