Merge pull request #6 from kakuteki/main
Fix GPU device ID mismatch causing CUDA invalid device error
This commit is contained in:
@@ -83,12 +83,33 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
# Configure device pytorch will use
|
# Configure device pytorch will use
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.device = f"cuda:{self.args['gpu_number']}"
|
gpu_num = self.args.get('gpu_number', 0)
|
||||||
else:
|
try:
|
||||||
self.device = "cpu"
|
gpu_num = int(gpu_num)
|
||||||
|
except ValueError:
|
||||||
|
self.logger.warning(f"Invalid gpu_number value: {gpu_num}. Using 0 instead.")
|
||||||
|
gpu_num = 0
|
||||||
|
|
||||||
|
max_gpu_index = torch.cuda.device_count() - 1
|
||||||
|
if gpu_num > max_gpu_index:
|
||||||
|
self.logger.warning(f"Requested GPU {gpu_num} not available. Using GPU 0 instead.")
|
||||||
|
gpu_num = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.device = torch.device(f"cuda:{gpu_num}")
|
||||||
|
test_tensor = torch.tensor([1.0]).to(self.device)
|
||||||
|
test_tensor = test_tensor * 2
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error initializing CUDA device {gpu_num}: {str(e)}")
|
||||||
|
self.logger.info("Falling back to CPU")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
self.logger.info(f'Using device: {self.device}')
|
self.logger.info(f'Using device: {self.device}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Set seed if provided
|
# Set seed if provided
|
||||||
if self.args['seed'] != -1:
|
if self.args['seed'] != -1:
|
||||||
np.random.seed(self.args['seed'])
|
np.random.seed(self.args['seed'])
|
||||||
|
Reference in New Issue
Block a user