"""!!!DO NOT USE!!! Distribution related class for Tensorflow backend. This is just a prototype and we might want to unify it with other backends in the future. """ import tensorflow as tf from tensorflow.experimental import dtensor def list_devices(device_type=None): """Return all the available devices based on the device type. Note that this should return the global devices in a distributed setting. Args: device_type: string of `"cpu"`, `"gpu"` or `"tpu"`. Default to `gpu` or `tpu` if available when device_type is not provided. Otherwise will return the `cpu` devices. Return: List of devices that are available for distribute computation. """ device_type = device_type.upper() if device_type else None # DTensor doesn't support getting global devices, even when knowing the # Mesh. Use TF API instead to get global devices. Coordinator service is # enabled by default with DTensor, so that list_logical_devices() returns # a list of global devices. More context can be found in b/254911601. tf_devices = tf.config.list_logical_devices(device_type=device_type) cpu_devices = [] other_devices = [] for device in tf_devices: if device.device_type.lower() == "cpu": cpu_devices.append(device) else: other_devices.append(device) if device_type is None: tf_devices = other_devices if len(other_devices) > 0 else cpu_devices return [ f"{device.device_type.lower()}:{device.name.split(':')[-1]}" for device in tf_devices ] def distribute_value(value, tensor_layout): # TODO pass def _to_backend_mesh(device_mesh): """Convert the DeviceMesh to Tensorflow backend specific Mesh. Args: device_mesh: DeviceMesh instance to convert. Returns: A `tf.dtensor.Mesh` instance. """ mesh_dims = list(zip(device_mesh.axis_names, device_mesh.shape)) return dtensor.create_distributed_mesh( mesh_dims=mesh_dims, local_devices=device_mesh.devices.flatten() ) def _to_backend_layout(tensor_layout): """Convert the TensorLayout to Tensorflow backend specific Sharding. Args: tensor_layout: TensorLayout instance to convert. Returns: A `tf.dtensor.Layout` instance. """ if tensor_layout.device_mesh is None: raise ValueError( "Cannot create sharding when device mesh is not set for " "TensorLayout." ) sharding_specs = [ axis if axis else dtensor.UNSHARDED for axis in tensor_layout.axes ] dtensor_mesh = tensor_layout.device_mesh.backend_mesh return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)