From a545cc564897472a4cc931d474ff5a7192e286b0 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:39:05 +0800 Subject: [PATCH] =?UTF-8?q?tpu=E7=BB=B4=E6=8A=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- model_training_nnn_tpu/check_tpu_memory.py | 403 +++++++++++++++++++ model_training_nnn_tpu/tpu_memory_monitor.py | 236 +++++++++++ model_training_nnn_tpu/trainer_tf.py | 91 +++-- 4 files changed, 708 insertions(+), 26 deletions(-) create mode 100644 model_training_nnn_tpu/check_tpu_memory.py create mode 100644 model_training_nnn_tpu/tpu_memory_monitor.py diff --git a/.gitignore b/.gitignore index ca37a77..c0f2dde 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,6 @@ model_training_lstm/trained_models model_training_lstm/trained_models_history -*.pkl \ No newline at end of file +*.pkl + +.idea \ No newline at end of file diff --git a/model_training_nnn_tpu/check_tpu_memory.py b/model_training_nnn_tpu/check_tpu_memory.py new file mode 100644 index 0000000..7356847 --- /dev/null +++ b/model_training_nnn_tpu/check_tpu_memory.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +""" +TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控 +适用于TPU v5e-8环境 +""" + +import tensorflow as tf +import time +import numpy as np + +def monitor_tpu_during_training(): + """训练过程中的TPU实时内存和MXU监控""" + print("📊 TPU训练实时监控工具") + print("=" * 50) + + # 获取TPU设备 + try: + tpu_devices = tf.config.list_logical_devices('TPU') + print(f"📍 发现TPU设备: {len(tpu_devices)}个") + if not tpu_devices: + print("❌ 未发现TPU设备") + return + except Exception as e: + print(f"❌ 无法检测TPU设备: {e}") + return + + def get_detailed_memory_snapshot(): + """获取详细的内存快照,包含所有核心信息""" + snapshot = {} + total_current = 0 + total_peak = 0 + active_cores = 0 + core_details = [] + + for i, device in enumerate(tpu_devices): + try: + memory_info = tf.config.experimental.get_memory_info(device.name) + if memory_info and 'current' in memory_info: + current_mb = memory_info['current'] // (1024 * 1024) + peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) + + if current_mb > 1: # >1MB算活跃 + active_cores += 1 + total_current += current_mb + total_peak += peak_mb + core_details.append(f"Core{i}:{current_mb}MB") + + snapshot[f'core_{i}'] = { + 'current': current_mb, + 'peak': peak_mb, + 'device': device.name + } + else: + snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name} + except Exception as e: + snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)} + + snapshot['summary'] = { + 'total_current': total_current, + 'total_peak': total_peak, + 'active_cores': active_cores, + 'total_cores': len(tpu_devices), + 'core_details': core_details + } + return snapshot + + def test_mxu_performance(): + """测试MXU性能和计算能力""" + print("\n🧮 MXU计算性能测试:") + + mxu_results = [] + try: + with tf.device(tpu_devices[0].name): + # 测试不同规模的矩阵运算 + test_configs = [ + (2000, "2K×2K", tf.bfloat16), + (4000, "4K×4K", tf.bfloat16), + (6000, "6K×6K", tf.bfloat16), + ] + + for size, desc, dtype in test_configs: + try: + # 获取测试前内存 + pre_mem = get_detailed_memory_snapshot() + + start_time = time.time() + + # 创建矩阵并执行MXU密集型运算 + matrix_a = tf.random.normal([size, size], dtype=dtype) + matrix_b = tf.random.normal([size, size], dtype=dtype) + + @tf.function + def mxu_operation(): + # 连续矩阵运算,充分使用MXU + result = tf.matmul(matrix_a, matrix_b) + result = tf.matmul(result, matrix_a) + return tf.reduce_sum(result) + + result = mxu_operation() + # 使用result确保计算被执行 + _ = result.numpy() + end_time = time.time() + + # 获取测试后内存 + post_mem = get_detailed_memory_snapshot() + + duration = end_time - start_time + # 计算FLOPS (两次矩阵乘法) + flops = 2 * (2 * size**3) + tflops = flops / duration / 1e12 + + memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current'] + + print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB") + + mxu_results.append({ + 'size': size, + 'tflops': tflops, + 'duration': duration, + 'memory_used': memory_used + }) + + except Exception as e: + print(f" {desc}: 测试失败 - {str(e)[:50]}") + + # MXU性能分析 + if mxu_results: + max_tflops = max(r['tflops'] for r in mxu_results) + total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0) + + # TPU v5e-8单核理论性能 + theoretical_tflops = 275 # bf16峰值性能 + efficiency = (max_tflops / theoretical_tflops) * 100 + + print(f"\n 📊 MXU性能汇总:") + print(f" 峰值性能: {max_tflops:.1f} TFLOPS") + print(f" 理论峰值: {theoretical_tflops} TFLOPS") + print(f" MXU效率: {efficiency:.1f}%") + print(f" 计算内存占用: {total_memory}MB") + + if efficiency > 80: + status = "🟢 优秀" + elif efficiency > 50: + status = "🟡 良好" + elif efficiency > 20: + status = "🟠 中等" + else: + status = "🔴 需优化" + + print(f" 性能评级: {status}") + + except Exception as e: + print(f" MXU测试失败: {e}") + + try: + print("🎯 开始TPU训练监控...") + + # 1. 获取初始状态 + print("\n📸 初始TPU状态:") + baseline_snapshot = get_detailed_memory_snapshot() + + print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB") + print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}") + + # 显示各核心详细状态 + for i in range(len(tpu_devices)): + core = baseline_snapshot[f'core_{i}'] + if core['current'] > 0 or core['peak'] > 0: + print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB") + + # 2. MXU性能基准测试 + test_mxu_performance() + + # 3. 创建分布式策略 - 使用项目验证的TPU初始化代码 + print(f"\n🔄 使用项目标准TPU初始化...") + try: + # 使用项目里验证过的TPU初始化代码 + # 禁用GPU避免冲突 + try: + tf.config.set_visible_devices([], 'GPU') + print("🚫 GPU已禁用,避免CUDA冲突") + except: + pass + + # 使用标准的TPU初始化流程 + print("🚀 使用官方TensorFlow TPU初始化...") + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + + # 验证TPU设备 + tpu_devices_check = tf.config.list_logical_devices('TPU') + print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备") + + # 创建TPU策略 + strategy = tf.distribute.TPUStrategy(resolver) + print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本") + use_distributed = True + + except Exception as e: + print(f"⚠️ 分布式策略失败: {str(e)[:80]}") + print(" 将使用单设备模拟") + use_distributed = False + + # 4. 模拟Brain-to-Text训练场景 + print(f"\n🧠 模拟Brain-to-Text训练场景...") + + if use_distributed: + # 分布式训练模拟 + with strategy.scope(): + print("📦 创建分布式模型参数...") + + # 创建接近真实Brain-to-Text模型的参数 (修复维度匹配) + model_components = { + # GRU层权重:第一层接收512维输入,后续层接收256维 + 'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'), + 'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'), + 'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'), + 'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'), + # 添加day-specific层模拟 (输入512维,输出512维) + 'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)] + } + + # 检查模型加载后内存 + after_model = get_detailed_memory_snapshot() + model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current'] + print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心") + + # 训练循环模拟 + print(f"\n🔄 开始训练循环监控...") + + for step in range(10): + step_start_time = time.time() + + @tf.function + def distributed_training_step(): + # 模拟真实训练数据大小 + batch_size = 32 + seq_length = 1000 + features = 512 + + # 输入数据 + neural_data = tf.random.normal([batch_size, seq_length, features]) + targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32) + + # 模拟前向传播 + x = neural_data + + # Day-specific transformation (简化版本避免复杂的维度操作) + # 模拟day-specific变换:对每个时间步应用相同变换 + day_weight = model_components['day_weights'][0] # 简化:使用第一个day权重 + # 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512] + x = tf.matmul(x, day_weight) + + # 为CTC损失添加目标使用(模拟) + target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1) + # 简化的CTC相关计算 + batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32)) + + # GRU layers simulation + for i in range(3): + layer_name = f'gru_layer_{i}' + weight = model_components[layer_name] + + # 处理张量维度:第一层从3D输入,后续层从2D输入 + if i == 0: + # 第一层:取最后时间步 [batch, seq, features] -> [batch, features] + if len(x.shape) == 3: + x = x[:, -1, :] # 取最后时间步 + x = tf.nn.tanh(tf.matmul(x, weight)) + else: + # 后续层:直接处理2D张量 [batch, features] -> [batch, features] + x = tf.nn.tanh(tf.matmul(x, weight)) + + # 输出投影 + logits = tf.matmul(x, model_components['output_projection']) + + # CTC loss模拟(使用batch_loss_weight作为权重) + base_loss = tf.reduce_mean(tf.square(logits)) + loss = base_loss * batch_loss_weight + + return loss + + # 执行训练步骤 + per_replica_loss = strategy.run(distributed_training_step) + # 聚合分布式结果 + loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None) + step_duration = time.time() - step_start_time + + # 获取当前内存状态 + current_snapshot = get_detailed_memory_snapshot() + step_memory = current_snapshot['summary']['total_current'] + memory_delta = step_memory - baseline_snapshot['summary']['total_current'] + + # 显示详细训练状态 + active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)" + + print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, " + f"时间={step_duration:.3f}s, " + f"内存={step_memory}MB(+{memory_delta}), " + f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}") + + # 每5步显示峰值内存 + if step % 5 == 0: + peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB" + print(f" {peak_info}") + + time.sleep(0.2) # 短暂暂停观察 + + else: + # 单设备训练模拟(改进版) + print("🔸 单设备训练模拟...") + + with tf.device(tpu_devices[0].name): + # 创建较小的模型参数 + simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net') + + for step in range(8): + step_start = time.time() + + # 创建较大的数据批次 + batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size + + # 模拟计算密集型操作 + @tf.function + def compute_step(): + x = tf.reshape(batch_data, [-1, 512]) + result = tf.matmul(x, simple_weights) + result = tf.nn.relu(result) + return tf.reduce_mean(result) + + result = compute_step() + step_duration = time.time() - step_start + + # 获取内存状态 + snapshot = get_detailed_memory_snapshot() + memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] + + print(f" Step {step}: result={result.numpy():.4f}, " + f"时间={step_duration:.3f}s, " + f"内存变化=+{memory_change}MB, " + f"峰值={snapshot['summary']['total_peak']}MB") + + # 5. 最终分析报告 + final_snapshot = get_detailed_memory_snapshot() + total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] + peak_usage = final_snapshot['summary']['total_peak'] + + print(f"\n📈 训练监控报告:") + print(f" 总内存增长: +{total_growth}MB") + print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)") + print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}") + + # 各核心最终状态 + print(f" 各核心最终状态:") + has_changes = False + for i in range(len(tpu_devices)): + final_core = final_snapshot[f'core_{i}'] + baseline_core = baseline_snapshot[f'core_{i}'] + current_change = final_core['current'] - baseline_core['current'] + peak_change = final_core['peak'] - baseline_core['peak'] + + if current_change != 0 or peak_change != 0: + has_changes = True + print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})") + + if not has_changes: + print(f" 所有核心内存无明显变化") + + # 分布式使用分析 + if final_snapshot['summary']['active_cores'] == 1: + print(f"\n⚠️ 分布式问题诊断:") + print(f" 只有1个核心活跃,其他7个核心空闲") + print(f" 可能原因: TPU策略配置问题或模型未正确分布") + print(f" 建议: 检查分布式策略和模型分片") + elif final_snapshot['summary']['active_cores'] > 4: + print(f"\n✅ 分布式状态良好:") + print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常") + else: + print(f"\n🟡 分布式部分工作:") + print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡") + + print("✅ TPU训练监控完成") + + except Exception as e: + print(f"❌ 训练监控失败: {e}") + import traceback + print(f"详细错误: {traceback.format_exc()[:300]}") + +if __name__ == "__main__": + print("🚀 TPU训练内存监控工具") + print("专注于训练过程中的实时内存和性能监控") + print("适用于TPU v5e-8环境") + print() + + monitor_tpu_during_training() + + print(f"\n🎯 监控要点总结:") + print(f" 1. 确认所有8个TPU核心是否活跃") + print(f" 2. 监控内存增长模式和峰值使用") + print(f" 3. 检测MXU计算性能和效率") + print(f" 4. 验证分布式策略是否正常工作") + print(f" 5. 识别可能的内存泄漏或性能瓶颈") \ No newline at end of file diff --git a/model_training_nnn_tpu/tpu_memory_monitor.py b/model_training_nnn_tpu/tpu_memory_monitor.py new file mode 100644 index 0000000..970d6f3 --- /dev/null +++ b/model_training_nnn_tpu/tpu_memory_monitor.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +TPU内存监控工具 - 专门用于训练过程 +解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题 +""" + +import tensorflow as tf +import time +import psutil +import os + +class TPUMemoryMonitor: + """TPU内存监控类""" + + def __init__(self): + self.tpu_devices = tf.config.list_logical_devices('TPU') + self.baseline_memory = None + self.peak_allocations = {} + + def get_tpu_status(self) -> str: + """获取TPU状态 - 实用版本,不依赖get_memory_info""" + try: + if not self.tpu_devices: + return "TPU: No devices" + + num_cores = len(self.tpu_devices) + + # 测试TPU响应性 + try: + with tf.device('/TPU:0'): + test_tensor = tf.constant([1.0, 2.0, 3.0]) + result = tf.reduce_sum(test_tensor) + _ = result.numpy() # 强制执行 + activity = "active" + except Exception: + activity = "inactive" + + # 获取主机内存作为参考 + try: + memory = psutil.virtual_memory() + host_mem = f"Host:{memory.percent:.1f}%" + except: + host_mem = "Host:unknown" + + return f"TPU: {num_cores}cores {activity} {host_mem}" + + except Exception as e: + return f"TPU: error({str(e)[:20]})" + + def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32): + """估算张量内存使用量""" + if dtype == tf.float32: + bytes_per_element = 4 + elif dtype == tf.float16 or dtype == tf.bfloat16: + bytes_per_element = 2 + elif dtype == tf.int32: + bytes_per_element = 4 + elif dtype == tf.int64: + bytes_per_element = 8 + else: + bytes_per_element = 4 # 默认 + + total_elements = 1 + for dim in tensor_shape: + total_elements *= dim + + total_bytes = total_elements * bytes_per_element + return total_bytes / (1024 * 1024) # 返回MB + + def track_allocation(self, name: str, tensor_shape, dtype=tf.float32): + """跟踪内存分配""" + mb = self.estimate_tensor_memory(tensor_shape, dtype) + self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb + return mb + + def get_allocation_summary(self) -> str: + """获取分配汇总""" + if not self.peak_allocations: + return "No allocations tracked" + + total_mb = sum(self.peak_allocations.values()) + top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3] + + summary = f"Tracked:{total_mb:.1f}MB " + summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)" + + return summary + + def test_memory_allocation_across_cores(self): + """测试8个核心的内存分配""" + print("🧪 测试所有TPU核心内存分配") + print("=" * 40) + + allocations_per_core = [] + + for i, device in enumerate(self.tpu_devices): + print(f"核心 {i+1}: {device.name}") + + try: + with tf.device(device.name): + # 创建不同大小的测试张量 + test_sizes = [ + ([1000, 1000], "1K×1K"), + ([3000, 3000], "3K×3K"), + ([5000, 5000], "5K×5K"), + ([7000, 7000], "7K×7K"), + ] + + core_total = 0 + successful_allocs = [] + + for shape, desc in test_sizes: + try: + tensor = tf.ones(shape, dtype=tf.float32) + mb = self.estimate_tensor_memory(shape) + core_total += mb + successful_allocs.append(f"{desc}({mb:.1f}MB)") + + # 实际使用张量防止被优化 + _ = tf.reduce_mean(tensor) + + except Exception as e: + print(f" {desc} 失败: {str(e)[:30]}") + break + + allocations_per_core.append(core_total) + print(f" 成功分配: {' + '.join(successful_allocs)}") + print(f" 核心总计: {core_total:.1f}MB") + + except Exception as e: + print(f" 核心{i+1}失败: {e}") + allocations_per_core.append(0) + + # 汇总结果 + total_all_cores = sum(allocations_per_core) + avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0 + + print(f"\n📊 汇总结果:") + print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)") + print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)") + + # 推测内存配置 + if avg_per_core > 8000: # > 8GB + print(" 推测: 每核心≥16GB (高端配置)") + elif avg_per_core > 4000: # > 4GB + print(" 推测: 每核心8-16GB (标准配置)") + elif avg_per_core > 1000: # > 1GB + print(" 推测: 每核心2-8GB (受限或共享)") + else: + print(" 推测: 每核心<2GB (严重受限)") + + return allocations_per_core + +def test_training_memory_pattern(): + """测试模拟训练的内存模式""" + print("\n🏋️ 模拟训练内存模式测试") + print("=" * 30) + + monitor = TPUMemoryMonitor() + + # 模拟典型的brain-to-text模型内存使用 + with tf.device('/TPU:0'): + print("创建模拟模型组件...") + + # 1. 输入数据 (batch_size=32, seq_len=1000, features=512) + batch_size, seq_len, features = 32, 1000, 512 + input_data = tf.random.normal([batch_size, seq_len, features]) + input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features]) + print(f" 输入数据: {input_mb:.1f}MB") + + # 2. GRU权重 (假设3层, 每层256单元) + n_layers, n_units = 3, 256 + for layer in range(n_layers): + # GRU有3个门,每个门需要权重矩阵 + weight_shape = [features if layer == 0 else n_units, n_units * 3] + weights = tf.random.normal(weight_shape) + weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape) + print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB") + + # 3. 输出投影层 (n_units -> n_classes=41) + n_classes = 41 + output_weights = tf.random.normal([n_units, n_classes]) + output_mb = monitor.track_allocation("output_projection", [n_units, n_classes]) + print(f" 输出投影: {output_mb:.1f}MB") + + # 4. 中间激活值 (前向传播) + hidden_states = tf.random.normal([batch_size, seq_len, n_units]) + hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units]) + print(f" 隐藏状态: {hidden_mb:.1f}MB") + + # 5. 梯度 (反向传播时会翻倍内存) + total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k]) + gradient_mb = total_params_mb # 梯度内存约等于参数内存 + print(f" 梯度内存: {gradient_mb:.1f}MB (估算)") + + print(f"\n模型总内存估算: {monitor.get_allocation_summary()}") + + # 实际执行一些操作确保内存被分配 + result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states) + print(f"验证计算结果: {result.numpy():.4f}") + +if __name__ == "__main__": + print("🚀 TPU内存监控工具启动") + + monitor = TPUMemoryMonitor() + + # 基础状态检查 + print(f"当前TPU状态: {monitor.get_tpu_status()}") + + # 测试所有核心 + print("\n" + "="*50) + core_allocations = monitor.test_memory_allocation_across_cores() + + # 训练内存模式测试 + print("\n" + "="*50) + test_training_memory_pattern() + + print(f"\n🎯 关键发现:") + if core_allocations: + max_core = max(core_allocations) + min_core = min([x for x in core_allocations if x > 0]) + print(f" 最大单核分配: {max_core:.1f}MB") + print(f" 最小单核分配: {min_core:.1f}MB") + + if max_core > 9000: # 你之前测试到9.4GB + print(" ✅ 内存充足,可支持大模型训练") + elif max_core > 5000: + print(" ⚠️ 内存中等,建议优化模型大小") + else: + print(" ❌ 内存不足,需要大幅减少模型参数") + + print(f"\n💡 针对你的训练卡顿问题:") + print(f" - SetPriority错误通常是XLA编译问题,不是内存问题") + print(f" - 你的9.4GB测试说明TPU内存工作正常") + print(f" - 建议检查模型是否有导致XLA编译卡顿的操作") + print(f" - 考虑使用更简单的操作或关闭某些XLA优化") \ No newline at end of file diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 31e344f..ee808ce 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -177,26 +177,43 @@ class BrainToTextDecoderTrainerTF: # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 - # Try to get TPU memory info (HBM) + # Get TPU memory info using the working /device:TPU:X format try: - # Attempt to get TPU memory usage for each device - memory_info = tf.config.experimental.get_memory_info('/TPU:0') - if memory_info and 'current' in memory_info: - current_mb = memory_info['current'] // (1024 * 1024) - peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) - hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)" + # Check all TPU devices for memory usage + active_cores = 0 + total_current_mb = 0 + max_peak_mb = 0 + + for device in tpu_devices: + try: + memory_info = tf.config.experimental.get_memory_info(device.name) + if memory_info and 'current' in memory_info: + current_mb = memory_info['current'] // (1024 * 1024) + peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) + + if current_mb > 1: # >1MB considered active + active_cores += 1 + total_current_mb += current_mb + max_peak_mb = max(max_peak_mb, peak_mb) + except: + continue + + if active_cores > 0: + if active_cores == 1: + hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)" + else: + hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)" else: - hbm_info = "HBM: unknown" + hbm_info = "HBM:idle" + except Exception: # Fallback: simple TPU activity check try: - # Test TPU responsiveness with tf.device('/TPU:0'): - test_tensor = tf.constant([1.0, 2.0]) - _ = tf.reduce_sum(test_tensor) - hbm_info = "HBM: active" + _ = tf.constant(1.0) + hbm_info = "HBM:active" except Exception: - hbm_info = "HBM: inactive" + hbm_info = "HBM:inactive" return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores " f"{hbm_info}") @@ -217,17 +234,37 @@ class BrainToTextDecoderTrainerTF: num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 strategy_type = type(self.strategy).__name__ - # Get TPU HBM memory info + # Get TPU HBM memory info using working device format try: - memory_info = tf.config.experimental.get_memory_info('/TPU:0') - if memory_info and 'current' in memory_info: - current_gb = memory_info['current'] // (1024 * 1024 * 1024) - peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) - # TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB - estimated_total_gb = 32 * len(tpu_devices) - hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)" + active_cores = 0 + total_current_gb = 0 + max_peak_gb = 0 + memory_details = [] + + for i, device in enumerate(tpu_devices): + try: + memory_info = tf.config.experimental.get_memory_info(device.name) + if memory_info and 'current' in memory_info: + current_gb = memory_info['current'] // (1024 * 1024 * 1024) + peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) + + if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB + active_cores += 1 + total_current_gb += current_gb + max_peak_gb = max(max_peak_gb, peak_gb) + if current_gb > 0: + memory_details.append(f"Core{i}:{current_gb}GB") + except: + continue + + if active_cores > 0: + # Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core + estimated_per_core = 16 # Conservative estimate + estimated_total_gb = estimated_per_core * len(tpu_devices) + hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores" else: - hbm_usage = "HBM: unknown" + hbm_usage = "HBM: 0GB/256GB (idle)" + except Exception: hbm_usage = "HBM: unavailable" @@ -559,39 +596,43 @@ class BrainToTextDecoderTrainerTF: self.args['dataset']['data_transforms'], training=True ) + val_dataset = create_input_fn( self.val_dataset_tf, self.args['dataset']['data_transforms'], training=False ) - # Distribute datasets train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset) val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset) + self.logger.info("Created distributed training and validation datasets") # Training metrics train_losses = [] val_losses = [] val_pers = [] val_results = [] val_steps_since_improvement = 0 - + self.logger.info("Training time count beginning...") train_start_time = time.time() # Training loop step = 0 for batch in train_dist_dataset: if step >= self.args['num_training_batches']: + self.logger.info("Reached maximum training batches, stopping training") break - + start_time = time.time() # Distributed training step + self.logger.info("Running distributed training step...") per_replica_losses, per_replica_grad_norms = self.strategy.run( self._train_step, args=(batch, step) ) # Reduce across replicas + self.logger.info("Reducing results across replicas...") loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)