Add optional device backend support (Accelerate, torch_xla) and flexible device selection for model loading
This commit is contained in:
@@ -7,6 +7,23 @@ import os
|
||||
import re
|
||||
import logging
|
||||
import torch
|
||||
"""
|
||||
Optional device backends:
|
||||
- Accelerate: simplifies multi-GPU/TPU/CPU device placement
|
||||
- torch_xla: TPU support (only available in TPU environments)
|
||||
Both imports are optional to keep backward compatibility.
|
||||
"""
|
||||
try:
|
||||
from accelerate import Accelerator # type: ignore
|
||||
except Exception:
|
||||
Accelerator = None # accelerate is optional
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm # type: ignore
|
||||
XLA_AVAILABLE = True
|
||||
except Exception:
|
||||
xm = None
|
||||
XLA_AVAILABLE = False
|
||||
import lm_decoder
|
||||
from functools import lru_cache
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -92,7 +109,8 @@ def update_ngram_params(
|
||||
def build_opt(
|
||||
model_name='facebook/opt-6.7b',
|
||||
cache_dir=None,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
device=None,
|
||||
accelerator=None,
|
||||
):
|
||||
|
||||
'''
|
||||
@@ -101,17 +119,50 @@ def build_opt(
|
||||
Put the model onto the GPU (if available).
|
||||
'''
|
||||
|
||||
# Resolve device automatically if not provided
|
||||
if device is None:
|
||||
if accelerator is not None:
|
||||
device = accelerator.device
|
||||
elif XLA_AVAILABLE and xm is not None:
|
||||
# Use all available TPU cores for multi-core processing
|
||||
device = xm.xla_device()
|
||||
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Choose appropriate dtype per device
|
||||
try:
|
||||
device_type = device.type # torch.device or XLA device
|
||||
except AttributeError:
|
||||
# Fallback for XLA device objects
|
||||
device_type = str(device)
|
||||
|
||||
if XLA_AVAILABLE and (str(device).startswith('xla') or device_type == 'xla'):
|
||||
load_dtype = torch.bfloat16 # TPU prefers bfloat16
|
||||
elif device_type == 'cuda':
|
||||
load_dtype = torch.float16
|
||||
else:
|
||||
load_dtype = torch.float32
|
||||
|
||||
# load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=cache_dir,
|
||||
torch_dtype=torch.float16,
|
||||
torch_dtype=load_dtype,
|
||||
)
|
||||
|
||||
if device != 'cpu':
|
||||
# Move the model to the GPU
|
||||
# Device placement
|
||||
if accelerator is not None:
|
||||
model = accelerator.prepare(model)
|
||||
else:
|
||||
model = model.to(device)
|
||||
# For TPU multi-core, ensure model is replicated across all cores
|
||||
if XLA_AVAILABLE and str(device).startswith('xla') and xm is not None:
|
||||
# This will be handled by torch_xla internally when using xla_device()
|
||||
pass
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
@@ -466,8 +517,24 @@ def main(args):
|
||||
'score_penalty_percent': float(score_penalty_percent),
|
||||
}
|
||||
|
||||
# pick GPU
|
||||
device = torch.device(f"cuda:{gpu_number}" if torch.cuda.is_available() else "cpu")
|
||||
# pick device (Accelerate -> TPU -> CUDA -> CPU)
|
||||
accelerator = None
|
||||
if Accelerator is not None:
|
||||
try:
|
||||
accelerator = Accelerator()
|
||||
except Exception:
|
||||
accelerator = None
|
||||
|
||||
if accelerator is not None:
|
||||
device = accelerator.device
|
||||
elif XLA_AVAILABLE and xm is not None:
|
||||
# Use all available TPU cores for multi-core processing
|
||||
device = xm.xla_device()
|
||||
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{gpu_number}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f'Using device: {device}')
|
||||
|
||||
# initialize opt model
|
||||
@@ -477,6 +544,7 @@ def main(args):
|
||||
lm, lm_tokenizer = build_opt(
|
||||
cache_dir=opt_cache_dir,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.')
|
||||
|
||||
|
@@ -6,4 +6,5 @@ tensorboard
|
||||
tensorboardX
|
||||
typeguard
|
||||
textgrid
|
||||
redis
|
||||
redis
|
||||
accelerate>=0.33.0
|
Reference in New Issue
Block a user