python - PyTorch:“CrossEntropyLoss”对象没有属性“项目”
问题描述
目前正在部署一个 CNN 模型。
model = CNN(height=96, width=96, channels=3)
并观察其交叉熵损失。
criterion = nn.CrossEntropyLoss()
培训师课程如下,
class Trainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
criterion: nn.Module,
optimizer: Optimizer,
summary_writer: SummaryWriter,
device: torch.device,
):
self.model = model.to(device)
self.device = device
self.train_loader = train_loader
self.val_loader = val_loader
self.criterion = criterion
self.optimizer = optimizer
self.summary_writer = summary_writer
self.step = 0
def train(
self,
epochs: int,
val_frequency: int,
print_frequency: int = 20,
log_frequency: int = 5,
start_epoch: int = 0
):
self.model.train()
for epoch in range(start_epoch, epochs):
self.model.train()
data_load_start_time = time.time()
for batch, labels in self.train_loader:
batch = batch.to(self.device)
labels = labels.to(self.device)
data_load_end_time = time.time()
loss=self.criterion
logits=self.model.forward(batch)
with torch.no_grad():
preds = logits
accuracy = compute_accuracy(labels, preds)
data_load_time = data_load_end_time - data_load_start_time
step_time = time.time() - data_load_end_time
if ((self.step + 1) % log_frequency) == 0:
self.log_metrics(epoch, accuracy, loss, data_load_time, step_time)
if ((self.step + 1) % print_frequency) == 0:
self.print_metrics(epoch, accuracy, loss, data_load_time, step_time)
self.step += 1
data_load_start_time = time.time()
self.summary_writer.add_scalar("epoch", epoch, self.step)
if ((epoch + 1) % val_frequency) == 0:
self.validate()
self.model.train()
记录损失的函数是,
def log_metrics(self, epoch, accuracy, loss, data_load_time, step_time):
self.summary_writer.add_scalar("epoch", epoch, self.step)
self.summary_writer.add_scalars(
"accuracy",
{"train": accuracy},
self.step
)
self.summary_writer.add_scalars(
"loss",
{"train": float(loss.item())},
self.step
)
self.summary_writer.add_scalar(
"time/data", data_load_time, self.step
)
self.summary_writer.add_scalar(
"time/data", step_time, self.step
)
我收到了一个属性错误“ ‘CrossEntropyLoss’对象没有属性‘item’ ”。我尝试删除几种方法,例如从代码的不同部分删除“item()”,并尝试不同类型的损失函数,如 MSELoss 等。任何解决方案或方向都将受到高度赞赏。谢谢你。
编辑-1:
这是错误回溯
Traceback (most recent call last):
File "/Users/xyz/main.py", line 316, in <module>
main(parser.parse_args())
File "/Users/xyz/main.py", line 128, in main
log_frequency=args.log_frequency,
File "/Users/xyz/main.py", line 198, in train
self.log_metrics(epoch, accuracy, loss, data_load_time, step_time)
File "/Users/xyz/main.py", line 232, in log_metrics
{"train": float(loss.item)},
File "/Users/xyz/main.py", line 585, in __getattr__
type(self).__name__, name))
AttributeError: 'CrossEntropyLoss' object has no attribute 'item'
解决方案
看起来loss
调用self.log_metrics(epoch, accuracy, loss, data_load_time, step_time)
中的 是标准本身(CrossEntropyLoss 对象),而不是调用它的结果。
您的训练循环需要调用标准来计算损失,我在您提供的代码中看不到它。
推荐阅读
- java - 使用二维数组比两个一维数组快吗?
- javascript - JS获取元素返回未定义
- python - 无法在 Colab 中安装 allennlp 0.5.0
- javascript - 按 ID 标记线程
- c++ - c++ 为什么带有 uint32_t 的结构和带有 uint8_t[4] 的结构有不同的大小?
- javascript - 在类组件渲染之外反应上下文
- javascript - 带有可滚动子 div 的可滚动父 div
- c# - 从 vb.net 转换为 c#,切换功能不起作用
- netlogo - Netlogo:如何创建到turtle 的两个直接邻居的链接?
- javascript - 如何在 React 中动态使用调度?