@@ -77,17 +77,15 @@ def __init__(self, encoder, tensorspec):
7777 # These dictionaries are filled inside of the initial_state_fn and encode_fn
7878 # methods, to be used in encode_fn and decode_fn methods, respectively.
7979 # Decorated by tf.function, their necessary side effects are realized during
80- # call to get_concrete_function(). Because of fixed input_signatures, these
81- # are traced only once. See the tf.function tutorial for more details on
82- # the tracing semantics.
80+ # call to get_concrete_function().
8381 state_py_structure = {}
8482 encoded_py_structure = {}
8583
8684 @tf .function
8785 def initial_state_fn ():
8886 state = encoder .initial_state ()
89- assert not state_py_structure # This should be traced only once.
90- state_py_structure ['state' ] = nest .map_structure (lambda _ : None , state )
87+ if not state_py_structure :
88+ state_py_structure ['state' ] = nest .map_structure (lambda _ : None , state )
9189 # Simplify the structure that needs to be manipulated by the user.
9290 return tuple (nest .flatten (state ))
9391
@@ -119,10 +117,10 @@ def encode_fn(x, flat_state):
119117 flat_encoded_py_structure , flat_encoded_tf_structure = (
120118 py_utils .split_dict_py_tf (flat_encoded_structure ))
121119
122- assert not encoded_py_structure # This should be traced only once.
123- encoded_py_structure ['full' ] = nest .map_structure (lambda _ : None ,
124- full_encoded_structure )
125- encoded_py_structure ['flat_py' ] = flat_encoded_py_structure
120+ if not encoded_py_structure :
121+ encoded_py_structure ['full' ] = nest .map_structure (
122+ lambda _ : None , full_encoded_structure )
123+ encoded_py_structure ['flat_py' ] = flat_encoded_py_structure
126124 return flat_encoded_tf_structure , updated_flat_state
127125
128126 @tf .function (input_signature = [
@@ -145,6 +143,12 @@ def decode_fn(encoded_structure):
145143 self ._initial_state_fn = initial_state_fn
146144 self ._encode_fn = encode_fn
147145 self ._decode_fn = decode_fn
146+ self ._tensorspec = tensorspec
147+
148+ @property
149+ def input_tensorspec (self ):
150+ """Returns `tf.TensorSpec` describing input expected by `SimpleEncoder`."""
151+ return self ._tensorspec
148152
149153 def initial_state (self , name = None ):
150154 """Returns the initial state.
@@ -182,8 +186,7 @@ def encode(self, x, state=None, name=None):
182186 """
183187 if state is None :
184188 state = self .initial_state ()
185- with tf .name_scope (name , 'simple_encoder_encode' ,
186- [x ] + list (state )):
189+ with tf .name_scope (name , 'simple_encoder_encode' , [x ] + list (state )):
187190 return self ._encode_fn (x , state )
188191
189192 def decode (self , encoded_x , name = None ):
@@ -205,4 +208,3 @@ def decode(self, encoded_x, name=None):
205208 """
206209 with tf .name_scope (name , 'simple_encoder_decode' , encoded_x .values ()):
207210 return self ._decode_fn (encoded_x )
208-
0 commit comments