-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Open
Description
准确率应该是TP/TP+FP,而CPA是每类的准确率。所以
原代码应该改为
# 总准确率 TP+TN/TP+TN+FP+FN
all_acc = total_area_intersect.sum() / total_area_pred_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
# 计算指标
for metric in metrics:
# mIoU
if metric == 'mIoU':
# 交集比并集
iou = total_area_intersect / total_area_union
# 准确率 该类交集比该类
acc = total_area_intersect / total_area_pred_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_pred_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor([
f_score(x[0], x[1], beta) for x in zip(precision, recall)
])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
Metadata
Metadata
Assignees
Labels
No labels