Add optional device backend support (Accelerate, torch_xla) and flexible device selection for model loading

This commit is contained in:
Zchen
2025-10-05 11:01:20 +08:00
parent 2f695a8d5d
commit 4ebd870f75
2 changed files with 76 additions and 7 deletions

View File

@@ -7,6 +7,23 @@ import os
import re
import logging
import torch
"""
Optional device backends:
- Accelerate: simplifies multi-GPU/TPU/CPU device placement
- torch_xla: TPU support (only available in TPU environments)
Both imports are optional to keep backward compatibility.
"""
try:
from accelerate import Accelerator # type: ignore
except Exception:
Accelerator = None # accelerate is optional
try:
import torch_xla.core.xla_model as xm # type: ignore
XLA_AVAILABLE = True
except Exception:
xm = None
XLA_AVAILABLE = False
import lm_decoder
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -92,7 +109,8 @@ def update_ngram_params(
def build_opt(
model_name='facebook/opt-6.7b',
cache_dir=None,
device='cuda' if torch.cuda.is_available() else 'cpu',
device=None,
accelerator=None,
):
'''
@@ -101,17 +119,50 @@ def build_opt(
Put the model onto the GPU (if available).
'''
# Resolve device automatically if not provided
if device is None:
if accelerator is not None:
device = accelerator.device
elif XLA_AVAILABLE and xm is not None:
# Use all available TPU cores for multi-core processing
device = xm.xla_device()
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Choose appropriate dtype per device
try:
device_type = device.type # torch.device or XLA device
except AttributeError:
# Fallback for XLA device objects
device_type = str(device)
if XLA_AVAILABLE and (str(device).startswith('xla') or device_type == 'xla'):
load_dtype = torch.bfloat16 # TPU prefers bfloat16
elif device_type == 'cuda':
load_dtype = torch.float16
else:
load_dtype = torch.float32
# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16,
torch_dtype=load_dtype,
)
if device != 'cpu':
# Move the model to the GPU
# Device placement
if accelerator is not None:
model = accelerator.prepare(model)
else:
model = model.to(device)
# For TPU multi-core, ensure model is replicated across all cores
if XLA_AVAILABLE and str(device).startswith('xla') and xm is not None:
# This will be handled by torch_xla internally when using xla_device()
pass
# Set the model to evaluation mode
model.eval()
@@ -466,8 +517,24 @@ def main(args):
'score_penalty_percent': float(score_penalty_percent),
}
# pick GPU
device = torch.device(f"cuda:{gpu_number}" if torch.cuda.is_available() else "cpu")
# pick device (Accelerate -> TPU -> CUDA -> CPU)
accelerator = None
if Accelerator is not None:
try:
accelerator = Accelerator()
except Exception:
accelerator = None
if accelerator is not None:
device = accelerator.device
elif XLA_AVAILABLE and xm is not None:
# Use all available TPU cores for multi-core processing
device = xm.xla_device()
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
elif torch.cuda.is_available():
device = torch.device(f"cuda:{gpu_number}")
else:
device = torch.device("cpu")
logging.info(f'Using device: {device}')
# initialize opt model
@@ -477,6 +544,7 @@ def main(args):
lm, lm_tokenizer = build_opt(
cache_dir=opt_cache_dir,
device=device,
accelerator=accelerator,
)
logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.')

View File

@@ -6,4 +6,5 @@ tensorboard
tensorboardX
typeguard
textgrid
redis
redis
accelerate>=0.33.0