import torch
[docs]
def set_device(device) -> torch.device:
# Order of preference for fallback
fallback_order = ["cuda", "mps", "cpu"]
print(f"Requested device: {device}")
def test_device(device_str: str) -> bool:
"""Test if a device is available and functional."""
try:
test_device = torch.device(device_str)
# Test tensor creation and basic operation
test_tensor = torch.tensor([[0, 3], [5, 7]], dtype=torch.float32, device=test_device)
# Simple operation to ensure device works
_ = test_tensor + 1
return True
except Exception as e:
print(f"Device {device_str} test failed: {e}")
return False
# First try the requested device
if device in fallback_order and test_device(device):
print(f"Using requested device: {device}")
return torch.device(device)
# If requested device failed, try alternatives
if device in fallback_order:
print(f"Requested device '{device}' not available, trying alternatives...")
fallback_order.remove(device)
# Try devices in fallback order
for fallback_device in fallback_order:
if fallback_device != device and test_device(fallback_device):
print(f"Using fallback device: {fallback_device}")
return torch.device(fallback_device)
# Ultimate fallback to CPU (should always work)
print("All devices failed, forcing CPU")
return torch.device("cpu")