import re from keras.src import ops from keras.src.api_export import keras_export from keras.src.optimizers import optimizer @keras_export(["keras.optimizers.Muon"]) class Muon(optimizer.Optimizer): """Optimizer that implements the Muon algorithm. Note that this optimizer should not be used in the following layers: 1. Embedding layer 2. Final output fully connected layer 3. Any {0,1}-D variables These should all be optimized using AdamW. The Muon optimizer can use both the Muon update step or the AdamW update step based on the following: - For any variable that isn't 2D, the AdamW step will be used. This is not configurable. - If the argument `exclude_embeddings` (defaults to `True`) is set to `True`, the AdamW step will be used. - For any variablewith a name that matches an expression listed in the argument `exclude_layers` (a list), the AdamW step will be used. - Any other variable uses the Muon step. Typically, you only need to pass the name of your densely-connected output layer to `exclude_layers`, e.g. `exclude_layers=["output_dense"]`. References: - [Original implementation](https://github.com/KellerJordan/Muon) - [Liu et al, 2025](https://arxiv.org/abs/2502.16982) Args: learning_rate: A float, `keras.optimizers.schedules.LearningRateSchedule` instance, or a callable that takes no arguments and returns the actual value to use. The learning rate. Defaults to `0.001`. adam_beta_1: A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 1st moment estimates. Defaults to `0.9`. adam_beta_2: A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 2nd moment estimates. Defaults to `0.999`. adam_weight_decay: Float. If set, weight decay is applied when using the Adam optimizer. epsilon: A small constant for numerical stability. This is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper. It be used at Adamw.Defaults to `1e-7`. exclude_layers: List of strings, keywords of layer names to exclude. All layers with keywords in their path will use adamw. exclude_embeddings: Boolean value If True, embedding layers will use adamw. muon_a: Float, parameter a of the muon algorithm. It is recommended to use the default value muon_b: Float, parameter b of the muon algorithm. It is recommended to use the default value muon_c: Float, parameter c of the muon algorithm. It is recommended to use the default value adam_lr_ratio: Float, the ratio of the learning rate when using Adam to the main learning rate. It is recommended to set it to 1 momentum: Float, momentum used by internal SGD. ns_steps: Integer, number of Newton-Schulz iterations to run. nesterov: Boolean, whether to use Nesterov-style momentum {{base_optimizer_keyword_args}} rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982 that can enhance the stability of Muon, allowing it to use the same learning rate and weight decay as Adam. Defaults to `0.2`. Set to `None` to disable this feature. """ def __init__( self, learning_rate=0.001, adam_beta_1=0.9, adam_beta_2=0.999, adam_weight_decay=0.004, epsilon=1e-7, weight_decay=0.004, clipnorm=None, clipvalue=None, global_clipnorm=None, use_ema=False, ema_momentum=0.99, ema_overwrite_frequency=None, loss_scale_factor=None, gradient_accumulation_steps=None, name="muon", exclude_layers=None, exclude_embeddings=True, muon_a=3.4445, muon_b=-4.7750, muon_c=2.0315, adam_lr_ratio=1, momentum=0.95, ns_steps=5, nesterov=True, rms_rate=0.2, **kwargs, ): super().__init__( learning_rate=learning_rate, name=name, weight_decay=weight_decay, clipnorm=clipnorm, clipvalue=clipvalue, global_clipnorm=global_clipnorm, use_ema=use_ema, ema_momentum=ema_momentum, ema_overwrite_frequency=ema_overwrite_frequency, loss_scale_factor=loss_scale_factor, gradient_accumulation_steps=gradient_accumulation_steps, **kwargs, ) self.adam_beta_1 = adam_beta_1 self.adam_beta_2 = adam_beta_2 self.epsilon = epsilon self.muon_a = muon_a self.muon_b = muon_b self.muon_c = muon_c self.adam_lr_ratio = adam_lr_ratio self.momentum = momentum self.ns_steps = ns_steps self.nesterov = nesterov self.exclude_embeddings = exclude_embeddings self.exclude_layers = exclude_layers or [] self.adam_weight_decay = adam_weight_decay self.rms_rate = rms_rate def _should_use_adamw(self, variable): # it works well to just flatten their last 3 dimensions. # any {0,1}-D parameters should all be optimized by adam if len(variable.shape) != 2: return True if self.exclude_embeddings and "embedding" in variable.path.lower(): return True for keyword in self.exclude_layers: if re.search(keyword, variable.path): return True return False def build(self, var_list): """Initialize optimizer variables. Adam optimizer has 3 types of variables: momentums, velocities and velocity_hat (only set when amsgrad is applied), Args: var_list: list of model variables to build Adam variables on. """ if self.built: return super().build(var_list) # Momentums are for both Muon and Adam self.momentums = [None] * len(var_list) # Velocities are just for Adam self.adam_velocities = [None] * len(var_list) for var in var_list: if not self._overwrite_variable_with_gradient(var): self.momentums[self._get_variable_index(var)] = ( self.add_variable_from_reference( reference_variable=var, name="momentum" ) ) if self._should_use_adamw(var): self.adam_velocities[self._get_variable_index(var)] = ( self.add_variable_from_reference( reference_variable=var, name="velocity" ) ) def update_step(self, gradient, variable, learning_rate): variable_index = self._get_variable_index(variable) m = self.momentums[variable_index] v = self.adam_velocities[variable_index] # The presence of the velocity tells us that this variable is for Adam if v is not None: # It should be noted that lr is one-tenth when using adamw. self._adamw_update_step( gradient, variable, learning_rate * self.adam_lr_ratio, m, v ) else: self._muon_update_step(gradient, variable, learning_rate, m) def _muon_update_step(self, gradient, variable, lr, m): self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) if self.nesterov: g = ops.add(gradient, self.momentum * m) else: g = m update = self.zeropower_via_newtonschulz5(g, self.ns_steps) self.assign_sub(variable, self.lr_adjust(lr * update)) def _adamw_update_step(self, gradient, variable, learning_rate, m, v): """Update step given gradient and the associated model variable.""" lr = ops.cast(learning_rate, variable.dtype) gradient = ops.cast(gradient, variable.dtype) local_step = ops.cast(self.iterations + 1, variable.dtype) adam_beta_1_power = ops.power( ops.cast(self.adam_beta_1, variable.dtype), local_step ) adam_beta_2_power = ops.power( ops.cast(self.adam_beta_2, variable.dtype), local_step ) alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power) self.assign_add( m, ops.multiply(ops.subtract(gradient, m), 1 - self.adam_beta_1) ) self.assign_add( v, ops.multiply( ops.subtract(ops.square(gradient), v), 1 - self.adam_beta_2 ), ) self.assign_sub( variable, ops.divide( ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon) ), ) def transpose_last_axis(self, X): shape = ops.shape(X) temp_order = list(range(len(shape))) temp_order[-2] = temp_order[-1] temp_order[-1] = len(shape) - 2 X = ops.transpose(X, temp_order) return X def lr_adjust(self, x): """Adjusts learning rate based on the Moonlight implementation. This method enhances the stability of Muon, allowing it to use the same learning rate and weight decay as Adam. For details, see https://arxiv.org/abs/2502.16982. For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`, where `n` and `m` are the dimensions of the matrix. """ if self.rms_rate is None: return x # moonlight version # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate def zeropower_via_newtonschulz5(self, x, steps: int): """We apply the Newton-Schulz iteration to compute matrix G. We select a quintic iteration that maximizes the slope at zero. This approach helps minimize steps, even if the iteration doesn't fully converge across the interval. The result isn't exactly UV^T (from the SVD of G), but rather an approximation like US'V^T. Despite this approximation, model performance remains unaffected compared to using the exact UV^T from the SVD. """ shape = ops.shape(x) assert len(shape) >= 2 a, b, c = self.muon_a, self.muon_b, self.muon_c if shape[-2] > shape[-1]: x = self.transpose_last_axis(x) # Ensure spectral norm is at most 1 x = x / (ops.norm(x, axis=(-2, -1), keepdims=True) + 1e-7) # Perform the NS iterations for _ in range(steps): temp_a = x @ self.transpose_last_axis(x) temp_b = b * temp_a + c * temp_a @ temp_a x = a * x + temp_b @ x if shape[-2] > shape[-1]: x = self.transpose_last_axis(x) return x def _apply_weight_decay(self, variables): for variable in variables: if not self._use_weight_decay(variable): continue if self._should_use_adamw(variable): weight_decay_value = self.adam_weight_decay else: weight_decay_value = self.weight_decay if weight_decay_value is None: continue wd = ops.cast(weight_decay_value, variable.dtype) lr = ops.cast(self.learning_rate, variable.dtype) variable.assign(variable - variable * wd * lr) def get_config(self): config = super().get_config() config.update( { "adam_beta_1": self.adam_beta_1, "adam_beta_2": self.adam_beta_2, "epsilon": self.epsilon, "exclude_layers": self.exclude_layers, "muon_a": self.muon_a, "muon_b": self.muon_b, "muon_c": self.muon_c, "adam_lr_ratio": self.adam_lr_ratio, "momentum": self.momentum, "ns_steps": self.ns_steps, "nesterov": self.nesterov, "exclude_embeddings": self.exclude_embeddings, "adam_weight_decay": self.adam_weight_decay, "rms_rate": self.rms_rate, } ) return config