This commit is contained in:
Zchen
2025-10-16 21:26:00 +08:00
parent dde6378481
commit 426b72ef25

View File

@@ -102,18 +102,25 @@ class BrainToTextDecoderTrainerTF:
try:
# Check if strategy is properly initialized before applying gradients
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
print("✅ Strategy has merge_call, building optimizer properly...")
# Build optimizer by explicitly calling build method
self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built with model variables")
# Test with dummy gradients to ensure everything works
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state pre-built successfully with TPU strategy")
else:
# Fallback: just build optimizer variables without applying gradients
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
# Alternative: trigger optimizer variable creation
_ = self.optimizer.iterations
print("✅ Optimizer state initialized (fallback mode)")
print("⚠️ Strategy not fully initialized, using fallback optimizer build")
# Force build the optimizer with the model variables
self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built in fallback mode")
except Exception as e:
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
print("✅ Continuing without optimizer pre-build")
print("✅ Continuing without optimizer pre-build - optimizer will build during first training step")
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler()
@@ -414,6 +421,9 @@ class BrainToTextDecoderTrainerTF:
"""Create AdamW optimizer with parameter groups"""
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
# We'll use a single optimizer and handle different learning rates in the scheduler
# Create optimizer within strategy scope to ensure proper initialization
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
@@ -565,7 +575,18 @@ class BrainToTextDecoderTrainerTF:
# Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0:
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
# Ensure we're in the strategy scope when applying gradients
# This prevents the 'NoneType' extended attribute error
try:
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
except AttributeError as e:
if "'NoneType' object has no attribute 'extended'" in str(e):
# Strategy context was lost, this should not happen in a @tf.function
tf.print(f"ERROR: Strategy context lost during gradient application: {e}")
tf.print("This indicates a serious issue with the distributed training setup")
raise RuntimeError(f"Strategy context lost during training: {e}")
else:
raise
return loss, grad_norm