from keras.src import ops from keras.src.api_export import keras_export from keras.src.optimizers import optimizer @keras_export("keras.optimizers.SGD") class SGD(optimizer.Optimizer): """Gradient descent (with momentum) optimizer. Update rule for parameter `w` with gradient `g` when `momentum` is 0: ```python w = w - learning_rate * g ``` Update rule when `momentum` is larger than 0: ```python velocity = momentum * velocity - learning_rate * g w = w + velocity ``` When `nesterov=True`, this rule becomes: ```python velocity = momentum * velocity - learning_rate * g w = w + momentum * velocity - learning_rate * g ``` Args: learning_rate: A float, a `keras.optimizers.schedules.LearningRateSchedule` instance, or a callable that takes no arguments and returns the actual value to use. The learning rate. Defaults to `0.01`. momentum: float hyperparameter >= 0 that accelerates gradient descent in the relevant direction and dampens oscillations. 0 is vanilla gradient descent. Defaults to `0.0`. nesterov: boolean. Whether to apply Nesterov momentum. Defaults to `False`. {{base_optimizer_keyword_args}} """ def __init__( self, learning_rate=0.01, momentum=0.0, nesterov=False, weight_decay=None, clipnorm=None, clipvalue=None, global_clipnorm=None, use_ema=False, ema_momentum=0.99, ema_overwrite_frequency=None, loss_scale_factor=None, gradient_accumulation_steps=None, name="SGD", **kwargs, ): super().__init__( learning_rate=learning_rate, name=name, weight_decay=weight_decay, clipnorm=clipnorm, clipvalue=clipvalue, global_clipnorm=global_clipnorm, use_ema=use_ema, ema_momentum=ema_momentum, ema_overwrite_frequency=ema_overwrite_frequency, loss_scale_factor=loss_scale_factor, gradient_accumulation_steps=gradient_accumulation_steps, **kwargs, ) if not isinstance(momentum, float) or momentum < 0 or momentum > 1: raise ValueError("`momentum` must be a float between [0, 1].") self.momentum = momentum self.nesterov = nesterov def build(self, variables): """Initialize optimizer variables. SGD optimizer has one variable `momentums`, only set if `self.momentum` is not 0. Args: var_list: list of model variables to build SGD variables on. """ if self.built: return super().build(variables) self.momentums = [] if self.momentum != 0: self.momentums = self.add_optimizer_variables(variables, "momentum") def update_step(self, gradient, variable, learning_rate): """Update step given gradient and the associated model variable.""" learning_rate = ops.cast(learning_rate, variable.dtype) gradient = ops.cast(gradient, variable.dtype) m = None if self.momentum != 0: m = self.momentums[self._get_variable_index(variable)] if m is not None: momentum = ops.cast(self.momentum, variable.dtype) self.assign( m, ops.subtract( ops.multiply(m, momentum), ops.multiply(gradient, learning_rate), ), ) if self.nesterov: self.assign_add( variable, ops.subtract( ops.multiply(m, momentum), ops.multiply(gradient, learning_rate), ), ) else: self.assign_add(variable, m) else: self.assign_sub(variable, ops.multiply(gradient, learning_rate)) def get_config(self): config = super().get_config() config.update( { "momentum": self.momentum, "nesterov": self.nesterov, } ) return config SGD.__doc__ = SGD.__doc__.replace( "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args )