BERT中文情感分析训练问题修复记录
闲鱼店铺 小食品大王钟爱土豆片 技术KK
问题1:数据集加载错误
FileNotFoundError: Directory code/my_model_cache/bert_base_chinese/my_dataset/lansinuote___chn_senti_corp is neither a `Dataset` directory nor a `DatasetDict` directory.
修复方案:修改 dataset_load.py,添加自动下载和缓存机制
try:
self.dataset = load_from_disk(dataset_path)
except FileNotFoundError:
self.dataset = load_dataset("lansinuote/ChnSentiCorp",
cache_dir=cache_dir,
trust_remote_code=True)
原因:数据集路径不存在或格式不正确,通过添加自动下载机制解决。
问题2:优化器参数错误
TypeError: 'method' object is not iterable
修复方案:修改 trainer.py 中的优化器初始化
# 错误写法
optimizer = AdamW(model.parameters, lr=5e-4)
# 正确写法
optimizer = AdamW(model.parameters(), lr=5e-4)
原因:model.parameters 是一个方法,需要调用它来获取参数列表。
问题3:GPU显存不足
RuntimeError: Could not allocate tensor with 188160000 bytes. There is not enough GPU video memory available!
修复方案:调整批次大小和序列长度
# 修改批次大小
batch_size=16 # 从32改为16
# 修改序列长度
max_length=128 # 从350改为128
原因:GPU显存不足以处理大批次数据,通过减小批次大小和序列长度来降低显存使用。
问题4:DirectML设备信息获取错误
module 'torch_directml' has no attribute 'device_properties'
修复方案:重写GPU信息检查代码,适配DirectML
def print_gpu_info():
print("\n=== GPU信息 ===")
try:
dml_device = torch_directml.device()
print(f"DirectML设备: {dml_device}")
print("DirectML设备信息:")
print(f"- 设备ID: {dml_device}")
print(f"- 设备类型: DirectML")
except Exception as e:
print(f"获取GPU信息时出错: {e}")
原因:DirectML的API与CUDA不同,需要特别处理设备信息获取。