fixed
This commit is contained in:
@@ -832,7 +832,16 @@ def create_tpu_strategy():
|
||||
print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
|
||||
|
||||
print("🔄 Falling back to default strategy (CPU/GPU)")
|
||||
return tf.distribute.get_strategy()
|
||||
fallback_strategy = tf.distribute.get_strategy()
|
||||
print(f"🎯 Fallback strategy created: {type(fallback_strategy).__name__}")
|
||||
print(f"📊 Fallback strategy replicas: {fallback_strategy.num_replicas_in_sync}")
|
||||
|
||||
# Ensure we never return None
|
||||
if fallback_strategy is None:
|
||||
print("⚠️ Warning: Default strategy is None, creating OneDeviceStrategy")
|
||||
fallback_strategy = tf.distribute.OneDeviceStrategy("/CPU:0")
|
||||
|
||||
return fallback_strategy
|
||||
|
||||
|
||||
def build_model_for_tpu(config):
|
||||
|
@@ -48,7 +48,11 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
if self.strategy is None:
|
||||
raise RuntimeError("Failed to create TPU strategy - strategy is None")
|
||||
|
||||
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
||||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||||
|
||||
# Configure mixed precision for TPU v5e-8
|
||||
if args.get('use_amp', True):
|
||||
@@ -95,9 +99,21 @@ class BrainToTextDecoderTrainerTF:
|
||||
print("🔧 Pre-building optimizer state for TPU...")
|
||||
# Force optimizer to build its internal state within strategy scope
|
||||
# This prevents the 'NoneType' strategy error during first apply_gradients
|
||||
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
||||
print("✅ Optimizer state pre-built successfully")
|
||||
try:
|
||||
# Check if strategy is properly initialized before applying gradients
|
||||
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
|
||||
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
||||
print("✅ Optimizer state pre-built successfully with TPU strategy")
|
||||
else:
|
||||
# Fallback: just build optimizer variables without applying gradients
|
||||
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
|
||||
# Alternative: trigger optimizer variable creation
|
||||
_ = self.optimizer.iterations
|
||||
print("✅ Optimizer state initialized (fallback mode)")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
|
||||
print("✅ Continuing without optimizer pre-build")
|
||||
|
||||
print("📅 Setting up learning rate scheduler...")
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
@@ -687,10 +703,21 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Distributed training step
|
||||
self.logger.info("Running distributed training step...")
|
||||
# Ensure we're in the correct TPU strategy scope
|
||||
with self.strategy.scope():
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
try:
|
||||
with self.strategy.scope():
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
except AttributeError as e:
|
||||
if "merge_call" in str(e):
|
||||
error_msg = f"Strategy merge_call error at step {step}: {e}"
|
||||
print(error_msg)
|
||||
if self.logger:
|
||||
self.logger.error(error_msg)
|
||||
self.logger.error("This indicates the strategy is not properly initialized")
|
||||
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Reduce across replicas
|
||||
self.logger.info("Reducing results across replicas...")
|
||||
|
Reference in New Issue
Block a user