f
This commit is contained in:
@@ -90,38 +90,31 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
with self.strategy.scope():
|
with self.strategy.scope():
|
||||||
self.model = self._build_model()
|
self.model = self._build_model()
|
||||||
self.optimizer = self._create_optimizer()
|
self.optimizer = self._create_optimizer()
|
||||||
print("🔧 Pre-building optimizer state for TPU...")
|
print("🔧 Initializing optimizer for TPU training...")
|
||||||
# For TPU, we must ensure optimizer is completely ready before training
|
# For TPU, we initialize the optimizer by accessing its basic properties
|
||||||
# since @tf.function doesn't allow dynamic building
|
# The optimizer will be properly built when first used in training
|
||||||
try:
|
try:
|
||||||
print("✅ Building optimizer with model variables...")
|
print("✅ Checking optimizer initialization...")
|
||||||
|
|
||||||
# Explicitly build the optimizer with model variables
|
# Access optimizer properties to ensure it's properly initialized
|
||||||
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
|
# This is safe and works with all TensorFlow/Keras optimizer versions
|
||||||
self.optimizer.build(self.model.trainable_variables)
|
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||||||
print("✅ Optimizer built with model variables")
|
print(f"Learning rate: {self.optimizer.learning_rate}")
|
||||||
|
|
||||||
# Verify optimizer is properly built - just check iterations
|
# Access iterations to ensure optimizer state tracking is ready
|
||||||
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
# This creates the iterations variable without building the full state
|
||||||
|
iterations = self.optimizer.iterations
|
||||||
|
print(f"Optimizer iterations initialized: {iterations}")
|
||||||
|
|
||||||
# For TPU training, we should also ensure the optimizer has all its state
|
print("✅ Optimizer ready for TPU training")
|
||||||
# variables created. We can do this by creating dummy variables that match
|
print("📝 Note: Optimizer state will be built automatically during first training step")
|
||||||
# the model variables, but we don't apply them (avoid the replica context issue)
|
|
||||||
print("🔄 Ensuring optimizer state variables are created...")
|
|
||||||
|
|
||||||
# Force creation of optimizer variables by accessing them
|
|
||||||
# This is safe and doesn't require replica context
|
|
||||||
_ = self.optimizer.iterations # This ensures basic state is created
|
|
||||||
|
|
||||||
print("✅ Optimizer fully ready for TPU training")
|
|
||||||
print("📝 Note: Optimizer will work correctly in @tf.function context")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
|
print(f"❌ CRITICAL: Could not initialize optimizer: {e}")
|
||||||
print(f"Error type: {type(e).__name__}")
|
print(f"Error type: {type(e).__name__}")
|
||||||
import traceback
|
import traceback
|
||||||
print(f"Full traceback: {traceback.format_exc()}")
|
print(f"Full traceback: {traceback.format_exc()}")
|
||||||
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
|
raise RuntimeError(f"Optimizer initialization failed: {e}") from e
|
||||||
|
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||||
|
Reference in New Issue
Block a user