首页 > 解决方案 > 通过跳过头部层加载张量流检查点

问题描述

我正在使用模型(SimCLR)从图像中学习表示。在预训练时,模型是针对单个虚拟标签进行训练的。现在我想用 8 类数据微调模型。在将预训练模型检查点加载到尚未经过微调的具有 8 类头的模型时,我遇到了 ValueError。

ValueError: Tensor's shape (2048, 1) is not compatible with supplied shape [2048, 8]

在加载到检查点以微调模型之前,是否有排除最后一个头层权重的解决方案?


系统信息

标签: pythontensorflowtensorflow2.0

解决方案


好吧,为了让您的预训练模型能够成功处理您的新输入,它们需要与它期望​​的旧输入具有完全相同的形状(来自旧的 1D 模型)。要让您的 8 类数据与此模型一起使用,您需要更改模型本身以处理 8 个类的输入。这可能需要您编辑模型本身的属性,并且如果没有代码的可视化,很难确切地说出您需要在哪里进行更改。


推荐阅读