fix
This commit is contained in:
@@ -102,18 +102,25 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
try:
|
try:
|
||||||
# Check if strategy is properly initialized before applying gradients
|
# Check if strategy is properly initialized before applying gradients
|
||||||
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
|
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
|
||||||
|
print("✅ Strategy has merge_call, building optimizer properly...")
|
||||||
|
|
||||||
|
# Build optimizer by explicitly calling build method
|
||||||
|
self.optimizer.build(self.model.trainable_variables)
|
||||||
|
print("✅ Optimizer built with model variables")
|
||||||
|
|
||||||
|
# Test with dummy gradients to ensure everything works
|
||||||
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
||||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
||||||
print("✅ Optimizer state pre-built successfully with TPU strategy")
|
print("✅ Optimizer state pre-built successfully with TPU strategy")
|
||||||
else:
|
else:
|
||||||
# Fallback: just build optimizer variables without applying gradients
|
# Fallback: just build optimizer variables without applying gradients
|
||||||
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
|
print("⚠️ Strategy not fully initialized, using fallback optimizer build")
|
||||||
# Alternative: trigger optimizer variable creation
|
# Force build the optimizer with the model variables
|
||||||
_ = self.optimizer.iterations
|
self.optimizer.build(self.model.trainable_variables)
|
||||||
print("✅ Optimizer state initialized (fallback mode)")
|
print("✅ Optimizer built in fallback mode")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
|
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
|
||||||
print("✅ Continuing without optimizer pre-build")
|
print("✅ Continuing without optimizer pre-build - optimizer will build during first training step")
|
||||||
|
|
||||||
print("📅 Setting up learning rate scheduler...")
|
print("📅 Setting up learning rate scheduler...")
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
@@ -414,6 +421,9 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
"""Create AdamW optimizer with parameter groups"""
|
"""Create AdamW optimizer with parameter groups"""
|
||||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||||||
|
|
||||||
|
# Create optimizer within strategy scope to ensure proper initialization
|
||||||
|
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||||||
optimizer = tf.keras.optimizers.AdamW(
|
optimizer = tf.keras.optimizers.AdamW(
|
||||||
learning_rate=self.args['lr_max'],
|
learning_rate=self.args['lr_max'],
|
||||||
beta_1=self.args['beta0'],
|
beta_1=self.args['beta0'],
|
||||||
@@ -565,7 +575,18 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Apply gradients (only for variables that have gradients)
|
# Apply gradients (only for variables that have gradients)
|
||||||
if len(filtered_gradients) > 0:
|
if len(filtered_gradients) > 0:
|
||||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
# Ensure we're in the strategy scope when applying gradients
|
||||||
|
# This prevents the 'NoneType' extended attribute error
|
||||||
|
try:
|
||||||
|
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||||
|
except AttributeError as e:
|
||||||
|
if "'NoneType' object has no attribute 'extended'" in str(e):
|
||||||
|
# Strategy context was lost, this should not happen in a @tf.function
|
||||||
|
tf.print(f"ERROR: Strategy context lost during gradient application: {e}")
|
||||||
|
tf.print("This indicates a serious issue with the distributed training setup")
|
||||||
|
raise RuntimeError(f"Strategy context lost during training: {e}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
return loss, grad_norm
|
return loss, grad_norm
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user