Merge pull request #6 from kakuteki/main

Fix GPU device ID mismatch causing CUDA invalid device error
This commit is contained in:
Nick Card
2025-07-28 21:40:21 -07:00
committed by GitHub

View File

@@ -83,12 +83,33 @@ class BrainToTextDecoder_Trainer:
# Configure device pytorch will use
if torch.cuda.is_available():
self.device = f"cuda:{self.args['gpu_number']}"
else:
self.device = "cpu"
gpu_num = self.args.get('gpu_number', 0)
try:
gpu_num = int(gpu_num)
except ValueError:
self.logger.warning(f"Invalid gpu_number value: {gpu_num}. Using 0 instead.")
gpu_num = 0
max_gpu_index = torch.cuda.device_count() - 1
if gpu_num > max_gpu_index:
self.logger.warning(f"Requested GPU {gpu_num} not available. Using GPU 0 instead.")
gpu_num = 0
try:
self.device = torch.device(f"cuda:{gpu_num}")
test_tensor = torch.tensor([1.0]).to(self.device)
test_tensor = test_tensor * 2
except Exception as e:
self.logger.error(f"Error initializing CUDA device {gpu_num}: {str(e)}")
self.logger.info("Falling back to CPU")
self.device = torch.device("cpu")
else:
self.device = torch.device("cpu")
self.logger.info(f'Using device: {self.device}')
# Set seed if provided
if self.args['seed'] != -1:
np.random.seed(self.args['seed'])