fix 'NoneType' object has no attribute 'extended'
This commit is contained in:
@@ -517,16 +517,32 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Calculate gradients - TensorFlow自动处理混合精度
|
||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||
|
||||
# Clip gradients
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
gradients, grad_norm = tf.clip_by_global_norm(
|
||||
gradients, self.args['grad_norm_clip_value']
|
||||
)
|
||||
else:
|
||||
grad_norm = tf.global_norm(gradients)
|
||||
# Filter out None gradients (for h0 variables that don't need gradients)
|
||||
filtered_grads_and_vars = []
|
||||
for grad, var in zip(gradients, self.model.trainable_variables):
|
||||
if grad is not None:
|
||||
filtered_grads_and_vars.append((grad, var))
|
||||
else:
|
||||
# Log which variables don't have gradients (informational)
|
||||
tf.print(f"No gradient for variable: {var.name}")
|
||||
|
||||
# Apply gradients
|
||||
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
||||
# Extract filtered gradients and variables
|
||||
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
|
||||
filtered_variables = [var for _, var in filtered_grads_and_vars]
|
||||
|
||||
# Clip gradients
|
||||
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
|
||||
filtered_gradients, grad_norm = tf.clip_by_global_norm(
|
||||
filtered_gradients, self.args['grad_norm_clip_value']
|
||||
)
|
||||
elif len(filtered_gradients) > 0:
|
||||
grad_norm = tf.global_norm(filtered_gradients)
|
||||
else:
|
||||
grad_norm = tf.constant(0.0)
|
||||
|
||||
# Apply gradients (only for variables that have gradients)
|
||||
if len(filtered_gradients) > 0:
|
||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||
|
||||
return loss, grad_norm
|
||||
|
||||
@@ -663,9 +679,11 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Distributed training step
|
||||
self.logger.info("Running distributed training step...")
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
# Ensure we're in the correct TPU strategy scope
|
||||
with self.strategy.scope():
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
|
||||
# Reduce across replicas
|
||||
self.logger.info("Reducing results across replicas...")
|
||||
|
Reference in New Issue
Block a user