From 4ebd870f7564cd46dc796f2ecb22348103ca240e Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 5 Oct 2025 11:01:20 +0800 Subject: [PATCH] Add optional device backend support (Accelerate, torch_xla) and flexible device selection for model loading --- language_model/language-model-standalone.py | 80 +++++++++++++++++++-- language_model/requirements.txt | 3 +- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/language_model/language-model-standalone.py b/language_model/language-model-standalone.py index 936fcdf..7938fbc 100644 --- a/language_model/language-model-standalone.py +++ b/language_model/language-model-standalone.py @@ -7,6 +7,23 @@ import os import re import logging import torch +""" +Optional device backends: +- Accelerate: simplifies multi-GPU/TPU/CPU device placement +- torch_xla: TPU support (only available in TPU environments) +Both imports are optional to keep backward compatibility. +""" +try: + from accelerate import Accelerator # type: ignore +except Exception: + Accelerator = None # accelerate is optional + +try: + import torch_xla.core.xla_model as xm # type: ignore + XLA_AVAILABLE = True +except Exception: + xm = None + XLA_AVAILABLE = False import lm_decoder from functools import lru_cache from transformers import AutoModelForCausalLM, AutoTokenizer @@ -92,7 +109,8 @@ def update_ngram_params( def build_opt( model_name='facebook/opt-6.7b', cache_dir=None, - device='cuda' if torch.cuda.is_available() else 'cpu', + device=None, + accelerator=None, ): ''' @@ -101,17 +119,50 @@ def build_opt( Put the model onto the GPU (if available). ''' + # Resolve device automatically if not provided + if device is None: + if accelerator is not None: + device = accelerator.device + elif XLA_AVAILABLE and xm is not None: + # Use all available TPU cores for multi-core processing + device = xm.xla_device() + logging.info(f"TPU cores available: {xm.xrt_world_size()}") + elif torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Choose appropriate dtype per device + try: + device_type = device.type # torch.device or XLA device + except AttributeError: + # Fallback for XLA device objects + device_type = str(device) + + if XLA_AVAILABLE and (str(device).startswith('xla') or device_type == 'xla'): + load_dtype = torch.bfloat16 # TPU prefers bfloat16 + elif device_type == 'cuda': + load_dtype = torch.float16 + else: + load_dtype = torch.float32 + # load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir=cache_dir, - torch_dtype=torch.float16, + torch_dtype=load_dtype, ) - if device != 'cpu': - # Move the model to the GPU + # Device placement + if accelerator is not None: + model = accelerator.prepare(model) + else: model = model.to(device) + # For TPU multi-core, ensure model is replicated across all cores + if XLA_AVAILABLE and str(device).startswith('xla') and xm is not None: + # This will be handled by torch_xla internally when using xla_device() + pass # Set the model to evaluation mode model.eval() @@ -466,8 +517,24 @@ def main(args): 'score_penalty_percent': float(score_penalty_percent), } - # pick GPU - device = torch.device(f"cuda:{gpu_number}" if torch.cuda.is_available() else "cpu") + # pick device (Accelerate -> TPU -> CUDA -> CPU) + accelerator = None + if Accelerator is not None: + try: + accelerator = Accelerator() + except Exception: + accelerator = None + + if accelerator is not None: + device = accelerator.device + elif XLA_AVAILABLE and xm is not None: + # Use all available TPU cores for multi-core processing + device = xm.xla_device() + logging.info(f"TPU cores available: {xm.xrt_world_size()}") + elif torch.cuda.is_available(): + device = torch.device(f"cuda:{gpu_number}") + else: + device = torch.device("cpu") logging.info(f'Using device: {device}') # initialize opt model @@ -477,6 +544,7 @@ def main(args): lm, lm_tokenizer = build_opt( cache_dir=opt_cache_dir, device=device, + accelerator=accelerator, ) logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.') diff --git a/language_model/requirements.txt b/language_model/requirements.txt index 933f5d1..ad81b1a 100644 --- a/language_model/requirements.txt +++ b/language_model/requirements.txt @@ -6,4 +6,5 @@ tensorboard tensorboardX typeguard textgrid -redis \ No newline at end of file +redis +accelerate>=0.33.0 \ No newline at end of file