import math import torch from keras.src.backend import config from keras.src.backend import standardize_dtype from keras.src.backend.common import dtypes from keras.src.backend.torch.core import cast from keras.src.backend.torch.core import convert_to_tensor from keras.src.backend.torch.core import get_device from keras.src.backend.torch.numpy import pad def _segment_reduction_fn(data, segment_ids, reduction_method, num_segments): num_repeats = torch.prod( torch.tensor(data.shape[1:], device=get_device()) ).long() # To use `scatter_add` in torch, we need to replicate `segment_ids` into the # shape of `data`. segment_ids = ( segment_ids.repeat_interleave(num_repeats) .view(*data.shape) .type(torch.int64) ) num_segments = num_segments or len(torch.unique(segment_ids)) # .scatter_add does not support -1 in the indices. # Add all out-of-bound indices value to an extra dimension after # num_segments, which is removed before returning the result. # Replacing the out-of-bound indices. segment_ids = torch.where(segment_ids >= 0, segment_ids, num_segments) segment_ids = torch.where( segment_ids < num_segments, segment_ids, num_segments ) # Add one more dimension to the result shape with the "+1". shape = (num_segments + 1,) + tuple(data.shape[1:]) if reduction_method == "amax": result = torch.ones(*shape, device=get_device()) * -float("Inf") else: result = torch.zeros(*shape, device=get_device()) result = result.scatter_reduce( 0, segment_ids, data.float(), reduction_method ) # Removing the extra dimension. result = result[:-1, ...] return result.type(data.dtype) def segment_sum(data, segment_ids, num_segments=None, sorted=False): data = convert_to_tensor(data) segment_ids = convert_to_tensor(segment_ids) return _segment_reduction_fn(data, segment_ids, "sum", num_segments) def segment_max(data, segment_ids, num_segments=None, sorted=False): data = convert_to_tensor(data) segment_ids = convert_to_tensor(segment_ids) return _segment_reduction_fn(data, segment_ids, "amax", num_segments) def top_k(x, k, sorted=True): x = convert_to_tensor(x) return torch.topk(x, k, sorted=sorted) def in_top_k(targets, predictions, k): targets = convert_to_tensor(targets).type(torch.int64) targets = targets[:, None] predictions = convert_to_tensor(predictions) topk_values = top_k(predictions, k).values targets_values = torch.take_along_dim(predictions, targets, dim=-1) mask = targets_values >= topk_values return torch.any(mask, axis=-1) def logsumexp(x, axis=None, keepdims=False): x = convert_to_tensor(x) axis = tuple(range(x.dim())) if axis is None else axis return torch.logsumexp(x, dim=axis, keepdim=keepdims) def qr(x, mode="reduced"): x = convert_to_tensor(x) if mode not in {"reduced", "complete"}: raise ValueError( "`mode` argument value not supported. " "Expected one of {'reduced', 'complete'}. " f"Received: mode={mode}" ) x = convert_to_tensor(x) return torch.linalg.qr(x, mode=mode) def extract_sequences(x, sequence_length, sequence_stride): x = convert_to_tensor(x) return torch.unfold_copy( x, dimension=-1, size=sequence_length, step=sequence_stride ) def _overlap_sequences(x, sequence_stride): # Ref: https://github.com/google/jax/blob/main/jax/_src/scipy/signal.py x = convert_to_tensor(x) *batch_shape, num_sequences, sequence_length = x.shape if sequence_stride > sequence_length: raise ValueError( "`sequence_stride` must equal or less than x.shape[-1]. " f"Received: sequence_stride={sequence_stride}, " f"x.shape[-1]={sequence_length}" ) if sequence_stride < (sequence_length / num_sequences): raise ValueError( "`sequence_stride` must equal or greater than " "x.shape[-1] / x.shape[-2]. " f"Received: sequence_stride={sequence_stride}, " f"x.shape[-1]={sequence_length}, x.shape[-2]={num_sequences}" ) flat_batchsize = math.prod(batch_shape) x = torch.reshape(x, (flat_batchsize, num_sequences, sequence_length)) output_size = sequence_stride * (num_sequences - 1) + sequence_length nstep_per_segment = 1 + (sequence_length - 1) // sequence_stride # Here, we use shorter notation for axes. # B: batch_size, N: num_sequences, S: nstep_per_segment, # T: sequence_length divided by S padded_segment_len = nstep_per_segment * sequence_stride x = torch.nn.functional.pad( x, (0, padded_segment_len - sequence_length, 0, 0, 0, 0) ) x = torch.reshape( x, (flat_batchsize, num_sequences, nstep_per_segment, sequence_stride) ) # For obtaining shifted signals, this routine reinterprets flattened array # with a shrinked axis. With appropriate truncation/ padding, this # operation pushes the last padded elements of the previous row to the head # of the current row. # See implementation of `overlap_and_add` in Tensorflow for details. x = torch.permute(x, (0, 2, 1, 3)) # x: (B, S, N, T) x = torch.nn.functional.pad(x, (0, 0, 0, num_sequences, 0, 0, 0, 0)) # x: (B, S, N*2, T) shrinked = x.shape[2] - 1 x = torch.reshape(x, (flat_batchsize, -1)) x = x[:, : (nstep_per_segment * shrinked * sequence_stride)] x = torch.reshape( x, (flat_batchsize, nstep_per_segment, shrinked * sequence_stride) ) # Finally, sum shifted segments, and truncate results to the output_size. x = torch.sum(x, dim=1)[:, :output_size] return torch.reshape(x, tuple(batch_shape) + (-1,)) def _get_complex_tensor_from_tuple(x): if not isinstance(x, (tuple, list)) or len(x) != 2: raise ValueError( "Input `x` should be a tuple of two tensors - real and imaginary." f"Received: x={x}" ) # `convert_to_tensor` does not support passing complex tensors. We separate # the input out into real and imaginary and convert them separately. real, imag = x real = convert_to_tensor(real) imag = convert_to_tensor(imag) # Check shape. if real.shape != imag.shape: raise ValueError( "Input `x` should be a tuple of two tensors - real and imaginary." "Both the real and imaginary parts should have the same shape. " f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" ) # Ensure dtype is float. if not torch.is_floating_point(real) or not torch.is_floating_point(imag): raise ValueError( "At least one tensor in input `x` is not of type float." f"Received: x={x}." ) complex_input = torch.complex(real, imag) return complex_input def fft(x): complex_input = _get_complex_tensor_from_tuple(x) complex_output = torch.fft.fft(complex_input) return complex_output.real, complex_output.imag def fft2(x): complex_input = _get_complex_tensor_from_tuple(x) complex_output = torch.fft.fft2(complex_input) return complex_output.real, complex_output.imag def ifft2(x): complex_input = _get_complex_tensor_from_tuple(x) complex_output = torch.fft.ifft2(complex_input) return complex_output.real, complex_output.imag def rfft(x, fft_length=None): x = convert_to_tensor(x) complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward") return complex_output.real, complex_output.imag def irfft(x, fft_length=None): complex_input = _get_complex_tensor_from_tuple(x) return torch.fft.irfft(complex_input, n=fft_length, dim=-1, norm="backward") def stft( x, sequence_length, sequence_stride, fft_length, window="hann", center=True ): if standardize_dtype(x.dtype) not in {"float32", "float64"}: raise TypeError( "Invalid input type. Expected `float32` or `float64`. " f"Received: input type={x.dtype}" ) if fft_length < sequence_length: raise ValueError( "`fft_length` must equal or larger than `sequence_length`. " f"Received: sequence_length={sequence_length}, " f"fft_length={fft_length}" ) if isinstance(window, str): if window not in {"hann", "hamming"}: raise ValueError( "If a string is passed to `window`, it must be one of " f'`"hann"`, `"hamming"`. Received: window={window}' ) x = convert_to_tensor(x) if window is not None: if isinstance(window, str): if window == "hann": win = torch.hann_window( sequence_length, periodic=True, dtype=x.dtype, device=get_device(), ) else: win = torch.hamming_window( sequence_length, periodic=True, dtype=x.dtype, device=get_device(), ) else: win = convert_to_tensor(window, dtype=x.dtype) if len(win.shape) != 1 or win.shape[-1] != sequence_length: raise ValueError( "The shape of `window` must be equal to [sequence_length]." f"Received: window shape={win.shape}" ) else: win = torch.ones((sequence_length,), dtype=x.dtype, device=get_device()) need_unpack = False *batch_shape, samples = x.shape if len(x.shape) > 2: need_unpack = True flat_batchsize = math.prod(batch_shape) x = torch.reshape(x, (flat_batchsize, samples)) x = torch.stft( x, n_fft=fft_length, hop_length=sequence_stride, win_length=sequence_length, window=win, center=center, return_complex=True, ) if need_unpack: fft_unique_bins, num_sequences = x.shape[-2:] x = torch.reshape(x, (*batch_shape, fft_unique_bins, num_sequences)) x = torch.swapaxes(x, -2, -1) return x.real, x.imag def istft( x, sequence_length, sequence_stride, fft_length, length=None, window="hann", center=True, ): complex_input = _get_complex_tensor_from_tuple(x) dtype = complex_input.real.dtype win = None if window is not None: if isinstance(window, str): if window == "hann": win = torch.hann_window( sequence_length, periodic=True, dtype=dtype, device=get_device(), ) else: win = torch.hamming_window( sequence_length, periodic=True, dtype=dtype, device=get_device(), ) else: win = convert_to_tensor(window, dtype=dtype) if len(win.shape) != 1 or win.shape[-1] != sequence_length: raise ValueError( "The shape of `window` must be equal to [sequence_length]." f"Received: window shape={win.shape}" ) if sequence_length == fft_length and center is True and win is not None: # can be fallen back to torch.istft need_unpack = False *batch_shape, num_sequences, fft_unique_bins = complex_input.shape if len(complex_input.shape) > 3: need_unpack = True flat_batchsize = math.prod(batch_shape) complex_input = torch.reshape( complex_input, (flat_batchsize, num_sequences, fft_unique_bins) ) complex_input = torch.swapaxes(complex_input, -2, -1) x = torch.istft( complex_input, n_fft=fft_length, hop_length=sequence_stride, win_length=sequence_length, window=win, center=center, length=length, return_complex=False, ) if need_unpack: samples = x.shape[-1] x = torch.reshape(x, (*batch_shape, samples)) return x # custom implementation with irfft and _overlap_sequences # references: # torch: aten/src/ATen/native/SpectralOps.cpp # tf: tf.signal.inverse_stft_window_fn x = irfft(x, fft_length) expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) if win is not None: l_pad = (fft_length - sequence_length) // 2 r_pad = fft_length - sequence_length - l_pad win = pad(win, [[l_pad, r_pad]], "constant") # square and sum _sequence_length = sequence_length + l_pad + r_pad denom = torch.square(win) overlaps = -(-_sequence_length // sequence_stride) denom = pad(denom, [(0, overlaps * sequence_stride - _sequence_length)]) denom = torch.reshape(denom, [overlaps, sequence_stride]) denom = torch.sum(denom, 0, keepdims=True) denom = torch.tile(denom, [overlaps, 1]) denom = torch.reshape(denom, [overlaps * sequence_stride]) win = torch.divide(win, denom[:_sequence_length]) x = torch.multiply(x, win) x = _overlap_sequences(x, sequence_stride) start = 0 if center is False else fft_length // 2 if length is not None: end = start + length elif center is True: end = -(fft_length // 2) else: end = expected_output_len return x[..., start:end] def rsqrt(x): x = convert_to_tensor(x) return torch.rsqrt(x) def erf(x): x = convert_to_tensor(x) return torch.erf(x) def erfinv(x): x = convert_to_tensor(x) return torch.erfinv(x) def solve(a, b): a = convert_to_tensor(a) b = convert_to_tensor(b) return torch.linalg.solve(a, b) def norm(x, ord=None, axis=None, keepdims=False): x = convert_to_tensor(x) if standardize_dtype(x.dtype) == "int64": dtype = config.floatx() else: dtype = dtypes.result_type(x.dtype, float) x = cast(x, dtype) return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims) def logdet(x): x = convert_to_tensor(x) return torch.logdet(x)