f
This commit is contained in:
@@ -100,8 +100,13 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def init_optimizer_slots():
|
def init_optimizer_slots():
|
||||||
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
# Use ALL trainable variables for slot initialization, not just filtered ones
|
||||||
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
# This ensures slot variables are created for all variables that might need gradients
|
||||||
|
all_variables = self.model.trainable_variables
|
||||||
|
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
|
||||||
|
|
||||||
|
# Apply gradients for all variables to ensure all slots are created
|
||||||
|
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
|
||||||
return tf.constant(True) # Return something to satisfy strategy.run
|
return tf.constant(True) # Return something to satisfy strategy.run
|
||||||
|
|
||||||
# Run the slot initialization in replica context
|
# Run the slot initialization in replica context
|
||||||
@@ -583,34 +588,30 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Calculate gradients - TensorFlow自动处理混合精度
|
# Calculate gradients - TensorFlow自动处理混合精度
|
||||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||||
|
|
||||||
# Filter out None gradients (for h0 variables that don't need gradients)
|
# For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically)
|
||||||
filtered_grads_and_vars = []
|
# This ensures consistency with slot variable initialization
|
||||||
for grad, var in zip(gradients, self.model.trainable_variables):
|
all_variables = self.model.trainable_variables
|
||||||
if grad is not None:
|
|
||||||
filtered_grads_and_vars.append((grad, var))
|
|
||||||
else:
|
|
||||||
# Log which variables don't have gradients (informational)
|
|
||||||
tf.print(f"No gradient for variable: {var.name}")
|
|
||||||
|
|
||||||
# Extract filtered gradients and variables
|
# Replace None gradients with zeros to maintain consistency
|
||||||
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
|
safe_gradients = []
|
||||||
filtered_variables = [var for _, var in filtered_grads_and_vars]
|
for grad, var in zip(gradients, all_variables):
|
||||||
|
if grad is not None:
|
||||||
|
safe_gradients.append(grad)
|
||||||
|
else:
|
||||||
|
# Create zero gradient for variables without gradients
|
||||||
|
safe_gradients.append(tf.zeros_like(var))
|
||||||
|
|
||||||
# Clip gradients
|
# Clip gradients
|
||||||
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
|
if self.args['grad_norm_clip_value'] > 0:
|
||||||
filtered_gradients, grad_norm = tf.clip_by_global_norm(
|
safe_gradients, grad_norm = tf.clip_by_global_norm(
|
||||||
filtered_gradients, self.args['grad_norm_clip_value']
|
safe_gradients, self.args['grad_norm_clip_value']
|
||||||
)
|
)
|
||||||
elif len(filtered_gradients) > 0:
|
|
||||||
grad_norm = tf.global_norm(filtered_gradients)
|
|
||||||
else:
|
else:
|
||||||
grad_norm = tf.constant(0.0)
|
grad_norm = tf.global_norm(safe_gradients)
|
||||||
|
|
||||||
# Apply gradients (only for variables that have gradients)
|
# Apply gradients to ALL variables (consistent with initialization)
|
||||||
if len(filtered_gradients) > 0:
|
# TensorFlow optimizer will handle zero gradients correctly
|
||||||
# Apply gradients directly - optimizer should be pre-built and ready
|
self.optimizer.apply_gradients(zip(safe_gradients, all_variables))
|
||||||
# In @tf.function, we need to keep error handling simple
|
|
||||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
|
||||||
|
|
||||||
return loss, grad_norm
|
return loss, grad_norm
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user