from keras.src import tree from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.data_adapters.data_adapter import DataAdapter class TFDatasetAdapter(DataAdapter): """Adapter that handles `tf.data.Dataset`.""" def __init__(self, dataset, class_weight=None, distribution=None): """Initialize the TFDatasetAdapter. Args: dataset: The input `tf.data.Dataset` instance. class_weight: A map where the keys are integer class ids and values are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`. distribution: A `keras.distribution.Distribution` instance. Used to shard the input dataset into per worker/process dataset instance. """ from keras.src.utils.module_utils import tensorflow as tf if not isinstance( dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) ): raise ValueError( "Expected argument `dataset` to be a tf.data.Dataset. " f"Received: {dataset}" ) if class_weight is not None: dataset = dataset.map( make_class_weight_map_fn(class_weight) ).prefetch(tf.data.AUTOTUNE) if distribution is not None: dataset = distribution.distribute_dataset(dataset) self._dataset = dataset def get_numpy_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy for batch in self._dataset: yield tree.map_structure( convert_to_numpy, batch, none_is_leaf=False ) def get_jax_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy from keras.src.utils.module_utils import tensorflow as tf def convert_to_jax(x): if isinstance(x, tf.SparseTensor): return data_adapter_utils.tf_sparse_to_jax_sparse(x) else: # We use numpy as an intermediary because it is faster. return convert_to_numpy(x) for batch in self._dataset: yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False) def get_tf_dataset(self): return self._dataset def get_torch_dataloader(self): return data_adapter_utils.get_torch_dataloader(self._dataset) @property def builtin_prefetch(self): return True @property def num_batches(self): cardinality = self._dataset.cardinality if callable(cardinality): # `dataset.cardinality` is normally expected to be a callable. cardinality = int(self._dataset.cardinality()) else: # However, in the case of `DistributedDataset`, it's a np.int64. cardinality = int(cardinality) # Return None for Unknown and Infinite cardinality datasets if cardinality < 0: return None return cardinality @property def batch_size(self): first_element_spec = tree.flatten(self._dataset.element_spec)[0] return first_element_spec.shape[0] @property def has_partial_batch(self): return None @property def partial_batch_size(self): return None def make_class_weight_map_fn(class_weight): """Applies class weighting to a `Dataset`. The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where `y` must be a single `Tensor`. Args: class_weight: A map where the keys are integer class ids and values are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` Returns: A function that can be used with `tf.data.Dataset.map` to apply class weighting. """ from keras.src.utils.module_utils import tensorflow as tf class_weight_tensor = tf.convert_to_tensor( [ class_weight.get(int(c), 1.0) for c in range(max(class_weight.keys()) + 1) ] ) def class_weights_map_fn(*data): """Convert `class_weight` to `sample_weight`.""" x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data) if sw is not None: raise ValueError( "You cannot `class_weight` and `sample_weight` " "at the same time." ) if tree.is_nested(y): raise ValueError( "`class_weight` is only supported for Models with a single " "output." ) if y.shape.rank >= 2: y_classes = tf.__internal__.smart_cond.smart_cond( tf.shape(y)[-1] > 1, lambda: tf.argmax(y, axis=-1, output_type=tf.int32), lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32), ) else: # Special casing for rank 1, where we can guarantee sparse encoding. y_classes = tf.cast(tf.round(y), tf.int32) cw = tf.gather(class_weight_tensor, y_classes) return x, y, cw return class_weights_map_fn