fix 'NoneType' object has no attribute 'extended'

This commit is contained in:
Zchen
2025-10-16 20:57:40 +08:00
parent 1e7077bba7
commit ed6e21bfe4

View File

@@ -517,16 +517,32 @@ class BrainToTextDecoderTrainerTF:
# Calculate gradients - TensorFlow自动处理混合精度 # Calculate gradients - TensorFlow自动处理混合精度
gradients = tape.gradient(loss, self.model.trainable_variables) gradients = tape.gradient(loss, self.model.trainable_variables)
# Clip gradients # Filter out None gradients (for h0 variables that don't need gradients)
if self.args['grad_norm_clip_value'] > 0: filtered_grads_and_vars = []
gradients, grad_norm = tf.clip_by_global_norm( for grad, var in zip(gradients, self.model.trainable_variables):
gradients, self.args['grad_norm_clip_value'] if grad is not None:
) filtered_grads_and_vars.append((grad, var))
else: else:
grad_norm = tf.global_norm(gradients) # Log which variables don't have gradients (informational)
tf.print(f"No gradient for variable: {var.name}")
# Apply gradients # Extract filtered gradients and variables
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
filtered_variables = [var for _, var in filtered_grads_and_vars]
# Clip gradients
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
filtered_gradients, grad_norm = tf.clip_by_global_norm(
filtered_gradients, self.args['grad_norm_clip_value']
)
elif len(filtered_gradients) > 0:
grad_norm = tf.global_norm(filtered_gradients)
else:
grad_norm = tf.constant(0.0)
# Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0:
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
return loss, grad_norm return loss, grad_norm
@@ -663,9 +679,11 @@ class BrainToTextDecoderTrainerTF:
# Distributed training step # Distributed training step
self.logger.info("Running distributed training step...") self.logger.info("Running distributed training step...")
per_replica_losses, per_replica_grad_norms = self.strategy.run( # Ensure we're in the correct TPU strategy scope
self._train_step, args=(batch, step) with self.strategy.scope():
) per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Reduce across replicas # Reduce across replicas
self.logger.info("Reducing results across replicas...") self.logger.info("Reducing results across replicas...")