f
This commit is contained in:
@@ -100,8 +100,13 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
@tf.function
|
||||
def init_optimizer_slots():
|
||||
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
||||
# Use ALL trainable variables for slot initialization, not just filtered ones
|
||||
# This ensures slot variables are created for all variables that might need gradients
|
||||
all_variables = self.model.trainable_variables
|
||||
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
|
||||
|
||||
# Apply gradients for all variables to ensure all slots are created
|
||||
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
|
||||
return tf.constant(True) # Return something to satisfy strategy.run
|
||||
|
||||
# Run the slot initialization in replica context
|
||||
@@ -583,34 +588,30 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Calculate gradients - TensorFlow自动处理混合精度
|
||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||
|
||||
# Filter out None gradients (for h0 variables that don't need gradients)
|
||||
filtered_grads_and_vars = []
|
||||
for grad, var in zip(gradients, self.model.trainable_variables):
|
||||
if grad is not None:
|
||||
filtered_grads_and_vars.append((grad, var))
|
||||
else:
|
||||
# Log which variables don't have gradients (informational)
|
||||
tf.print(f"No gradient for variable: {var.name}")
|
||||
# For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically)
|
||||
# This ensures consistency with slot variable initialization
|
||||
all_variables = self.model.trainable_variables
|
||||
|
||||
# Extract filtered gradients and variables
|
||||
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
|
||||
filtered_variables = [var for _, var in filtered_grads_and_vars]
|
||||
# Replace None gradients with zeros to maintain consistency
|
||||
safe_gradients = []
|
||||
for grad, var in zip(gradients, all_variables):
|
||||
if grad is not None:
|
||||
safe_gradients.append(grad)
|
||||
else:
|
||||
# Create zero gradient for variables without gradients
|
||||
safe_gradients.append(tf.zeros_like(var))
|
||||
|
||||
# Clip gradients
|
||||
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
|
||||
filtered_gradients, grad_norm = tf.clip_by_global_norm(
|
||||
filtered_gradients, self.args['grad_norm_clip_value']
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
safe_gradients, grad_norm = tf.clip_by_global_norm(
|
||||
safe_gradients, self.args['grad_norm_clip_value']
|
||||
)
|
||||
elif len(filtered_gradients) > 0:
|
||||
grad_norm = tf.global_norm(filtered_gradients)
|
||||
else:
|
||||
grad_norm = tf.constant(0.0)
|
||||
grad_norm = tf.global_norm(safe_gradients)
|
||||
|
||||
# Apply gradients (only for variables that have gradients)
|
||||
if len(filtered_gradients) > 0:
|
||||
# Apply gradients directly - optimizer should be pre-built and ready
|
||||
# In @tf.function, we need to keep error handling simple
|
||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||
# Apply gradients to ALL variables (consistent with initialization)
|
||||
# TensorFlow optimizer will handle zero gradients correctly
|
||||
self.optimizer.apply_gradients(zip(safe_gradients, all_variables))
|
||||
|
||||
return loss, grad_norm
|
||||
|
||||
|
Reference in New Issue
Block a user