import tensorflow as tf from keras.src import tree def rnn( step_function, inputs, initial_states, go_backwards=False, mask=None, constants=None, unroll=False, input_length=None, time_major=False, zero_output_for_mask=False, return_all_outputs=True, ): """Iterates over the time dimension of a tensor. Args: step_function: RNN step function. Args; `input`; Tensor with shape `(samples, ...)` (no time dimension), representing input for the batch of samples at a certain time step. `states`; List of tensors. Returns; `output`; Tensor with shape `(samples, output_dim)` (no time dimension). `new_states`; List of tensors, same length and shapes as 'states'. The first state in the list must be the output tensor at the previous timestep. inputs: Tensor of temporal data of shape `(samples, time, ...)` (at least 3D), or nested tensors, and each of which has shape `(samples, time, ...)`. initial_states: Tensor with shape `(samples, state_size)` (no time dimension), containing the initial values for the states used in the step function. In the case that state_size is in a nested shape, the shape of initial_states will also follow the nested structure. go_backwards: Boolean. If `True`, do the iteration over the time dimension in reverse order and return the reversed sequence. mask: Binary tensor with shape `(samples, time, 1)`, with a zero for every element that is masked. constants: List of constant values passed at each step. unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. input_length: An integer or a 1-D Tensor, depending on whether the time dimension is fixed-length or not. In case of variable length input, it is used for masking in case there's no mask specified. time_major: Boolean. If `True`, the inputs and outputs will be in shape `(timesteps, batch, ...)`, whereas in the False case, it will be `(batch, timesteps, ...)`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. zero_output_for_mask: Boolean. If `True`, the output for masked timestep will be zeros, whereas in the `False` case, output from previous timestep is returned. return_all_outputs: Boolean. If `True`, return the recurrent outputs for all timesteps in the sequence. If `False`, only return the output for the last timestep (which consumes less memory). Returns: A tuple, `(last_output, outputs, new_states)`. - `last_output`: the latest output of the rnn, with shape `(samples, ...)`. - `outputs`: - If `return_all_outputs=True`: a tensor with shape `(samples, time, ...)` where each entry `outputs[s, t]` is the output of the step function at time `t` for sample `s` - Else, a tensor equal to `last_output` with shape `(samples, 1, ...)` - `new_states`: list of tensors, latest states returned by the step function, of shape `(samples, ...)`. """ input_length = input_length or inputs.shape[1] def swap_batch_timestep(input_t): # Swap the batch and timestep dim for the incoming tensor. axes = list(range(len(input_t.shape))) axes[0], axes[1] = 1, 0 return tf.transpose(input_t, axes) if not time_major: inputs = tree.map_structure(swap_batch_timestep, inputs) flattened_inputs = tree.flatten(inputs) time_steps = flattened_inputs[0].shape[0] time_steps_t = ( tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps ) for input_ in flattened_inputs: input_.shape.with_rank_at_least(3) if mask is not None: if mask.dtype != tf.bool: mask = tf.cast(mask, tf.bool) if len(mask.shape) == 2: mask = tf.expand_dims(mask, axis=-1) if not time_major: mask = swap_batch_timestep(mask) if constants is None: constants = [] # tf.where needs its condition tensor to be the same shape as its two # result tensors, but in our case the condition (mask) tensor is # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. # So we need to broadcast the mask to match the shape of inputs. # That's what the tile call does, it just repeats the mask along its # second dimension n times. def _expand_mask(mask_t, input_t, fixed_dim=1): if tree.is_nested(mask_t): raise ValueError( f"mask_t is expected to be tensor, but got {mask_t}" ) if tree.is_nested(input_t): raise ValueError( f"input_t is expected to be tensor, but got {input_t}" ) rank_diff = len(input_t.shape) - len(mask_t.shape) for _ in range(rank_diff): mask_t = tf.expand_dims(mask_t, -1) multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:] return tf.tile(mask_t, multiples) if unroll: if not time_steps: raise ValueError("Unrolling requires a fixed number of timesteps.") states = tuple(initial_states) successive_states = [] successive_outputs = [] # Process the input tensors. The input tensor need to be split on the # time_step dim, and reverse if go_backwards is True. In the case of # nested input, the input is flattened and then transformed # individually. The result of this will be a tuple of lists, each of # the item in tuple is list of the tensor with shape (batch, feature) def _process_single_input_t(input_t): input_t = tf.unstack(input_t) # unstack for time_step dim if go_backwards: input_t.reverse() return input_t if tree.is_nested(inputs): processed_input = tree.map_structure( _process_single_input_t, inputs ) else: processed_input = (_process_single_input_t(inputs),) def _get_input_tensor(time): inp = [t_[time] for t_ in processed_input] return tree.pack_sequence_as(inputs, inp) if mask is not None: mask_list = tf.unstack(mask) if go_backwards: mask_list.reverse() for i in range(time_steps): inp = _get_input_tensor(i) mask_t = mask_list[i] output, new_states = step_function( inp, tuple(states) + tuple(constants) ) tiled_mask_t = _expand_mask(mask_t, output) if not successive_outputs: prev_output = tf.zeros_like(output) else: prev_output = successive_outputs[-1] output = tf.where(tiled_mask_t, output, prev_output) flat_states = tree.flatten(states) flat_new_states = tree.flatten(new_states) tiled_mask_t = tuple( _expand_mask(mask_t, s) for s in flat_states ) flat_final_states = tuple( tf.where(m, s, ps) for m, s, ps in zip( tiled_mask_t, flat_new_states, flat_states ) ) states = tree.pack_sequence_as(states, flat_final_states) if return_all_outputs: successive_outputs.append(output) successive_states.append(states) else: successive_outputs = [output] successive_states = [states] last_output = successive_outputs[-1] new_states = successive_states[-1] outputs = tf.stack(successive_outputs) if zero_output_for_mask: last_output = tf.where( _expand_mask(mask_list[-1], last_output), last_output, tf.zeros_like(last_output), ) outputs = tf.where( _expand_mask(mask, outputs, fixed_dim=2), outputs, tf.zeros_like(outputs), ) else: # mask is None for i in range(time_steps): inp = _get_input_tensor(i) output, states = step_function( inp, tuple(states) + tuple(constants) ) if return_all_outputs: successive_outputs.append(output) successive_states.append(states) else: successive_outputs = [output] successive_states = [states] last_output = successive_outputs[-1] new_states = successive_states[-1] outputs = tf.stack(successive_outputs) else: # Unroll == False states = tuple(initial_states) # Create input tensor array, if the inputs is nested tensors, then it # will be flattened first, and tensor array will be created one per # flattened tensor. input_ta = tuple( tf.TensorArray( dtype=inp.dtype, size=time_steps_t, tensor_array_name=f"input_ta_{i}", ) for i, inp in enumerate(flattened_inputs) ) input_ta = tuple( ( ta.unstack(input_) if not go_backwards else ta.unstack(tf.reverse(input_, [0])) ) for ta, input_ in zip(input_ta, flattened_inputs) ) # Get the time(0) input and compute the output for that, the output will # be used to determine the dtype of output tensor array. Don't read from # input_ta due to TensorArray clear_after_read default to True. input_time_zero = tree.pack_sequence_as( inputs, [inp[0] for inp in flattened_inputs] ) # output_time_zero is used to determine the cell output shape and its # dtype. the value is discarded. output_time_zero, _ = step_function( input_time_zero, tuple(initial_states) + tuple(constants) ) output_ta_size = time_steps_t if return_all_outputs else 1 output_ta = tuple( tf.TensorArray( dtype=out.dtype, size=output_ta_size, element_shape=out.shape, tensor_array_name=f"output_ta_{i}", ) for i, out in enumerate(tree.flatten(output_time_zero)) ) time = tf.constant(0, dtype="int32", name="time") if input_length is None: max_iterations = time_steps_t else: max_iterations = tf.reduce_max(input_length) while_loop_kwargs = { "cond": lambda time, *_: time < time_steps_t, "maximum_iterations": max_iterations, "parallel_iterations": 32, "swap_memory": True, } if mask is not None: if go_backwards: mask = tf.reverse(mask, [0]) mask_ta = tf.TensorArray( dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta" ) mask_ta = mask_ta.unstack(mask) def masking_fn(time): return mask_ta.read(time) def compute_masked_output(mask_t, flat_out, flat_mask): tiled_mask_t = tuple( _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape)) for o in flat_out ) return tuple( tf.where(m, o, fm) for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask) ) elif isinstance(input_length, tf.Tensor): if go_backwards: max_len = tf.reduce_max(input_length, axis=0) rev_input_length = tf.subtract(max_len - 1, input_length) def masking_fn(time): return tf.less(rev_input_length, time) else: def masking_fn(time): return tf.greater(input_length, time) def compute_masked_output(mask_t, flat_out, flat_mask): return tuple( tf.where(mask_t, o, zo) for (o, zo) in zip(flat_out, flat_mask) ) else: masking_fn = None if masking_fn is not None: # Mask for the T output will be base on the output of T - 1. In the # case T = 0, a zero filled tensor will be used. flat_zero_output = tuple( tf.zeros_like(o) for o in tree.flatten(output_time_zero) ) def _step(time, output_ta_t, prev_output, *states): """RNN step function. Args: time: Current timestep value. output_ta_t: TensorArray. prev_output: tuple of outputs from time - 1. *states: List of states. Returns: Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)` """ current_input = tuple(ta.read(time) for ta in input_ta) # maybe set shape. current_input = tree.pack_sequence_as(inputs, current_input) mask_t = masking_fn(time) output, new_states = step_function( current_input, tuple(states) + tuple(constants) ) # mask output flat_output = tree.flatten(output) flat_mask_output = ( flat_zero_output if zero_output_for_mask else tree.flatten(prev_output) ) flat_new_output = compute_masked_output( mask_t, flat_output, flat_mask_output ) # mask states flat_state = tree.flatten(states) flat_new_state = tree.flatten(new_states) flat_final_state = compute_masked_output( mask_t, flat_new_state, flat_state ) new_states = tree.pack_sequence_as(new_states, flat_final_state) ta_index_to_write = time if return_all_outputs else 0 output_ta_t = tuple( ta.write(ta_index_to_write, out) for ta, out in zip(output_ta_t, flat_new_output) ) return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple( new_states ) final_outputs = tf.while_loop( body=_step, loop_vars=(time, output_ta, flat_zero_output) + states, **while_loop_kwargs, ) # Skip final_outputs[2] which is the output for final timestep. new_states = final_outputs[3:] else: def _step(time, output_ta_t, *states): """RNN step function. Args: time: Current timestep value. output_ta_t: TensorArray. *states: List of states. Returns: Tuple: `(time + 1,output_ta_t) + tuple(new_states)` """ current_input = tuple(ta.read(time) for ta in input_ta) current_input = tree.pack_sequence_as(inputs, current_input) output, new_states = step_function( current_input, tuple(states) + tuple(constants) ) flat_new_state = tree.flatten(new_states) flat_output = tree.flatten(output) ta_index_to_write = time if return_all_outputs else 0 output_ta_t = tuple( ta.write(ta_index_to_write, out) for ta, out in zip(output_ta_t, flat_output) ) new_states = tree.pack_sequence_as( initial_states, flat_new_state ) return (time + 1, output_ta_t) + tuple(new_states) final_outputs = tf.while_loop( body=_step, loop_vars=(time, output_ta) + states, **while_loop_kwargs, ) new_states = final_outputs[2:] output_ta = final_outputs[1] outputs = tuple(o.stack() for o in output_ta) last_output = tuple(o[-1] for o in outputs) outputs = tree.pack_sequence_as(output_time_zero, outputs) last_output = tree.pack_sequence_as(output_time_zero, last_output) if not time_major: outputs = tree.map_structure(swap_batch_timestep, outputs) return last_output, outputs, new_states def gru( inputs, initial_state, mask, kernel, recurrent_kernel, bias, activation, recurrent_activation, return_sequences=False, go_backwards=False, unroll=False, time_major=False, reset_after=True, ): cudnn_supported = cudnn_ok( activation, recurrent_activation, unroll, use_bias=bias is not None, reset_after=reset_after, ) if not cudnn_supported: raise NotImplementedError from keras.src.backend.tensorflow import Variable if isinstance(kernel, Variable): kernel = kernel.value if isinstance(recurrent_kernel, Variable): recurrent_kernel = recurrent_kernel.value if isinstance(bias, Variable): bias = bias.value try: return _cudnn_gru( inputs, initial_state, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, return_sequences, ) except tf.errors.InvalidArgumentError: # cuDNN op not found. raise NotImplementedError except tf.errors.NotFoundError: # alternative error: device not found for op raise NotImplementedError def _do_gru_arguments_support_cudnn( activation, recurrent_activation, unroll, use_bias, reset_after, ): from keras.src import activations from keras.src import ops return ( activation in (activations.tanh, tf.tanh, ops.tanh) and recurrent_activation in (activations.sigmoid, tf.sigmoid, ops.sigmoid) and not unroll and use_bias and reset_after ) def _do_lstm_arguments_support_cudnn( activation, recurrent_activation, unroll, use_bias, ): from keras.src import activations from keras.src import ops return ( activation in (activations.tanh, tf.tanh, ops.tanh) and recurrent_activation in (activations.sigmoid, tf.sigmoid, ops.sigmoid) and not unroll and use_bias ) def _has_fully_masked_sequence(mask): # Cudnn kernel will error out if the input sequence contains any # fully masked data. We walk around this issue by rerouting the computation # to standard kernel, until the issue on cudnn side has been fixed. For a # fully masked sequence, it will contain all Falses. To make it easy to # check, we inverse the boolean, check if any of the sequence has all True. return tf.reduce_any( tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1) ) def _assert_valid_mask(mask): valid = tf.logical_and( tf.logical_not(_has_fully_masked_sequence(mask)), _is_sequence_right_padded(mask), ) tf.Assert( valid, [ ( "You are passing a RNN mask that does not correspond to " "right-padded sequences, while using cuDNN, which is not " "supported. With cuDNN, RNN masks can only be used for " "right-padding, e.g. `[[True, True, False, False]]` would " "be a valid mask, but any mask that isn't just contiguous " "`True`'s on the left and contiguous `False`'s on the right " "would be invalid. You can pass `use_cudnn=False` to your " "RNN layer to stop using cuDNN (this may be slower)." ) ], ) def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False): """Utility function convert variable to cuDNN compatible parameter. Note that Keras weights for kernels are different from the cuDNN format. Eg.: ``` Keras cuDNN [[0, 1, 2], <---> [[0, 2, 4], [3, 4, 5]] [1, 3, 5]] ``` If the input weights need to be in a unified format, then set `transpose_weights=True` to convert the weights. Args: weights: list of weights for the kernels and recurrent kernels. biases: list of biases for individual gate. shape: the shape for the converted variables that will be feed to cuDNN. transpose_weights: boolean, whether to transpose the weights. Returns: The converted weights that can be feed to cuDNN ops as param. """ def convert(w): return tf.transpose(w) if transpose_weights else w weights = [tf.reshape(convert(x), shape) for x in weights] biases = [tf.reshape(x, shape) for x in biases] return tf.concat(weights + biases, axis=0) def _is_sequence_right_padded(mask): """Check the mask tensor and see if it right padded. cuDNN uses the sequence length param to skip the tailing timestep. If the data is left padded, or not a strict right padding (has masked value in the middle of the sequence), then cuDNN won't work properly in those cases. Left padded data: [[False, False, True, True, True]]. Right padded data: [[True, True, True, False, False]]. Mixture of mask/unmasked data: [[True, False, True, False, False]]. Note that for the mixed data example above, the actually data RNN should see are those 2 Trues (index 0 and 2), the index 1 False should be ignored and not pollute the internal states. Args: mask: the Boolean tensor with shape [batch, timestep] Returns: boolean scalar tensor, whether the mask is strictly right padded. """ max_seq_length = tf.shape(mask)[1] count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1) right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length) return tf.reduce_all( tf.equal( tf.cast(mask, dtype="bool"), tf.cast(right_padded_mask, dtype="bool"), ) ) def _compute_sequence_length_from_mask(mask, time_major): """Calculate the sequence length tensor (1-D) based on the masking tensor. The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For any timestep that should be masked, the corresponding field will be False. Consider the following example: a = [[True, True, False, False], [True, True, True, False]] It is a (2, 4) tensor, and the corresponding sequence length result should be 1D tensor with value [2, 3]. Note that the masking tensor must be right padded that could be checked by, e.g., `is_sequence_right_padded()`. Args: mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if time_major=True. time_major: Boolean, which indicates whether the mask is time major or batch major. Returns: sequence_length: 1D int32 tensor. """ timestep_index = 0 if time_major else 1 return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index) def _is_gpu_available(): return bool(tf.config.list_logical_devices("GPU")) def _cudnn_gru( inputs, initial_state, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, return_sequences, ): """GRU with cuDNN implementation which is only available for GPU.""" if mask is not None: _assert_valid_mask(mask) sequence_lengths = _compute_sequence_length_from_mask(mask, time_major) else: if time_major: batch_dim = tf.shape(inputs)[1] max_sequence_length = tf.shape(inputs)[0] else: batch_dim = tf.shape(inputs)[0] max_sequence_length = tf.shape(inputs)[1] sequence_lengths = tf.fill([batch_dim], max_sequence_length) if not time_major and sequence_lengths is None: inputs = tf.transpose(inputs, perm=(1, 0, 2)) seq_axis, batch_axis = (0, 1) else: seq_axis, batch_axis = (0, 1) if time_major else (1, 0) # For init_h, cuDNN expects one more dim of num_layers before or after batch # dim for time major or batch major inputs respectively init_h = tf.expand_dims(initial_state, axis=seq_axis) weights = tf.split(kernel, 3, axis=1) weights += tf.split(recurrent_kernel, 3, axis=1) # Note that the bias was initialized as shape (2, 3 * units), flatten it to # (6 * units) bias = tf.split(tf.reshape(bias, [-1]), 6) if tf.sysconfig.get_build_info()["is_cuda_build"]: # Note that the gate order for cuDNN is different from the canonical # format. canonical format is [z, r, h], whereas cuDNN is [r, z, h]. # The swap need to be done for kernel, recurrent_kernel, input_bias, # recurrent_bias. # z is update gate weights. # r is reset gate weights. # h is output gate weights. weights[0], weights[1] = weights[1], weights[0] weights[3], weights[4] = weights[4], weights[3] bias[0], bias[1] = bias[1], bias[0] bias[3], bias[4] = bias[4], bias[3] params = _standardize_cudnn_weights( weights=weights, biases=bias, shape=tf.constant([-1]), transpose_weights=True, ) if go_backwards: # Three reversals are required. E.g., # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked # reversed_input_to_cudnn = [3, 2, 1, 0, 0] # output_from_cudnn = [6, 5, 4, 0, 0] # expected_output = [0, 0, 6, 5 ,4] inputs = tf.reverse_sequence( inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis, ) outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3( input=inputs, input_h=init_h, input_c=0, params=params, is_training=True, rnn_mode="gru", sequence_lengths=sequence_lengths, time_major=time_major, ) if go_backwards: outputs = tf.reverse_sequence( outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis, ) outputs = tf.reverse(outputs, axis=[seq_axis]) last_output = outputs[-1] if not time_major and sequence_lengths is None and return_sequences: outputs = tf.transpose(outputs, perm=[1, 0, 2]) state = tf.squeeze(h, axis=seq_axis) # In the case of variable length input, the cudnn kernel will fill zeros for # the output, whereas the default keras behavior is to bring over the # previous output for t-1, so that in the return_sequence=False case, user # can quickly get the final effect output instead just 0s at the last # timestep. In order to mimic the default keras behavior, we copy the final # h state as the last_output, since it is numerically same as the output. if sequence_lengths is not None: last_output = state # Match CPU return format if not return_sequences: outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) return ( last_output, outputs, [state], ) def cudnn_ok( activation, recurrent_activation, unroll, use_bias, reset_after=None, ): if reset_after is None: args_supported = _do_lstm_arguments_support_cudnn( activation=activation, recurrent_activation=recurrent_activation, unroll=unroll, use_bias=use_bias, ) else: args_supported = _do_gru_arguments_support_cudnn( activation=activation, recurrent_activation=recurrent_activation, unroll=unroll, use_bias=use_bias, reset_after=reset_after, ) return args_supported and _is_gpu_available() def lstm( inputs, initial_state_h, initial_state_c, mask, kernel, recurrent_kernel, bias, activation, recurrent_activation, return_sequences=False, go_backwards=False, unroll=False, time_major=False, ): cudnn_supported = cudnn_ok( activation, recurrent_activation, unroll, use_bias=bias is not None ) if not cudnn_supported: raise NotImplementedError from keras.src.backend.tensorflow import Variable if isinstance(kernel, Variable): kernel = kernel.value if isinstance(recurrent_kernel, Variable): recurrent_kernel = recurrent_kernel.value if isinstance(bias, Variable): bias = bias.value try: return _cudnn_lstm( inputs, initial_state_h, initial_state_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, return_sequences, ) except tf.errors.InvalidArgumentError: # cuDNN op not found. raise NotImplementedError except tf.errors.NotFoundError: # alternative error: device not found for op raise NotImplementedError def _cudnn_lstm( inputs, initial_state_h, initial_state_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, return_sequences, ): if mask is not None: _assert_valid_mask(mask) sequence_lengths = _compute_sequence_length_from_mask(mask, time_major) else: if time_major: batch_dim = tf.shape(inputs)[1] max_sequence_length = tf.shape(inputs)[0] else: batch_dim = tf.shape(inputs)[0] max_sequence_length = tf.shape(inputs)[1] sequence_lengths = tf.fill([batch_dim], max_sequence_length) if not time_major and sequence_lengths is None: inputs = tf.transpose(inputs, perm=(1, 0, 2)) seq_axis, batch_axis = (0, 1) else: seq_axis, batch_axis = (0, 1) if time_major else (1, 0) # For init_h and init_c, cuDNN expects one more dim of num_layers before or # after batch dim for time major or batch major inputs respectively init_h = tf.expand_dims(initial_state_h, axis=seq_axis) init_c = tf.expand_dims(initial_state_c, axis=seq_axis) weights = tf.split(kernel, 4, axis=1) weights += tf.split(recurrent_kernel, 4, axis=1) # cuDNN has an extra set of bias for inputs, we disable them (setting to 0), # so that mathematically it is same as the canonical LSTM implementation. full_bias = tf.concat((tf.zeros_like(bias), bias), 0) if tf.sysconfig.get_build_info()["is_rocm_build"]: # ROCm MIOpen's weight sequence for LSTM is different from both # canonical and Cudnn format # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o] # i is input gate weights. # f is forget gate weights. # o is output gate weights. # c is cell gate weights. weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] # full_bias is a tensor of shape (8*n,) full_bias = tf.split(full_bias, 8, axis=0) full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] params = _standardize_cudnn_weights( weights=weights, biases=tf.split(full_bias, 8), shape=tf.constant([-1]), transpose_weights=True, ) if go_backwards: # Three reversals are required. E.g., # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked # reversed_input_to_cudnn = [3, 2, 1, 0, 0] # output_from_cudnn = [6, 5, 4, 0, 0] # expected_output = [0, 0, 6, 5 ,4] inputs = tf.reverse_sequence( inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis, ) outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3( input=inputs, input_h=init_h, input_c=init_c, params=params, is_training=True, rnn_mode="lstm", sequence_lengths=sequence_lengths, time_major=time_major, ) if go_backwards: outputs = tf.reverse_sequence( outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis, ) outputs = tf.reverse(outputs, axis=[seq_axis]) last_output = outputs[-1] if not time_major and sequence_lengths is None and return_sequences: outputs = tf.transpose(outputs, perm=[1, 0, 2]) h = tf.squeeze(h, axis=seq_axis) c = tf.squeeze(c, axis=seq_axis) # In the case of variable length input, the cudnn kernel will fill zeros for # the output, whereas the default keras behavior is to bring over the # previous output for t-1, so that in the return_sequence=False case, user # can quickly get the final effect output instead just 0s at the last # timestep. In order to mimic the default keras behavior, we copy the final # h state as the last_output, since it is numerically same as the output. if sequence_lengths is not None: last_output = h # Match CPU return format if not return_sequences: outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) return (last_output, outputs, [h, c])