import tensorflow as tf from keras.src.backend.common import standardize_dtype from keras.src.backend.config import floatx from keras.src.random.seed_generator import SeedGenerator from keras.src.random.seed_generator import draw_seed from keras.src.random.seed_generator import make_default_seed def _cast_seed(seed): # TensorFlow has a device placement issue that `Variable` must be int64 # in `SeedGenerator`. However, all `tf.random.stateless_*` expect the seed # to be int32 to run with XLA. # This function addresses the inconsistency using `floormod`. # Ref: https://www.tensorflow.org/api_docs/python/tf/random if standardize_dtype(seed.dtype) == "int32": return seed else: seed = tf.cast(tf.math.floormod(seed, tf.int32.max - 1), dtype="int32") return seed def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed ) def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_uniform( shape=shape, minval=tf.cast(minval, dtype), maxval=tf.cast(maxval, dtype), dtype=dtype, seed=seed, ) def categorical(logits, num_samples, dtype="int64", seed=None): seed = _cast_seed(draw_seed(seed)) output = tf.random.stateless_categorical(logits, num_samples, seed=seed) return tf.cast(output, dtype) def randint(shape, minval, maxval, dtype="int32", seed=None): intermediate_dtype = dtype if standardize_dtype(dtype) not in ["int32", "int64"]: intermediate_dtype = "int64" seed = _cast_seed(draw_seed(seed)) output = tf.random.stateless_uniform( shape=shape, minval=minval, maxval=maxval, dtype=intermediate_dtype, seed=seed, ) return tf.cast(output, dtype) def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_truncated_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed ) def _get_concrete_noise_shape(inputs, noise_shape): if noise_shape is None: return tf.shape(inputs) concrete_inputs_shape = tf.shape(inputs) concrete_noise_shape = [] for i, value in enumerate(noise_shape): concrete_noise_shape.append( concrete_inputs_shape[i] if value is None else value ) return concrete_noise_shape def dropout(inputs, rate, noise_shape=None, seed=None): seed = _cast_seed(draw_seed(seed)) noise_shape = _get_concrete_noise_shape(inputs, noise_shape) return tf.nn.experimental.stateless_dropout( inputs, rate=rate, noise_shape=noise_shape, seed=seed, ) def shuffle(x, axis=0, seed=None): seed = _cast_seed(draw_seed(seed)) indices = tf.argsort( tf.random.stateless_uniform(shape=[tf.shape(x)[axis]], seed=seed) ) return tf.gather(x, indices, axis=axis) def gamma(shape, alpha, dtype=None, seed=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) # TODO: `tf.random.stateless_gamma` doesn't support bfloat16 intermediate_dtype = dtype if standardize_dtype(dtype) == "bfloat16": intermediate_dtype = "float32" return tf.cast( tf.random.stateless_gamma( shape, alpha=alpha, dtype=intermediate_dtype, seed=seed, ), dtype, ) def binomial(shape, counts, probabilities, dtype=None, seed=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) # TODO: `tf.random.stateless_binomial` doesn't support bfloat16 intermediate_dtype = dtype if standardize_dtype(dtype) == "bfloat16": intermediate_dtype = "float32" return tf.cast( tf.random.stateless_binomial( shape=shape, seed=seed, counts=counts, probs=probabilities, output_dtype=intermediate_dtype, ), dtype, ) def beta(shape, alpha, beta, dtype=None, seed=None): dtype = dtype or floatx() # since tensorflow doesn't offer a beta distribution function # so we'll use the formula U(a,b) = (X(a) / (X(a) + Y(b)), # where U(a,b) is a beta-distributed random variable with # parameters a and b, and X(a) and Y(b) are gamma-distributed # random variables with parameters a and b respectively. # Additionally, we'll use two different seeds for our two # gamma random variables to prevent any unintended # dependencies and correlations between the generated values # due to the usage of same seed. seed_1 = _cast_seed(draw_seed(seed)) # The choice of 12 is totally arbitrary, as we're # incrementing the first drawn seed by a CONSTANT to # ensure deterministic results. seed_2 = seed_1 + 12 # TODO: `tf.random.stateless_gamma` doesn't support bfloat16 intermediate_dtype = dtype if standardize_dtype(dtype) == "bfloat16": intermediate_dtype = "float32" alpha = tf.convert_to_tensor(alpha, dtype=intermediate_dtype) beta = tf.convert_to_tensor(beta, dtype=intermediate_dtype) # tensorflow's tf.random.stateless_gamma has a bit of unconventional # implementation of the stateless_gamma function where it checks the # broadcastability of alpha's shape with ONLY the RIGHTMOST dimension of # the specified output shape instead of considering the whole. # Consequently, it then results in errors for perfectly broadcastable shapes # such as for output shape of (2, 3) and alpha shape of (1, 3) # So to resolve this, we explicitly broadcast alpha and beta to shape before # passing them to the stateless_gamma function. alpha = tf.broadcast_to(alpha, shape) beta = tf.broadcast_to(beta, shape) gamma_a = tf.cast( tf.random.stateless_gamma( shape=shape, seed=seed_1, alpha=alpha, dtype=intermediate_dtype ), dtype, ) gamma_b = tf.cast( tf.random.stateless_gamma( shape=shape, seed=seed_2, alpha=beta, dtype=intermediate_dtype ), dtype, ) sample = gamma_a / (gamma_a + gamma_b) return sample