@@ -409,6 +409,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
409409
410410 return model_proto , external_tensor_storage
411411
412+
412413def from_keras3 (model , input_signature = None , opset = None , custom_ops = None , custom_op_handlers = None ,
413414 custom_rewriter = None , inputs_as_nchw = None , outputs_as_nchw = None , extra_opset = None , shape_override = None ,
414415 target = None , large_model = False , output_path = None , optimizers = None ):
@@ -434,10 +435,25 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
434435 A tuple (model_proto, external_tensor_storage_dict)
435436 """
436437 if not input_signature :
438+ if hasattr (model , "inputs" ):
439+ model_input = model .inputs
440+ elif hasattr (model , "input_dtype" ) and hasattr (model , "_build_shapes_dict" ):
441+ if len (model ._build_shapes_dict ) == 1 :
442+ shape = list (model ._build_shapes_dict .values ())[0 ]
443+ model_input = [tf .Variable (tf .zeros (shape , dtype = model .input_dtype ), name = "input" )]
444+ else :
445+ raise RuntimeError (f"Not implemented yet with input_dtype={ model .input_dtype } and model._build_shapes_dict={ model ._build_shapes_dict } " )
446+ else :
447+ if not hasattr (model , "inputs_spec" ):
448+ raise RuntimeError ("You may set attribute 'inputs_spec' with your inputs (model.input_specs = ...)" )
449+ model_input = model .inputs_spec
450+
437451 input_signature = [
438452 tf .TensorSpec (tensor .shape , tensor .dtype , name = tensor .name .split (":" )[0 ])
439- for tensor in model . inputs
453+ for tensor in model_input
440454 ]
455+ else :
456+ model_input = None
441457
442458 # Trace model
443459 function = tf .function (model )
@@ -459,13 +475,33 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
459475 reverse_lookup = {v : k for k , v in tensors_to_rename .items ()}
460476
461477 valid_names = []
462- for out in [t .name for t in model .outputs ]:
478+ if hasattr (model , "outputs" ):
479+ model_output = model .outputs
480+ else :
481+ if hasattr (model , "outputs_spec" ):
482+ model_output = model .outputs_spec
483+ elif model_input and len (model_input ) == 1 :
484+ # Let's try something to make unit test work. This should be replaced.
485+ model_output = [tf .Variable (model_input [0 ], name = "output" )]
486+ else :
487+ raise RuntimeError (
488+ "You should set attribute 'outputs_spec' with your outputs "
489+ "so that the expected can use that information."
490+ )
491+
492+ def _get_name (t , i ):
493+ try :
494+ return t .name
495+ except AttributeError :
496+ return f"output:{ i } "
497+
498+ for out in [_get_name (t , i ) for i , t in enumerate (model_output )]:
463499 if out in reverse_lookup :
464500 valid_names .append (reverse_lookup [out ])
465501 else :
466502 print (f"Warning: Output name '{ out } ' not found in reverse_lookup." )
467503 # Fallback: verwende TensorFlow-Ausgangsnamen direkt
468- valid_names = [t . name for t in concrete_func .outputs if t .dtype != tf .dtypes .resource ]
504+ valid_names = [_get_name ( t , i ) for i , t in enumerate ( concrete_func .outputs ) if t .dtype != tf .dtypes .resource ]
469505 break
470506 output_names = valid_names
471507
0 commit comments