From eb4e3fc69f5a832d0a260e2be03a5346cf4d083a Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:38:57 +0800 Subject: [PATCH] fff --- model_training_nnn_tpu/trainer_tf.py | 320 +++++++++++++++------------ 1 file changed, 183 insertions(+), 137 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index d1a3a8c..32feeb0 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -10,6 +10,13 @@ import sys from typing import Dict, Any, Tuple, Optional, List from omegaconf import OmegaConf +# For accurate PER calculation +try: + import editdistance +except ImportError: + print("Warning: editdistance not available, falling back to approximation") + editdistance = None + from rnn_model_tf import ( TripleGRUDecoder, CTCLoss, @@ -99,6 +106,30 @@ class BrainToTextDecoderTrainerTF: self.lr_scheduler = self._create_lr_scheduler() self.ctc_loss = CTCLoss(blank_index=0, reduction='none') + # Create unified checkpoint management + self.ckpt = tf.train.Checkpoint( + optimizer=self.optimizer, + model=self.model + ) + self.ckpt_manager = tf.train.CheckpointManager( + self.ckpt, + directory=self.args['checkpoint_dir'], + max_to_keep=5 # Keep only the 5 most recent checkpoints + ) + + # Try to restore from latest checkpoint + if self.ckpt_manager.latest_checkpoint: + print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}") + if self.logger: + self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}") + # expect_partial() avoids failures due to model structure changes + self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial() + print("✅ Checkpoint restored successfully") + else: + print("🆕 Initializing training from scratch") + if self.logger: + self.logger.info("Initializing training from scratch") + # Log model information self._log_model_info() @@ -110,12 +141,10 @@ class BrainToTextDecoderTrainerTF: self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) - # Manual weight decay handling for all environments (since we use Adam) + # Manual weight decay handling - disable since AdamW handles it self.manual_weight_decay = False if self.args.get('weight_decay', 0.0) > 0: - self.manual_weight_decay = True - self.weight_decay_rate = self.args['weight_decay'] - print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") + print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}") else: print("💡 No weight decay configured") @@ -402,39 +431,18 @@ class BrainToTextDecoderTrainerTF: return model def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: - """Create AdamW optimizer with parameter groups""" - # Note: TensorFlow doesn't have the same parameter group functionality as PyTorch - # We'll use a single optimizer and handle different learning rates in the scheduler - + """Create AdamW optimizer""" print(f"Creating optimizer with strategy: {type(self.strategy).__name__}") + print("Using AdamW optimizer for TPU training") - # For TPU training, we need to be more explicit about optimizer configuration - # to avoid strategy context issues - # IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs - # AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0 - # We implement manual L2 regularization (weight decay) in the training step instead - print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") - print("💡 Manual L2 regularization will be applied in training step") - - # Use legacy Adam optimizer for better TPU distributed training compatibility - # Legacy optimizers have more stable distributed training implementations - try: - optimizer = tf.keras.optimizers.legacy.Adam( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'] - ) - print("✅ Using legacy Adam optimizer for better TPU compatibility") - except AttributeError: - # Fallback to standard Adam if legacy is not available - optimizer = tf.keras.optimizers.Adam( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'] - ) - print("⚠️ Using standard Adam optimizer (legacy not available)") + optimizer = tf.keras.optimizers.AdamW( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'], + weight_decay=self.args.get('weight_decay', 0.0) + ) + print("✅ Using AdamW optimizer") return optimizer @@ -486,6 +494,7 @@ class BrainToTextDecoderTrainerTF: else: print(f"Model has {total_params:,} trainable parameters") + @tf.function def _train_step(self, batch, step): """Single training step with gradient tape""" features = batch['input_features'] @@ -554,16 +563,7 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) - # Add manual L2 regularization for TPU (since weight_decay is disabled) - if self.manual_weight_decay: - l2_loss = tf.constant(0.0, dtype=loss.dtype) - for var in self.model.trainable_variables: - # Ensure dtype consistency for mixed precision training - var_l2 = tf.nn.l2_loss(var) - var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype - l2_loss += var_l2 - loss += self.weight_decay_rate * l2_loss - + # AdamW handles weight decay automatically - no manual L2 regularization needed # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 # TPU v5e-8使用bfloat16,不需要loss scaling @@ -599,7 +599,7 @@ class BrainToTextDecoderTrainerTF: @tf.function def _validation_step(self, batch): - """Single validation step""" + """Single validation step - returns data for accurate PER calculation""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] @@ -630,38 +630,11 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, logits) loss = tf.reduce_mean(loss) - # Calculate simplified PER approximation (TPU-compatible) - # For TPU training, we use a simplified metric that avoids complex loops - # This gives an approximation of PER but is much faster and TPU-compatible + # Greedy decoding for PER calculation + predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) - # Greedy decoding - predicted_ids = tf.argmax(logits, axis=-1) - - # Simple approximation: count exact matches vs mismatches - # This is less accurate than true edit distance but TPU-compatible - batch_size = tf.shape(logits)[0] - - # For each sample, compare predicted vs true sequences - total_mismatches = tf.constant(0, dtype=tf.int32) - - for i in tf.range(batch_size): - # Get sequences for this sample - pred_seq = predicted_ids[i, :adjusted_lens[i]] - true_seq = labels[i, :phone_seq_lens[i]] - - # Pad to same length for comparison - max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0]) - pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0) - true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0) - - # Count mismatches - mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32)) - total_mismatches += mismatches - - # Approximate edit distance as number of mismatches - batch_edit_distance = tf.cast(total_mismatches, tf.float32) - - return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32) + # Return all necessary data for accurate PER calculation on CPU + return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens def train(self) -> Dict[str, Any]: """Main training loop""" @@ -671,28 +644,28 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # Create distributed datasets - train_dataset = create_input_fn( - self.train_dataset_tf, - self.args['dataset']['data_transforms'], - training=True - ) - - val_dataset = create_input_fn( - self.val_dataset_tf, - self.args['dataset']['data_transforms'], - training=False - ) - # Distribute datasets with timing + # Create datasets using modern distribution API + def create_dist_dataset_fn(input_dataset_tf, training): + """Create distributed dataset function for modern TPU strategy""" + def dataset_fn(input_context): + # create_input_fn returns a complete, batched tf.data.Dataset + return create_input_fn( + input_dataset_tf, + self.args['dataset']['data_transforms'], + training=training + ) + return self.strategy.distribute_datasets_from_function(dataset_fn) + + # Distribute datasets using modern API self.logger.info("🔄 Distributing training dataset across TPU cores...") dist_start_time = time.time() - train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset) + train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info("🔄 Distributing validation dataset across TPU cores...") val_start_time = time.time() - val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset) + val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_time = time.time() - val_start_time self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s") @@ -709,18 +682,13 @@ class BrainToTextDecoderTrainerTF: # Training loop step = 0 - # Add timing diagnostic for first batch iteration - self.logger.info("🔄 Starting training loop iteration...") - loop_start_time = time.time() + self.logger.info("🔄 Starting training loop...") for batch in train_dist_dataset: - if step == 0: - first_batch_iteration_time = time.time() - loop_start_time - self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s") if step >= self.args['num_training_batches']: self.logger.info("Reached maximum training batches, stopping training") break - + start_time = time.time() # Distributed training step @@ -794,7 +762,7 @@ class BrainToTextDecoderTrainerTF: if new_best: if self.args.get('save_best_checkpoint', True): self.logger.info("Checkpointing model") - self._save_checkpoint('best_checkpoint', step) + self._save_checkpoint(step) if self.args.get('save_val_metrics', True): with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: @@ -806,7 +774,7 @@ class BrainToTextDecoderTrainerTF: # Optional save all validation checkpoints if self.args.get('save_all_val_steps', False): - self._save_checkpoint(f'checkpoint_batch_{step}', step) + self._save_checkpoint(step) # Early stopping if (self.args.get('early_stopping', False) and @@ -825,7 +793,7 @@ class BrainToTextDecoderTrainerTF: # Save final model if self.args.get('save_final_model', False): last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') - self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1) + self._save_checkpoint(step-1) return { 'train_losses': train_losses, @@ -835,25 +803,58 @@ class BrainToTextDecoderTrainerTF: } def _validate(self, val_dataset) -> Dict[str, Any]: - """Run validation on entire validation dataset""" + """Run validation on entire validation dataset with accurate PER calculation""" total_loss = 0.0 total_edit_distance = 0 total_seq_length = 0 num_batches = 0 for batch in val_dataset: - per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = ( + # Get predictions and labels from all TPU cores + per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = ( self.strategy.run(self._validation_step, args=(batch,)) ) - # Reduce across replicas + # Reduce loss across replicas batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) - batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None) - batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None) - total_loss += float(batch_loss.numpy()) - total_edit_distance += float(batch_edit_distance.numpy()) - total_seq_length += float(batch_seq_length.numpy()) + + # Gather all data from TPU cores to CPU for accurate PER calculation + all_preds = self.strategy.gather(per_replica_preds, axis=0) + all_labels = self.strategy.gather(per_replica_labels, axis=0) + all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0) + all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0) + + # Calculate accurate edit distance on CPU + batch_size = all_preds.shape[0] + for i in range(batch_size): + pred_len = int(all_pred_lens[i].numpy()) + label_len = int(all_label_lens[i].numpy()) + + # Extract sequences and remove CTC blanks (assuming blank_index=0) + pred_seq = all_preds[i, :pred_len].numpy() + pred_seq = [p for p in pred_seq if p != 0] # Remove blanks + + # Remove consecutive duplicates (CTC decoding) + if len(pred_seq) > 0: + deduped_pred = [pred_seq[0]] + for j in range(1, len(pred_seq)): + if pred_seq[j] != pred_seq[j-1]: + deduped_pred.append(pred_seq[j]) + pred_seq = deduped_pred + + true_seq = all_labels[i, :label_len].numpy().tolist() + + # Calculate edit distance using proper library if available + if editdistance is not None: + edit_dist = editdistance.eval(pred_seq, true_seq) + else: + # Fallback to simple approximation if editdistance not available + edit_dist = self._simple_edit_distance(pred_seq, true_seq) + + total_edit_distance += edit_dist + total_seq_length += label_len + num_batches += 1 avg_loss = total_loss / max(num_batches, 1) @@ -867,50 +868,95 @@ class BrainToTextDecoderTrainerTF: 'num_batches': num_batches } - def _save_checkpoint(self, name: str, step: int): - """Save model checkpoint""" - checkpoint_path = os.path.join(self.args['checkpoint_dir'], name) + def _simple_edit_distance(self, seq1, seq2): + """Simple edit distance implementation as fallback""" + # Dynamic programming implementation of edit distance + m, n = len(seq1), len(seq2) + dp = [[0] * (n + 1) for _ in range(m + 1)] - # Save model weights - self.model.save_weights(checkpoint_path + '.weights.h5') + # Initialize base cases + for i in range(m + 1): + dp[i][0] = i + for j in range(n + 1): + dp[0][j] = j - # Save optimizer state - optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) - optimizer_checkpoint.save(checkpoint_path + '.optimizer') + # Fill the DP table + for i in range(1, m + 1): + for j in range(1, n + 1): + if seq1[i-1] == seq2[j-1]: + dp[i][j] = dp[i-1][j-1] + else: + dp[i][j] = 1 + min( + dp[i-1][j], # deletion + dp[i][j-1], # insertion + dp[i-1][j-1] # substitution + ) - # Save training state + return dp[m][n] + + def _save_checkpoint(self, step: int, name: str = ""): + """Save checkpoint using the unified CheckpointManager""" + # CheckpointManager automatically handles naming and numbering + # The 'name' parameter is kept for backward compatibility but not used + save_path = self.ckpt_manager.save(checkpoint_number=step) + + if self.logger: + self.logger.info(f"Saved checkpoint for step {step}: {save_path}") + else: + print(f"Saved checkpoint for step {step}: {save_path}") + + # Save non-TensorFlow Python state separately state = { 'step': step, 'best_val_per': float(self.best_val_per), 'best_val_loss': float(self.best_val_loss) } - with open(checkpoint_path + '.state.json', 'w') as f: + # Associate state file with checkpoint + state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json') + with open(state_path, 'w') as f: json.dump(state, f) - # Save config - with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: - OmegaConf.save(config=self.args, f=f) - - self.logger.info(f"Saved checkpoint: {checkpoint_path}") + # Save config file (only once) + config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml') + if not os.path.exists(config_path): + with open(config_path, 'w') as f: + OmegaConf.save(config=self.args, f=f) def load_checkpoint(self, checkpoint_path: str): - """Load model checkpoint""" - # Load model weights - self.model.load_weights(checkpoint_path + '.weights.h5') + """Load a specific checkpoint and its associated training state""" + if self.logger: + self.logger.info(f"Loading checkpoint from: {checkpoint_path}") + else: + print(f"Loading checkpoint from: {checkpoint_path}") - # Load optimizer state - optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) - optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1') + # Restore TensorFlow objects (model and optimizer) + self.ckpt.restore(checkpoint_path).expect_partial() - # Load training state - with open(checkpoint_path + '.state.json', 'r') as f: - state = json.load(f) + # Restore non-TensorFlow training state + try: + # Extract step number from checkpoint path (e.g., ckpt-123 -> 123) + step = int(checkpoint_path.split('-')[-1]) + state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json') - self.best_val_per = state['best_val_per'] - self.best_val_loss = state['best_val_loss'] + with open(state_path, 'r') as f: + state = json.load(f) - self.logger.info(f"Loaded checkpoint: {checkpoint_path}") + self.best_val_per = state['best_val_per'] + self.best_val_loss = state['best_val_loss'] + + if self.logger: + self.logger.info(f"Restored training state from: {state_path}") + else: + print(f"Restored training state from: {state_path}") + + except (IOError, ValueError, KeyError) as e: + warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. " + f"Starting with fresh state. Error: {e}") + if self.logger: + self.logger.warning(warning_msg) + else: + print(f"⚠️ {warning_msg}") def inference(self, features: tf.Tensor, day_indices: tf.Tensor, n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor: