Files
b2txt25/model_training_nnn_tpu/setup_tensorflow_tpu.sh
Zchen 7965f7dbfe TPU
2025-10-15 16:55:52 +08:00

150 lines
4.5 KiB
Bash

#!/bin/bash
# Setup script for TensorFlow Brain-to-Text training on TPU v5e-8
#
# Usage: ./setup_tensorflow_tpu.sh
#
# This script prepares the environment for training the brain-to-text model
# using TensorFlow on TPU v5e-8 hardware.
set -e # Exit on any error
echo "=== TensorFlow TPU v5e-8 Setup Script ==="
echo "Setting up environment for brain-to-text training..."
# Check if we're in a TPU environment
if [[ -z "${TPU_NAME}" ]] && [[ -z "${COLAB_TPU_ADDR}" ]]; then
echo "Warning: TPU environment variables not detected."
echo "Make sure you're running on a TPU v5e-8 instance."
fi
# Create conda environment for TensorFlow TPU
ENV_NAME="b2txt_tf"
echo "Creating conda environment: ${ENV_NAME}"
if conda env list | grep -q "^${ENV_NAME} "; then
echo "Environment ${ENV_NAME} already exists. Activating..."
conda activate ${ENV_NAME}
else
echo "Creating new environment..."
conda create -n ${ENV_NAME} python=3.10 -y
conda activate ${ENV_NAME}
fi
# Install TensorFlow with TPU support
echo "Installing TensorFlow with TPU support..."
pip install tensorflow[and-cuda]>=2.15.0
# Install additional requirements
echo "Installing additional requirements..."
pip install -r requirements_tf.txt
# Set up TPU environment variables
echo "Configuring TPU environment variables..."
# Create or update .bashrc with TPU optimizations
cat >> ~/.bashrc << 'EOF'
# TPU v5e-8 Environment Variables
export TPU_ML_PLATFORM="TensorFlow"
export XLA_USE_BF16=1
export TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
export TPU_MEGACORE=1
export LIBTPU_INIT_ARGS="--xla_tpu_spmd_threshold_for_allgather_cse=10000"
# Disable TensorFlow warnings for cleaner output
export TF_CPP_MIN_LOG_LEVEL=2
# Memory optimizations
export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_THREAD_MODE=gpu_private
EOF
# Source the updated .bashrc
source ~/.bashrc
# Test TPU connectivity
echo "Testing TPU connectivity..."
python3 << 'EOF'
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print(f"TPU cluster initialized successfully!")
print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")
print(f"TPU devices: {tf.config.list_logical_devices('TPU')}")
except Exception as e:
print(f"TPU initialization failed: {e}")
print("You may be running on CPU/GPU instead of TPU")
# Test mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision policy: {policy.name}")
EOF
# Verify data directory exists
DATA_DIR="../data/hdf5_data_final"
if [ -d "$DATA_DIR" ]; then
echo "Data directory found: $DATA_DIR"
# Count available sessions
SESSION_COUNT=$(ls -d $DATA_DIR/t*.20* 2>/dev/null | wc -l)
echo "Available sessions: $SESSION_COUNT"
else
echo "Warning: Data directory not found at $DATA_DIR"
echo "Please ensure the dataset is available before training."
fi
# Create output directories
echo "Creating output directories..."
mkdir -p trained_models/tensorflow_tpu
mkdir -p logs/tensorflow_tpu
mkdir -p eval_output
# Make scripts executable
echo "Setting script permissions..."
chmod +x train_model_tf.py
chmod +x evaluate_model_tf.py
# Display system information
echo "=== System Information ==="
echo "Python version: $(python --version)"
echo "Conda environment: $CONDA_DEFAULT_ENV"
echo "Available memory: $(free -h | grep '^Mem:' | awk '{print $7}')"
echo "CPU cores: $(nproc)"
# Check for GPU/TPU
echo "=== Hardware Information ==="
if nvidia-smi &> /dev/null; then
echo "NVIDIA GPUs detected:"
nvidia-smi --list-gpus
else
echo "No NVIDIA GPUs detected"
fi
if [[ -n "${TPU_NAME}" ]]; then
echo "TPU Name: $TPU_NAME"
elif [[ -n "${COLAB_TPU_ADDR}" ]]; then
echo "Colab TPU Address: $COLAB_TPU_ADDR"
else
echo "No TPU environment variables detected"
fi
echo ""
echo "=== Setup Complete ==="
echo "Environment '$ENV_NAME' is ready for TensorFlow TPU training."
echo ""
echo "To activate the environment:"
echo " conda activate $ENV_NAME"
echo ""
echo "To start training:"
echo " python train_model_tf.py --config_path rnn_args.yaml"
echo ""
echo "To run evaluation:"
echo " python evaluate_model_tf.py --model_path path/to/checkpoint --config_path rnn_args.yaml"
echo ""
echo "For more options, use --help with any script."