from keras.src.backend.config import backend if backend() == "torch": # When using the torch backend, # torch needs to be imported first, otherwise it will segfault # upon import. import torch from keras.src.api_export import keras_export from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.keras_tensor import any_symbolic_tensors from keras.src.backend.common.keras_tensor import is_keras_tensor from keras.src.backend.common.masking import get_keras_mask from keras.src.backend.common.masking import set_keras_mask from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import Variable from keras.src.backend.common.variables import get_autocast_scope from keras.src.backend.common.variables import is_float_dtype from keras.src.backend.common.variables import is_int_dtype from keras.src.backend.common.variables import standardize_dtype from keras.src.backend.common.variables import standardize_shape from keras.src.backend.config import epsilon from keras.src.backend.config import floatx from keras.src.backend.config import image_data_format from keras.src.backend.config import set_epsilon from keras.src.backend.config import set_floatx from keras.src.backend.config import set_image_data_format from keras.src.backend.config import standardize_data_format # Import backend functions. if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None else: raise ValueError(f"Unable to import backend : {backend()}") @keras_export("keras.Variable") class Variable(BackendVariable): # noqa: F811 pass backend_name_scope = name_scope # noqa: F405 @keras_export("keras.name_scope") class name_scope(backend_name_scope): pass @keras_export("keras.device") def device(device_name): return device_scope(device_name) # noqa: F405