fix 'NoneType' object has no attribute 'extended'
This commit is contained in:
@@ -517,16 +517,32 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Calculate gradients - TensorFlow自动处理混合精度
|
# Calculate gradients - TensorFlow自动处理混合精度
|
||||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||||
|
|
||||||
# Clip gradients
|
# Filter out None gradients (for h0 variables that don't need gradients)
|
||||||
if self.args['grad_norm_clip_value'] > 0:
|
filtered_grads_and_vars = []
|
||||||
gradients, grad_norm = tf.clip_by_global_norm(
|
for grad, var in zip(gradients, self.model.trainable_variables):
|
||||||
gradients, self.args['grad_norm_clip_value']
|
if grad is not None:
|
||||||
)
|
filtered_grads_and_vars.append((grad, var))
|
||||||
else:
|
else:
|
||||||
grad_norm = tf.global_norm(gradients)
|
# Log which variables don't have gradients (informational)
|
||||||
|
tf.print(f"No gradient for variable: {var.name}")
|
||||||
|
|
||||||
# Apply gradients
|
# Extract filtered gradients and variables
|
||||||
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
|
||||||
|
filtered_variables = [var for _, var in filtered_grads_and_vars]
|
||||||
|
|
||||||
|
# Clip gradients
|
||||||
|
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
|
||||||
|
filtered_gradients, grad_norm = tf.clip_by_global_norm(
|
||||||
|
filtered_gradients, self.args['grad_norm_clip_value']
|
||||||
|
)
|
||||||
|
elif len(filtered_gradients) > 0:
|
||||||
|
grad_norm = tf.global_norm(filtered_gradients)
|
||||||
|
else:
|
||||||
|
grad_norm = tf.constant(0.0)
|
||||||
|
|
||||||
|
# Apply gradients (only for variables that have gradients)
|
||||||
|
if len(filtered_gradients) > 0:
|
||||||
|
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||||
|
|
||||||
return loss, grad_norm
|
return loss, grad_norm
|
||||||
|
|
||||||
@@ -663,6 +679,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Distributed training step
|
# Distributed training step
|
||||||
self.logger.info("Running distributed training step...")
|
self.logger.info("Running distributed training step...")
|
||||||
|
# Ensure we're in the correct TPU strategy scope
|
||||||
|
with self.strategy.scope():
|
||||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||||
self._train_step, args=(batch, step)
|
self._train_step, args=(batch, step)
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user