fix 'NoneType' object has no attribute 'extended'

This commit is contained in:
Zchen
2025-10-16 20:57:40 +08:00
parent 1e7077bba7
commit ed6e21bfe4

View File

@@ -517,16 +517,32 @@ class BrainToTextDecoderTrainerTF:
# Calculate gradients - TensorFlow自动处理混合精度
gradients = tape.gradient(loss, self.model.trainable_variables)
# Clip gradients
if self.args['grad_norm_clip_value'] > 0:
gradients, grad_norm = tf.clip_by_global_norm(
gradients, self.args['grad_norm_clip_value']
)
else:
grad_norm = tf.global_norm(gradients)
# Filter out None gradients (for h0 variables that don't need gradients)
filtered_grads_and_vars = []
for grad, var in zip(gradients, self.model.trainable_variables):
if grad is not None:
filtered_grads_and_vars.append((grad, var))
else:
# Log which variables don't have gradients (informational)
tf.print(f"No gradient for variable: {var.name}")
# Apply gradients
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
# Extract filtered gradients and variables
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
filtered_variables = [var for _, var in filtered_grads_and_vars]
# Clip gradients
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
filtered_gradients, grad_norm = tf.clip_by_global_norm(
filtered_gradients, self.args['grad_norm_clip_value']
)
elif len(filtered_gradients) > 0:
grad_norm = tf.global_norm(filtered_gradients)
else:
grad_norm = tf.constant(0.0)
# Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0:
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
return loss, grad_norm
@@ -663,9 +679,11 @@ class BrainToTextDecoderTrainerTF:
# Distributed training step
self.logger.info("Running distributed training step...")
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Ensure we're in the correct TPU strategy scope
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Reduce across replicas
self.logger.info("Reducing results across replicas...")