import torch


# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)

def get_optimal_device():
  if torch.cuda.is_available():
      return torch.device("cuda")
  if has_mps:
      return torch.device("mps")
  return torch.device("cpu")
