1006 lines
42 KiB
Python
1006 lines
42 KiB
Python
import os
|
||
import tensorflow as tf
|
||
import numpy as np
|
||
import time
|
||
import json
|
||
import pickle
|
||
import logging
|
||
import pathlib
|
||
import sys
|
||
from typing import Dict, Any, Tuple, Optional, List
|
||
from omegaconf import OmegaConf
|
||
|
||
# For accurate PER calculation
|
||
try:
|
||
import editdistance
|
||
except ImportError:
|
||
print("Warning: editdistance not available, falling back to approximation")
|
||
editdistance = None
|
||
|
||
from rnn_model_tf import (
|
||
TripleGRUDecoder,
|
||
CTCLoss,
|
||
create_tpu_strategy,
|
||
build_model_for_tpu,
|
||
configure_mixed_precision
|
||
)
|
||
from dataset_tf import (
|
||
BrainToTextDatasetTF,
|
||
DataAugmentationTF,
|
||
train_test_split_indices,
|
||
create_input_fn
|
||
)
|
||
|
||
|
||
class BrainToTextDecoderTrainerTF:
|
||
"""
|
||
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
|
||
|
||
This trainer implements the same training logic as the PyTorch version but uses
|
||
TensorFlow operations optimized for TPU hardware.
|
||
"""
|
||
|
||
def __init__(self, args: Dict[str, Any]):
|
||
"""
|
||
Initialize the TensorFlow trainer
|
||
|
||
Args:
|
||
args: Configuration dictionary containing all training parameters
|
||
"""
|
||
self.args = args
|
||
self.logger = None
|
||
|
||
# Initialize TPU strategy
|
||
self.strategy = create_tpu_strategy()
|
||
if self.strategy is None:
|
||
raise RuntimeError("Failed to create TPU strategy - strategy is None")
|
||
|
||
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
||
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
||
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
|
||
print(" and drop_remainder=True to avoid dynamic shape warnings")
|
||
|
||
# Configure mixed precision for TPU v5e-8
|
||
if args.get('use_amp', True):
|
||
configure_mixed_precision()
|
||
self.mixed_precision = True
|
||
else:
|
||
self.mixed_precision = False
|
||
|
||
# Initialize tracking variables
|
||
self.best_val_per = float('inf')
|
||
self.best_val_loss = float('inf')
|
||
|
||
# Setup directories
|
||
if args['mode'] == 'train':
|
||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||
|
||
if (args.get('save_best_checkpoint', True) or
|
||
args.get('save_all_val_steps', False) or
|
||
args.get('save_final_model', False)):
|
||
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||
|
||
# Setup logging
|
||
self._setup_logging()
|
||
|
||
# Set random seeds
|
||
if self.args['seed'] != -1:
|
||
tf.random.set_seed(self.args['seed'])
|
||
np.random.seed(self.args['seed'])
|
||
|
||
# Initialize datasets
|
||
self._initialize_datasets()
|
||
|
||
|
||
# Build model within strategy scope
|
||
with self.strategy.scope():
|
||
self.model = self._build_model()
|
||
self.optimizer = self._create_optimizer()
|
||
|
||
print("🔧 Initializing optimizer for TPU training...")
|
||
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||
|
||
# ========================= SOLUTION =========================
|
||
# Explicitly build optimizer within strategy scope before training.
|
||
# This forces creation of all slot variables (e.g., AdamW momentum)
|
||
# avoiding lazy initialization inside @tf.function which loses context.
|
||
# Note: Model must be built first for .build() to work.
|
||
# The _log_model_info method builds the model via forward pass.
|
||
|
||
# Ensure model is built (will be called later in _log_model_info anyway)
|
||
if not self.model.built:
|
||
dummy_batch_size = 2
|
||
dummy_time_steps = 100
|
||
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
|
||
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
|
||
_ = self.model(dummy_features, dummy_day_idx, training=False)
|
||
|
||
print("🔧 Building optimizer with model variables...")
|
||
self.optimizer.build(self.model.trainable_variables)
|
||
print("✅ Optimizer built successfully")
|
||
# ============================================================
|
||
|
||
print("✅ Optimizer ready for TPU training")
|
||
|
||
self.lr_scheduler = self._create_lr_scheduler()
|
||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||
|
||
# Create unified checkpoint management
|
||
self.ckpt = tf.train.Checkpoint(
|
||
optimizer=self.optimizer,
|
||
model=self.model
|
||
)
|
||
self.ckpt_manager = tf.train.CheckpointManager(
|
||
self.ckpt,
|
||
directory=self.args['checkpoint_dir'],
|
||
max_to_keep=5 # Keep only the 5 most recent checkpoints
|
||
)
|
||
|
||
# Try to restore from latest checkpoint
|
||
if self.ckpt_manager.latest_checkpoint:
|
||
print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||
if self.logger:
|
||
self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||
# expect_partial() avoids failures due to model structure changes
|
||
self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial()
|
||
print("✅ Checkpoint restored successfully")
|
||
else:
|
||
print("🆕 Initializing training from scratch")
|
||
if self.logger:
|
||
self.logger.info("Initializing training from scratch")
|
||
|
||
# Log model information
|
||
self._log_model_info()
|
||
|
||
# Adversarial training configuration
|
||
adv_cfg = self.args.get('adversarial', {})
|
||
self.adv_enabled = adv_cfg.get('enabled', False)
|
||
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
|
||
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
|
||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||
|
||
# Manual weight decay handling - disable since AdamW handles it
|
||
self.manual_weight_decay = False
|
||
if self.args.get('weight_decay', 0.0) > 0:
|
||
print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}")
|
||
else:
|
||
print("💡 No weight decay configured")
|
||
|
||
if self.adv_enabled:
|
||
if self.logger:
|
||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||
f"warmup_steps={self.adv_warmup_steps}")
|
||
else:
|
||
print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||
f"warmup_steps={self.adv_warmup_steps}")
|
||
|
||
def _setup_logging(self):
|
||
"""Setup logging configuration"""
|
||
self.logger = logging.getLogger(__name__)
|
||
for handler in self.logger.handlers[:]:
|
||
self.logger.removeHandler(handler)
|
||
self.logger.setLevel(logging.INFO)
|
||
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||
|
||
if self.args['mode'] == 'train':
|
||
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
|
||
fh.setFormatter(formatter)
|
||
self.logger.addHandler(fh)
|
||
|
||
sh = logging.StreamHandler(sys.stdout)
|
||
sh.setFormatter(formatter)
|
||
self.logger.addHandler(sh)
|
||
|
||
self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
|
||
if self.mixed_precision:
|
||
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
|
||
|
||
def _configure_cpu_optimization(self):
|
||
"""Configure CPU utilization to make use of 224 cores for data pipeline"""
|
||
import multiprocessing
|
||
|
||
# Get available CPU cores
|
||
available_cores = multiprocessing.cpu_count()
|
||
print(f"💻 Available CPU cores: {available_cores}")
|
||
|
||
# Optimize for data pipeline parallelism
|
||
# For 224 cores, use more threads for better data loading performance
|
||
if available_cores >= 200: # High core count system
|
||
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
|
||
intra_op_threads = min(32, available_cores // 6)
|
||
else:
|
||
# Use ~1/4 of cores for inter-op (between operations)
|
||
# Use ~1/8 of cores for intra-op (within operations)
|
||
inter_op_threads = min(32, available_cores // 4)
|
||
intra_op_threads = min(16, available_cores // 8)
|
||
|
||
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
||
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
||
|
||
print(f"🔧 CPU optimization configured:")
|
||
print(f" Inter-op parallelism: {inter_op_threads} threads")
|
||
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||
|
||
def _get_tpu_status(self) -> str:
|
||
"""Get current TPU status and HBM utilization info"""
|
||
try:
|
||
# Get TPU devices
|
||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||
|
||
if not tpu_devices:
|
||
return "TPU: No devices"
|
||
|
||
# Get strategy info
|
||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||
|
||
# Get TPU memory info using the working /device:TPU:X format
|
||
try:
|
||
# Check all TPU devices for memory usage
|
||
active_cores = 0
|
||
total_current_mb = 0
|
||
max_peak_mb = 0
|
||
|
||
for device in tpu_devices:
|
||
try:
|
||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||
if memory_info and 'current' in memory_info:
|
||
current_mb = memory_info['current'] // (1024 * 1024)
|
||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||
|
||
if current_mb > 1: # >1MB considered active
|
||
active_cores += 1
|
||
total_current_mb += current_mb
|
||
max_peak_mb = max(max_peak_mb, peak_mb)
|
||
except:
|
||
continue
|
||
|
||
if active_cores > 0:
|
||
if active_cores == 1:
|
||
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
|
||
else:
|
||
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
|
||
else:
|
||
hbm_info = "HBM:idle"
|
||
|
||
except Exception:
|
||
# Fallback: simple TPU activity check
|
||
try:
|
||
with tf.device('/TPU:0'):
|
||
_ = tf.constant(1.0)
|
||
hbm_info = "HBM:active"
|
||
except Exception:
|
||
hbm_info = "HBM:inactive"
|
||
|
||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||
f"{hbm_info}")
|
||
|
||
except Exception as e:
|
||
return f"TPU: status_error({str(e)[:20]})"
|
||
|
||
def _get_detailed_tpu_status(self) -> str:
|
||
"""Get detailed TPU status for training start"""
|
||
try:
|
||
# Get TPU devices
|
||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||
|
||
if not tpu_devices:
|
||
return "❌ No TPU devices detected"
|
||
|
||
# Get strategy info
|
||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||
strategy_type = type(self.strategy).__name__
|
||
|
||
# Get TPU HBM memory info using working device format
|
||
try:
|
||
active_cores = 0
|
||
total_current_gb = 0
|
||
max_peak_gb = 0
|
||
memory_details = []
|
||
|
||
for i, device in enumerate(tpu_devices):
|
||
try:
|
||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||
if memory_info and 'current' in memory_info:
|
||
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||
|
||
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
|
||
active_cores += 1
|
||
total_current_gb += current_gb
|
||
max_peak_gb = max(max_peak_gb, peak_gb)
|
||
if current_gb > 0:
|
||
memory_details.append(f"Core{i}:{current_gb}GB")
|
||
except:
|
||
continue
|
||
|
||
if active_cores > 0:
|
||
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
|
||
estimated_per_core = 16 # Conservative estimate
|
||
estimated_total_gb = estimated_per_core * len(tpu_devices)
|
||
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
|
||
else:
|
||
hbm_usage = "HBM: 0GB/256GB (idle)"
|
||
|
||
except Exception:
|
||
hbm_usage = "HBM: unavailable"
|
||
|
||
# Simple TPU test
|
||
try:
|
||
with tf.device('/TPU:0'):
|
||
test_result = tf.constant([1.0, 2.0])
|
||
_ = tf.reduce_sum(test_result)
|
||
tpu_test = "✅ responsive"
|
||
except Exception as e:
|
||
tpu_test = f"❌ test_failed({str(e)[:15]})"
|
||
|
||
return (f"TPU Devices: {len(tpu_devices)} | "
|
||
f"Strategy: {strategy_type} | "
|
||
f"Cores: {num_replicas} | "
|
||
f"{hbm_usage} | "
|
||
f"Test: {tpu_test}")
|
||
|
||
except Exception as e:
|
||
return f"❌ TPU status check failed: {str(e)[:50]}"
|
||
|
||
def _initialize_datasets(self):
|
||
"""Initialize training and validation datasets"""
|
||
# Create file paths
|
||
train_file_paths = [
|
||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
|
||
for s in self.args['dataset']['sessions']
|
||
]
|
||
val_file_paths = [
|
||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
|
||
for s in self.args['dataset']['sessions']
|
||
]
|
||
|
||
# Validate no duplicates
|
||
if len(set(train_file_paths)) != len(train_file_paths):
|
||
raise ValueError("Duplicate sessions in train dataset")
|
||
if len(set(val_file_paths)) != len(val_file_paths):
|
||
raise ValueError("Duplicate sessions in val dataset")
|
||
|
||
# Split trials
|
||
train_trials, _ = train_test_split_indices(
|
||
file_paths=train_file_paths,
|
||
test_percentage=0,
|
||
seed=self.args['dataset']['seed'],
|
||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||
)
|
||
|
||
_, val_trials = train_test_split_indices(
|
||
file_paths=val_file_paths,
|
||
test_percentage=1,
|
||
seed=self.args['dataset']['seed'],
|
||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||
)
|
||
|
||
# Save trial splits
|
||
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||
json.dump({'train': train_trials, 'val': val_trials}, f)
|
||
|
||
# Create TensorFlow datasets with aggressive data preloading for TPU optimization
|
||
# Monitor memory usage during data preloading
|
||
import psutil
|
||
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
|
||
print("🔄 Initializing training dataset with GPU-style memory management...")
|
||
init_start_time = time.time()
|
||
self.train_dataset_tf = BrainToTextDatasetTF(
|
||
trial_indices=train_trials,
|
||
n_batches=self.args['num_training_batches'],
|
||
split='train',
|
||
batch_size=self.args['dataset']['batch_size'],
|
||
days_per_batch=self.args['dataset']['days_per_batch'],
|
||
random_seed=self.args['dataset']['seed'],
|
||
must_include_days=self.args['dataset'].get('must_include_days'),
|
||
feature_subset=self.args['dataset'].get('feature_subset'),
|
||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
||
)
|
||
|
||
# Log training dataset initialization performance
|
||
train_init_time = time.time() - init_start_time
|
||
train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
train_memory_used = train_memory_mb - initial_memory_mb
|
||
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
|
||
|
||
print("🔄 Initializing validation dataset with GPU-style memory management...")
|
||
val_init_start_time = time.time()
|
||
self.val_dataset_tf = BrainToTextDatasetTF(
|
||
trial_indices=val_trials,
|
||
n_batches=None, # Use all validation data
|
||
split='test',
|
||
batch_size=self.args['dataset']['batch_size'],
|
||
days_per_batch=1, # One day per validation batch
|
||
random_seed=self.args['dataset']['seed'],
|
||
feature_subset=self.args['dataset'].get('feature_subset'),
|
||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
||
)
|
||
|
||
# Log validation dataset initialization performance
|
||
val_init_time = time.time() - val_init_start_time
|
||
final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
total_memory_used = final_memory_mb - initial_memory_mb
|
||
val_memory_used = final_memory_mb - train_memory_mb
|
||
print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM")
|
||
print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets")
|
||
|
||
self.logger.info("Successfully initialized TensorFlow datasets")
|
||
|
||
def _build_model(self) -> TripleGRUDecoder:
|
||
"""Build the TripleGRUDecoder model"""
|
||
model = TripleGRUDecoder(
|
||
neural_dim=self.args['model']['n_input_features'],
|
||
n_units=self.args['model']['n_units'],
|
||
n_days=len(self.args['dataset']['sessions']),
|
||
n_classes=self.args['dataset']['n_classes'],
|
||
rnn_dropout=self.args['model']['rnn_dropout'],
|
||
input_dropout=self.args['model']['input_network']['input_layer_dropout'],
|
||
patch_size=self.args['model']['patch_size'],
|
||
patch_stride=self.args['model']['patch_stride']
|
||
)
|
||
return model
|
||
|
||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||
"""Create AdamW optimizer"""
|
||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||
print("Using AdamW optimizer for TPU training")
|
||
|
||
optimizer = tf.keras.optimizers.AdamW(
|
||
learning_rate=self.args['lr_max'],
|
||
beta_1=self.args['beta0'],
|
||
beta_2=self.args['beta1'],
|
||
epsilon=self.args['epsilon'],
|
||
weight_decay=self.args.get('weight_decay', 0.0)
|
||
)
|
||
print("✅ Using AdamW optimizer")
|
||
|
||
return optimizer
|
||
|
||
def _create_lr_scheduler(self):
|
||
"""Create learning rate scheduler"""
|
||
if self.args['lr_scheduler_type'] == 'cosine':
|
||
return self._create_cosine_scheduler()
|
||
elif self.args['lr_scheduler_type'] == 'linear':
|
||
return tf.keras.optimizers.schedules.PolynomialDecay(
|
||
initial_learning_rate=self.args['lr_max'],
|
||
decay_steps=self.args['lr_decay_steps'],
|
||
end_learning_rate=self.args['lr_min'],
|
||
power=1.0 # Linear decay
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
|
||
|
||
def _create_cosine_scheduler(self):
|
||
"""Create cosine learning rate scheduler"""
|
||
return tf.keras.optimizers.schedules.CosineDecayRestarts(
|
||
initial_learning_rate=self.args['lr_max'],
|
||
first_decay_steps=self.args['lr_decay_steps'],
|
||
t_mul=1.0,
|
||
m_mul=1.0,
|
||
alpha=self.args['lr_min'] / self.args['lr_max']
|
||
)
|
||
|
||
def _log_model_info(self):
|
||
"""Log model architecture and parameter information"""
|
||
if self.logger:
|
||
self.logger.info("Initialized TripleGRUDecoder model")
|
||
else:
|
||
print("Initialized TripleGRUDecoder model")
|
||
|
||
# Build the model by calling it once with dummy data
|
||
dummy_batch_size = 2
|
||
dummy_time_steps = 100
|
||
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
|
||
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
|
||
|
||
# Call the model to build it
|
||
_ = self.model(dummy_features, dummy_day_idx, training=False)
|
||
|
||
# Count parameters
|
||
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
||
|
||
if self.logger:
|
||
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
||
else:
|
||
print(f"Model has {total_params:,} trainable parameters")
|
||
|
||
@tf.function
|
||
def _train_step(self, batch, step):
|
||
"""Single training step with gradient tape"""
|
||
features = batch['input_features']
|
||
labels = batch['seq_class_ids']
|
||
n_time_steps = batch['n_time_steps']
|
||
phone_seq_lens = batch['phone_seq_lens']
|
||
day_indices = batch['day_indices']
|
||
|
||
with tf.GradientTape() as tape:
|
||
# Apply data transformations
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
|
||
)
|
||
|
||
# Calculate adjusted lengths for CTC
|
||
adjusted_lens = tf.cast(
|
||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||
self.args['model']['patch_stride'] + 1,
|
||
tf.int32
|
||
)
|
||
|
||
# Forward pass
|
||
use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
|
||
if use_full:
|
||
clean_logits, noisy_logits, noise_output = self.model(
|
||
features, day_indices, None, False, 'full',
|
||
grl_lambda=self.adv_grl_lambda, training=True
|
||
)
|
||
else:
|
||
clean_logits = self.model(
|
||
features, day_indices, None, False, 'inference', training=True
|
||
)
|
||
|
||
# Calculate losses
|
||
if use_full:
|
||
# Clean CTC loss
|
||
clean_loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
|
||
clean_loss = tf.reduce_mean(clean_loss)
|
||
|
||
# Noisy CTC loss
|
||
noisy_loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
|
||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||
|
||
# Optional noise L2 regularization
|
||
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
||
if self.adv_noise_l2_weight > 0.0:
|
||
noise_l2 = tf.reduce_mean(tf.square(noise_output))
|
||
|
||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||
else:
|
||
loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
loss = self.ctc_loss(loss_input, clean_logits)
|
||
loss = tf.reduce_mean(loss)
|
||
|
||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||
|
||
# Calculate gradients - TensorFlow自动处理混合精度
|
||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||
|
||
# For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically)
|
||
# This ensures consistency with slot variable initialization
|
||
all_variables = self.model.trainable_variables
|
||
|
||
# Replace None gradients with zeros to maintain consistency
|
||
safe_gradients = []
|
||
for grad, var in zip(gradients, all_variables):
|
||
if grad is not None:
|
||
safe_gradients.append(grad)
|
||
else:
|
||
# Create zero gradient for variables without gradients
|
||
safe_gradients.append(tf.zeros_like(var))
|
||
|
||
# Clip gradients
|
||
if self.args['grad_norm_clip_value'] > 0:
|
||
safe_gradients, grad_norm = tf.clip_by_global_norm(
|
||
safe_gradients, self.args['grad_norm_clip_value']
|
||
)
|
||
else:
|
||
grad_norm = tf.global_norm(safe_gradients)
|
||
|
||
# Apply gradients to ALL variables (consistent with initialization)
|
||
# TensorFlow optimizer will handle zero gradients correctly
|
||
self.optimizer.apply_gradients(zip(safe_gradients, all_variables))
|
||
|
||
return loss, grad_norm
|
||
|
||
@tf.function
|
||
def _validation_step(self, batch):
|
||
"""Single validation step - returns data for accurate PER calculation"""
|
||
features = batch['input_features']
|
||
labels = batch['seq_class_ids']
|
||
n_time_steps = batch['n_time_steps']
|
||
phone_seq_lens = batch['phone_seq_lens']
|
||
day_indices = batch['day_indices']
|
||
|
||
# Apply data transformations (no augmentation for validation)
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||
)
|
||
|
||
# Calculate adjusted lengths
|
||
adjusted_lens = tf.cast(
|
||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||
self.args['model']['patch_stride'] + 1,
|
||
tf.int32
|
||
)
|
||
|
||
# Forward pass (inference mode only)
|
||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||
|
||
# Calculate loss
|
||
loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
loss = self.ctc_loss(loss_input, logits)
|
||
loss = tf.reduce_mean(loss)
|
||
|
||
# Greedy decoding for PER calculation
|
||
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||
|
||
# Return all necessary data for accurate PER calculation on CPU
|
||
return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens
|
||
|
||
def train(self) -> Dict[str, Any]:
|
||
"""Main training loop"""
|
||
self.logger.info("Starting training loop...")
|
||
|
||
# Log initial TPU status
|
||
initial_tpu_status = self._get_detailed_tpu_status()
|
||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||
|
||
# Create datasets using modern distribution API
|
||
def create_dist_dataset_fn(input_dataset_tf, training):
|
||
"""Create distributed dataset function for modern TPU strategy"""
|
||
def dataset_fn(input_context):
|
||
# create_input_fn returns a complete, batched tf.data.Dataset
|
||
return create_input_fn(
|
||
input_dataset_tf,
|
||
self.args['dataset']['data_transforms'],
|
||
training=training
|
||
)
|
||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||
|
||
# Distribute datasets using modern API
|
||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||
dist_start_time = time.time()
|
||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||
train_dist_time = time.time() - dist_start_time
|
||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||
|
||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||
val_start_time = time.time()
|
||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||
val_dist_time = time.time() - val_start_time
|
||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||
|
||
self.logger.info("Created distributed training and validation datasets")
|
||
# Training metrics
|
||
train_losses = []
|
||
val_losses = []
|
||
val_pers = []
|
||
val_results = []
|
||
val_steps_since_improvement = 0
|
||
self.logger.info("Training time count beginning...")
|
||
train_start_time = time.time()
|
||
|
||
# Training loop
|
||
step = 0
|
||
|
||
self.logger.info("🔄 Starting training loop...")
|
||
self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,")
|
||
self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn")
|
||
|
||
for batch in train_dist_dataset:
|
||
if step >= self.args['num_training_batches']:
|
||
self.logger.info("Reached maximum training batches, stopping training")
|
||
break
|
||
|
||
start_time = time.time()
|
||
|
||
# Distributed training step
|
||
self.logger.info("Running distributed training step...")
|
||
# Ensure we're in the correct TPU strategy scope
|
||
try:
|
||
with self.strategy.scope():
|
||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||
self._train_step, args=(batch, step)
|
||
)
|
||
except AttributeError as e:
|
||
if "merge_call" in str(e):
|
||
error_msg = f"Strategy merge_call error at step {step}: {e}"
|
||
print(error_msg)
|
||
if self.logger:
|
||
self.logger.error(error_msg)
|
||
self.logger.error("This indicates the strategy is not properly initialized")
|
||
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
|
||
else:
|
||
raise
|
||
|
||
# Reduce across replicas
|
||
self.logger.info("Reducing results across replicas...")
|
||
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
|
||
|
||
train_step_duration = time.time() - start_time
|
||
train_losses.append(float(loss.numpy()))
|
||
|
||
# Log training progress with TPU status
|
||
if step % self.args['batches_per_train_log'] == 0:
|
||
tpu_status = self._get_tpu_status()
|
||
self.logger.info(f'Train batch {step}: '
|
||
f'loss: {float(loss.numpy()):.2f} '
|
||
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
||
f'time: {train_step_duration:.3f}s '
|
||
f'| {tpu_status}')
|
||
|
||
# Validation step
|
||
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
||
self.logger.info(f"Running validation after training batch: {step}")
|
||
|
||
val_start_time = time.time()
|
||
val_metrics = self._validate(val_dist_dataset)
|
||
val_step_duration = time.time() - val_start_time
|
||
|
||
tpu_status = self._get_tpu_status()
|
||
self.logger.info(f'Val batch {step}: '
|
||
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
||
f'time: {val_step_duration:.3f}s '
|
||
f'| {tpu_status}')
|
||
|
||
val_pers.append(val_metrics['avg_per'])
|
||
val_losses.append(val_metrics['avg_loss'])
|
||
val_results.append(val_metrics)
|
||
|
||
# Check for improvement
|
||
new_best = False
|
||
if val_metrics['avg_per'] < self.best_val_per:
|
||
self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
|
||
self.best_val_per = val_metrics['avg_per']
|
||
self.best_val_loss = val_metrics['avg_loss']
|
||
new_best = True
|
||
elif (val_metrics['avg_per'] == self.best_val_per and
|
||
val_metrics['avg_loss'] < self.best_val_loss):
|
||
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||
self.best_val_loss = val_metrics['avg_loss']
|
||
new_best = True
|
||
|
||
if new_best:
|
||
if self.args.get('save_best_checkpoint', True):
|
||
self.logger.info("Checkpointing model")
|
||
self._save_checkpoint(step)
|
||
|
||
if self.args.get('save_val_metrics', True):
|
||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||
pickle.dump(val_metrics, f)
|
||
|
||
val_steps_since_improvement = 0
|
||
else:
|
||
val_steps_since_improvement += 1
|
||
|
||
# Optional save all validation checkpoints
|
||
if self.args.get('save_all_val_steps', False):
|
||
self._save_checkpoint(step)
|
||
|
||
# Early stopping
|
||
if (self.args.get('early_stopping', False) and
|
||
val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
|
||
self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
|
||
f'validation steps. Stopping training early at batch: {step}')
|
||
break
|
||
|
||
step += 1
|
||
|
||
# Training completed
|
||
training_duration = time.time() - train_start_time
|
||
self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
|
||
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||
|
||
# Save final model
|
||
if self.args.get('save_final_model', False):
|
||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||
self._save_checkpoint(step-1)
|
||
|
||
return {
|
||
'train_losses': train_losses,
|
||
'val_losses': val_losses,
|
||
'val_pers': val_pers,
|
||
'val_metrics': val_results
|
||
}
|
||
|
||
def _validate(self, val_dataset) -> Dict[str, Any]:
|
||
"""Run validation on entire validation dataset with accurate PER calculation"""
|
||
total_loss = 0.0
|
||
total_edit_distance = 0
|
||
total_seq_length = 0
|
||
num_batches = 0
|
||
|
||
for batch in val_dataset:
|
||
# Get predictions and labels from all TPU cores
|
||
per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = (
|
||
self.strategy.run(self._validation_step, args=(batch,))
|
||
)
|
||
|
||
# Reduce loss across replicas
|
||
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||
total_loss += float(batch_loss.numpy())
|
||
|
||
# Gather all data from TPU cores to CPU for accurate PER calculation
|
||
all_preds = self.strategy.gather(per_replica_preds, axis=0)
|
||
all_labels = self.strategy.gather(per_replica_labels, axis=0)
|
||
all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0)
|
||
all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0)
|
||
|
||
# Calculate accurate edit distance on CPU
|
||
batch_size = all_preds.shape[0]
|
||
for i in range(batch_size):
|
||
pred_len = int(all_pred_lens[i].numpy())
|
||
label_len = int(all_label_lens[i].numpy())
|
||
|
||
# Extract sequences and remove CTC blanks (assuming blank_index=0)
|
||
pred_seq = all_preds[i, :pred_len].numpy()
|
||
pred_seq = [p for p in pred_seq if p != 0] # Remove blanks
|
||
|
||
# Remove consecutive duplicates (CTC decoding)
|
||
if len(pred_seq) > 0:
|
||
deduped_pred = [pred_seq[0]]
|
||
for j in range(1, len(pred_seq)):
|
||
if pred_seq[j] != pred_seq[j-1]:
|
||
deduped_pred.append(pred_seq[j])
|
||
pred_seq = deduped_pred
|
||
|
||
true_seq = all_labels[i, :label_len].numpy().tolist()
|
||
|
||
# Calculate edit distance using proper library if available
|
||
if editdistance is not None:
|
||
edit_dist = editdistance.eval(pred_seq, true_seq)
|
||
else:
|
||
# Fallback to simple approximation if editdistance not available
|
||
edit_dist = self._simple_edit_distance(pred_seq, true_seq)
|
||
|
||
total_edit_distance += edit_dist
|
||
total_seq_length += label_len
|
||
|
||
num_batches += 1
|
||
|
||
avg_loss = total_loss / max(num_batches, 1)
|
||
avg_per = total_edit_distance / max(total_seq_length, 1e-6)
|
||
|
||
return {
|
||
'avg_loss': avg_loss,
|
||
'avg_per': avg_per,
|
||
'total_edit_distance': total_edit_distance,
|
||
'total_seq_length': total_seq_length,
|
||
'num_batches': num_batches
|
||
}
|
||
|
||
def _simple_edit_distance(self, seq1, seq2):
|
||
"""Simple edit distance implementation as fallback"""
|
||
# Dynamic programming implementation of edit distance
|
||
m, n = len(seq1), len(seq2)
|
||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||
|
||
# Initialize base cases
|
||
for i in range(m + 1):
|
||
dp[i][0] = i
|
||
for j in range(n + 1):
|
||
dp[0][j] = j
|
||
|
||
# Fill the DP table
|
||
for i in range(1, m + 1):
|
||
for j in range(1, n + 1):
|
||
if seq1[i-1] == seq2[j-1]:
|
||
dp[i][j] = dp[i-1][j-1]
|
||
else:
|
||
dp[i][j] = 1 + min(
|
||
dp[i-1][j], # deletion
|
||
dp[i][j-1], # insertion
|
||
dp[i-1][j-1] # substitution
|
||
)
|
||
|
||
return dp[m][n]
|
||
|
||
def _save_checkpoint(self, step: int, name: str = ""):
|
||
"""Save checkpoint using the unified CheckpointManager"""
|
||
# CheckpointManager automatically handles naming and numbering
|
||
# The 'name' parameter is kept for backward compatibility but not used
|
||
save_path = self.ckpt_manager.save(checkpoint_number=step)
|
||
|
||
if self.logger:
|
||
self.logger.info(f"Saved checkpoint for step {step}: {save_path}")
|
||
else:
|
||
print(f"Saved checkpoint for step {step}: {save_path}")
|
||
|
||
# Save non-TensorFlow Python state separately
|
||
state = {
|
||
'step': step,
|
||
'best_val_per': float(self.best_val_per),
|
||
'best_val_loss': float(self.best_val_loss)
|
||
}
|
||
|
||
# Associate state file with checkpoint
|
||
state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json')
|
||
with open(state_path, 'w') as f:
|
||
json.dump(state, f)
|
||
|
||
# Save config file (only once)
|
||
config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml')
|
||
if not os.path.exists(config_path):
|
||
with open(config_path, 'w') as f:
|
||
OmegaConf.save(config=self.args, f=f)
|
||
|
||
def load_checkpoint(self, checkpoint_path: str):
|
||
"""Load a specific checkpoint and its associated training state"""
|
||
if self.logger:
|
||
self.logger.info(f"Loading checkpoint from: {checkpoint_path}")
|
||
else:
|
||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||
|
||
# Restore TensorFlow objects (model and optimizer)
|
||
self.ckpt.restore(checkpoint_path).expect_partial()
|
||
|
||
# Restore non-TensorFlow training state
|
||
try:
|
||
# Extract step number from checkpoint path (e.g., ckpt-123 -> 123)
|
||
step = int(checkpoint_path.split('-')[-1])
|
||
state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json')
|
||
|
||
with open(state_path, 'r') as f:
|
||
state = json.load(f)
|
||
|
||
self.best_val_per = state['best_val_per']
|
||
self.best_val_loss = state['best_val_loss']
|
||
|
||
if self.logger:
|
||
self.logger.info(f"Restored training state from: {state_path}")
|
||
else:
|
||
print(f"Restored training state from: {state_path}")
|
||
|
||
except (IOError, ValueError, KeyError) as e:
|
||
warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. "
|
||
f"Starting with fresh state. Error: {e}")
|
||
if self.logger:
|
||
self.logger.warning(warning_msg)
|
||
else:
|
||
print(f"⚠️ {warning_msg}")
|
||
|
||
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
||
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
||
"""
|
||
Run inference on input features
|
||
|
||
Args:
|
||
features: Input neural features [batch_size, time_steps, features]
|
||
day_indices: Day indices [batch_size]
|
||
n_time_steps: Number of valid time steps [batch_size]
|
||
mode: 'inference' or 'full'
|
||
|
||
Returns:
|
||
Phoneme logits [batch_size, time_steps, n_classes]
|
||
"""
|
||
# Apply data transformations (no augmentation)
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||
)
|
||
|
||
# Run model inference
|
||
logits = self.model(features, day_indices, None, False, mode, training=False)
|
||
|
||
return logits |