diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 1917585..d5dfb91 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -517,16 +517,32 @@ class BrainToTextDecoderTrainerTF: # Calculate gradients - TensorFlow自动处理混合精度 gradients = tape.gradient(loss, self.model.trainable_variables) - # Clip gradients - if self.args['grad_norm_clip_value'] > 0: - gradients, grad_norm = tf.clip_by_global_norm( - gradients, self.args['grad_norm_clip_value'] - ) - else: - grad_norm = tf.global_norm(gradients) + # Filter out None gradients (for h0 variables that don't need gradients) + filtered_grads_and_vars = [] + for grad, var in zip(gradients, self.model.trainable_variables): + if grad is not None: + filtered_grads_and_vars.append((grad, var)) + else: + # Log which variables don't have gradients (informational) + tf.print(f"No gradient for variable: {var.name}") - # Apply gradients - self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) + # Extract filtered gradients and variables + filtered_gradients = [grad for grad, _ in filtered_grads_and_vars] + filtered_variables = [var for _, var in filtered_grads_and_vars] + + # Clip gradients + if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0: + filtered_gradients, grad_norm = tf.clip_by_global_norm( + filtered_gradients, self.args['grad_norm_clip_value'] + ) + elif len(filtered_gradients) > 0: + grad_norm = tf.global_norm(filtered_gradients) + else: + grad_norm = tf.constant(0.0) + + # Apply gradients (only for variables that have gradients) + if len(filtered_gradients) > 0: + self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables)) return loss, grad_norm @@ -663,9 +679,11 @@ class BrainToTextDecoderTrainerTF: # Distributed training step self.logger.info("Running distributed training step...") - per_replica_losses, per_replica_grad_norms = self.strategy.run( - self._train_step, args=(batch, step) - ) + # Ensure we're in the correct TPU strategy scope + with self.strategy.scope(): + per_replica_losses, per_replica_grad_norms = self.strategy.run( + self._train_step, args=(batch, step) + ) # Reduce across replicas self.logger.info("Reducing results across replicas...")