import gc import threading from keras.src import backend from keras.src.api_export import keras_export GLOBAL_STATE_TRACKER = threading.local() GLOBAL_SETTINGS_TRACKER = threading.local() def set_global_attribute(name, value): setattr(GLOBAL_STATE_TRACKER, name, value) def get_global_attribute(name, default=None, set_to_default=False): attr = getattr(GLOBAL_STATE_TRACKER, name, None) if attr is None and default is not None: attr = default if set_to_default: set_global_attribute(name, attr) return attr @keras_export(["keras.utils.clear_session", "keras.backend.clear_session"]) def clear_session(free_memory=True): """Resets all state generated by Keras. Keras manages a global state, which it uses to implement the Functional model-building API and to uniquify autogenerated layer names. If you are creating many models in a loop, this global state will consume an increasing amount of memory over time, and you may want to clear it. Calling `clear_session()` releases the global state: this helps avoid clutter from old models and layers, especially when memory is limited. Args: free_memory: Whether to call Python garbage collection. It's usually a good practice to call it to make sure memory used by deleted objects is immediately freed. However, it may take a few seconds to execute, so when using `clear_session()` in a short loop, you may want to skip it. Example 1: calling `clear_session()` when creating models in a loop ```python for _ in range(100): # Without `clear_session()`, each iteration of this loop will # slightly increase the size of the global state managed by Keras model = keras.Sequential([ keras.layers.Dense(10) for _ in range(10)]) for _ in range(100): # With `clear_session()` called at the beginning, # Keras starts with a blank state at each iteration # and memory consumption is constant over time. keras.backend.clear_session() model = keras.Sequential([ keras.layers.Dense(10) for _ in range(10)]) ``` Example 2: resetting the layer name generation counter >>> layers = [keras.layers.Dense(10) for _ in range(10)] >>> new_layer = keras.layers.Dense(10) >>> print(new_layer.name) dense_10 >>> keras.backend.clear_session() >>> new_layer = keras.layers.Dense(10) >>> print(new_layer.name) dense """ global GLOBAL_STATE_TRACKER global GLOBAL_SETTINGS_TRACKER GLOBAL_STATE_TRACKER = threading.local() GLOBAL_SETTINGS_TRACKER = threading.local() if backend.backend() == "tensorflow": from keras.src.utils.module_utils import tensorflow as tf tf.compat.v1.reset_default_graph() if tf.executing_eagerly(): # Clear pending nodes in eager executors, kernel caches and # step_containers. from tensorflow.python.eager import context context.context().clear_kernel_cache() elif backend.backend() == "torch": import torch._dynamo as dynamo # reset's torchdynamo's cache so that cached guards, compiled fn, etc # do not persist between clear_session() calls dynamo.reset() if free_memory: # Manually trigger garbage collection. gc.collect()