Files
b2txt25/model_training_nnn_tpu/jupyter_xla_compatibility.py

93 lines
2.4 KiB
Python
Raw Normal View History

2025-10-15 14:33:49 +08:00
# ====================
# 单元格: XLA版本兼容性检查和修复
# ====================
import torch
import torch.nn as nn
print("🔧 PyTorch XLA版本兼容性检查...")
# 导入XLA
import torch_xla.core.xla_model as xm
print("✅ PyTorch XLA导入成功!")
# 定义兼容性函数
def get_xla_world_size():
"""获取XLA world size兼容不同版本"""
try:
return xm.xrt_world_size()
except AttributeError:
try:
return xm.get_world_size()
except AttributeError:
return 1 # 默认返回1
def get_xla_ordinal():
"""获取XLA ordinal兼容不同版本"""
try:
return xm.get_ordinal()
except AttributeError:
return 0 # 默认返回0
def xla_mark_step():
"""XLA mark step兼容不同版本"""
try:
xm.mark_step()
except AttributeError:
try:
xm.wait_device_ops()
except AttributeError:
pass # 如果都不可用,则跳过
def check_xla_device():
"""检查XLA设备状态"""
try:
device = xm.xla_device()
print(f"📱 XLA设备: {device}")
world_size = get_xla_world_size()
ordinal = get_xla_ordinal()
print(f"🌍 World Size: {world_size}")
print(f"🔢 Ordinal: {ordinal}")
# 检测设备类型
device_str = str(device)
if 'xla' in device_str and 'cpu' not in device_str:
print("✅ 检测到TPU设备")
return True, "TPU"
elif 'xla' in device_str and 'cpu' in device_str:
print("⚠️ XLA CPU模拟模式")
return True, "XLA_CPU"
else:
print("❌ 未检测到XLA设备")
return False, "CPU"
except Exception as e:
print(f"❌ XLA设备检查失败: {e}")
return False, "ERROR"
# 执行兼容性检查
device_available, device_type = check_xla_device()
if device_available:
print(f"✅ XLA环境正常设备类型: {device_type}")
# 测试基本XLA操作
print("🧪 测试基本XLA操作...")
try:
device = xm.xla_device()
x = torch.randn(2, 2, device=device)
y = torch.matmul(x, x)
# 测试同步函数
xla_mark_step()
print("✅ 基本XLA操作测试成功")
except Exception as e:
print(f"❌ XLA操作测试失败: {e}")
else:
print("❌ XLA环境不可用")
print("\n💡 兼容性检查完成,可以运行后续单元格")