BERT中文情感分析训练问题修复记录

闲鱼店铺 小食品大王钟爱土豆片 技术KK

问题1:数据集加载错误

FileNotFoundError: Directory code/my_model_cache/bert_base_chinese/my_dataset/lansinuote___chn_senti_corp is neither a `Dataset` directory nor a `DatasetDict` directory.
修复方案:修改 dataset_load.py,添加自动下载和缓存机制
try:
    self.dataset = load_from_disk(dataset_path)
except FileNotFoundError:
    self.dataset = load_dataset("lansinuote/ChnSentiCorp", 
                              cache_dir=cache_dir,
                              trust_remote_code=True)
                
原因:数据集路径不存在或格式不正确,通过添加自动下载机制解决。

问题2:优化器参数错误

TypeError: 'method' object is not iterable
修复方案:修改 trainer.py 中的优化器初始化
# 错误写法
optimizer = AdamW(model.parameters, lr=5e-4)

# 正确写法
optimizer = AdamW(model.parameters(), lr=5e-4)
                
原因:model.parameters 是一个方法,需要调用它来获取参数列表。

问题3:GPU显存不足

RuntimeError: Could not allocate tensor with 188160000 bytes. There is not enough GPU video memory available!
修复方案:调整批次大小和序列长度
# 修改批次大小
batch_size=16  # 从32改为16

# 修改序列长度
max_length=128  # 从350改为128
                
原因:GPU显存不足以处理大批次数据,通过减小批次大小和序列长度来降低显存使用。

问题4:DirectML设备信息获取错误

module 'torch_directml' has no attribute 'device_properties'
修复方案:重写GPU信息检查代码,适配DirectML
def print_gpu_info():
    print("\n=== GPU信息 ===")
    try:
        dml_device = torch_directml.device()
        print(f"DirectML设备: {dml_device}")
        print("DirectML设备信息:")
        print(f"- 设备ID: {dml_device}")
        print(f"- 设备类型: DirectML")
    except Exception as e:
        print(f"获取GPU信息时出错: {e}")
                
原因:DirectML的API与CUDA不同,需要特别处理设备信息获取。