Merge pull request #6 from kakuteki/main
Fix GPU device ID mismatch causing CUDA invalid device error
This commit is contained in:
@@ -83,12 +83,33 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Configure device pytorch will use
|
||||
if torch.cuda.is_available():
|
||||
self.device = f"cuda:{self.args['gpu_number']}"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
gpu_num = self.args.get('gpu_number', 0)
|
||||
try:
|
||||
gpu_num = int(gpu_num)
|
||||
except ValueError:
|
||||
self.logger.warning(f"Invalid gpu_number value: {gpu_num}. Using 0 instead.")
|
||||
gpu_num = 0
|
||||
|
||||
max_gpu_index = torch.cuda.device_count() - 1
|
||||
if gpu_num > max_gpu_index:
|
||||
self.logger.warning(f"Requested GPU {gpu_num} not available. Using GPU 0 instead.")
|
||||
gpu_num = 0
|
||||
|
||||
try:
|
||||
self.device = torch.device(f"cuda:{gpu_num}")
|
||||
test_tensor = torch.tensor([1.0]).to(self.device)
|
||||
test_tensor = test_tensor * 2
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing CUDA device {gpu_num}: {str(e)}")
|
||||
self.logger.info("Falling back to CPU")
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.logger.info(f'Using device: {self.device}')
|
||||
|
||||
|
||||
|
||||
# Set seed if provided
|
||||
if self.args['seed'] != -1:
|
||||
np.random.seed(self.args['seed'])
|
||||
|
Reference in New Issue
Block a user