"""A class for Tensorflow specific optimizer logic. The major behavior change for this class is for tf.distribute. It will override methods from base Keras core Optimizer, which provide distribute specific functionality, e.g. variable creation, loss reduction, etc. """ import warnings import tensorflow as tf from keras.src import backend from keras.src.backend.tensorflow.trackable import KerasAutoTrackable from keras.src.optimizers import base_optimizer class TFOptimizer(KerasAutoTrackable, base_optimizer.BaseOptimizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._distribution_strategy = tf.distribute.get_strategy() def add_variable_from_reference( self, reference_variable, name=None, initializer="zeros" ): if isinstance(reference_variable, backend.Variable): colocate_var = reference_variable.value else: colocate_var = reference_variable with self._distribution_strategy.extended.colocate_vars_with( colocate_var ): return super().add_variable_from_reference( reference_variable, name=name, initializer=initializer ) def stateless_apply(self, optimizer_variables, grads, trainable_variables): # This is mainly due to the interaction with tf.distribute.Strategy, # which requires tf.Variable as the inputs for most of its APIs. raise ValueError( "stateless_apply is not supported with the TensorFlow backend " "(as it is incompatible with tf.distribute)." ) def assign(self, variable, value): if isinstance(variable, backend.Variable): variable = variable.value value = tf.cast(value, variable.dtype) if isinstance(value, tf.IndexedSlices): variable.scatter_update(value) else: variable.assign(value) def assign_add(self, variable, value): if isinstance(variable, backend.Variable): variable = variable.value value = tf.cast(value, variable.dtype) if isinstance(value, tf.IndexedSlices): variable.scatter_add(value) else: variable.assign_add(value) def assign_sub(self, variable, value): if isinstance(variable, backend.Variable): variable = variable.value value = tf.cast(value, variable.dtype) if isinstance(value, tf.IndexedSlices): variable.scatter_sub(value) else: variable.assign_sub(value) def _var_key(self, variable): if isinstance(variable, backend.Variable): variable = variable.value # Convert to tf.Variable if hasattr(variable, "_distributed_container"): variable = variable._distributed_container() elif ( isinstance(variable, tf.__internal__.CompositeTensor) and hasattr(variable, "handle") and hasattr(variable.handle, "_distributed_container") ): # For ResourceVariables, the _distributed_container attribute # is added to their handle tensors. variable = variable.handle._distributed_container() return variable._unique_id def _apply_weight_decay(self, variables): if self.weight_decay is None: return def distributed_apply_weight_decay(distribution, variables, **kwargs): def weight_decay_fn(variable): if self._use_weight_decay(variable): lr = tf.cast(self.learning_rate, variable.dtype) wd = tf.cast(self.weight_decay, variable.dtype) variable.assign_sub(variable * wd * lr) for variable in variables: if isinstance(variable, backend.Variable): variable = variable.value # Convert to tf.Variable distribution.extended.update( variable, weight_decay_fn, group=False ) tf.__internal__.distribute.interim.maybe_merge_call( distributed_apply_weight_decay, self._distribution_strategy, variables, ) def _backend_update_step(self, grads, trainable_variables, learning_rate): trainable_variables = [ v.value if isinstance(v, backend.Variable) else v for v in trainable_variables ] grads_and_vars = list(zip(grads, trainable_variables)) grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars) tf.__internal__.distribute.interim.maybe_merge_call( self._distributed_tf_update_step, self._distribution_strategy, grads_and_vars, learning_rate, ) def _distributed_tf_update_step( self, distribution, grads_and_vars, learning_rate ): def apply_grad_to_update_var(var, grad, learning_rate): return self.update_step(grad, var, learning_rate) for grad, var in grads_and_vars: distribution.extended.update( var, apply_grad_to_update_var, args=(grad, learning_rate), group=False, ) def _all_reduce_sum_gradients(self, grads_and_vars): """Returns all-reduced gradients aggregated via summation. Args: grads_and_vars: List of (gradient, variable) pairs. Returns: List of (gradient, variable) pairs where gradients have been all-reduced. """ replica_context = tf.distribute.get_replica_context() if not replica_context: return grads_and_vars grads_and_vars = list(grads_and_vars) filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) if filtered_grads_and_vars: grads = [pair[0] for pair in filtered_grads_and_vars] reduced = tf.distribute.get_replica_context().all_reduce( tf.distribute.ReduceOp.SUM, grads ) else: reduced = [] # Copy 'reduced' but add None gradients back in reduced_with_nones = [] reduced_pos = 0 for g, v in grads_and_vars: if g is None: reduced_with_nones.append((None, v)) else: reduced_with_nones.append((reduced[reduced_pos], v)) reduced_pos += 1 assert reduced_pos == len(reduced), "Failed to add all gradients" return reduced_with_nones def _overwrite_model_variables_with_average_value( self, trainable_variables ): """Overwrite model variables with their moving average values. This function overwrites variables on each device. Args: var_list: list of model variables. """ trainable_variables = [ v.value if isinstance(v, backend.Variable) else v for v in trainable_variables ] # Override model variable by the stored average value on all devices. for var, average_var in zip( trainable_variables, self._model_variables_moving_average ): self._distribution_strategy.extended.update( var, lambda a, b: a.assign(b), args=(average_var,) ) def _backend_increment_gradient_accumulators(self, grads, acc_grads): def update_accumulator(var, grad): var.assign(var + grad) accumulators = [v.value for v in acc_grads] def _distributed_tf_increment_grad_acc( distribution, grads, accumulators ): for grad, var in zip(grads, accumulators): distribution.extended.update( var, update_accumulator, args=(grad,), group=False ) tf.__internal__.distribute.interim.maybe_merge_call( _distributed_tf_increment_grad_acc, self._distribution_strategy, grads, accumulators, ) def _clip_by_norm(self, values, axes=None): # We need to use TF-specific OP to support the case, # when `values` are `tf.IndexedSlices`. return tf.clip_by_norm(values, self.clipnorm, axes) def filter_empty_gradients(grads_and_vars): """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" grads_and_vars = tuple(grads_and_vars) if not grads_and_vars: return grads_and_vars filtered = [] vars_with_empty_grads = [] for grad, var in grads_and_vars: if grad is None: vars_with_empty_grads.append(var) else: filtered.append((grad, var)) filtered = tuple(filtered) if not filtered: variable = ([v.name for _, v in grads_and_vars],) raise ValueError( f"No gradients provided for any variable: {variable}. " f"Provided `grads_and_vars` is {grads_and_vars}." ) if vars_with_empty_grads: warnings.warn( "Gradients do not exist for variables %s when minimizing the " "loss. If you're using `model.compile()`, did you forget to " "provide a `loss` argument?", ([v.name for v in vars_with_empty_grads]), ) return filtered