[Question 4-6] Add sklearn confusion matrix

This commit is contained in:
wuhanstudio 2024-03-12 18:05:59 +00:00
parent 5ae34629d2
commit 292a257ea3
1 changed files with 65 additions and 56 deletions

View File

@ -137,6 +137,13 @@
"image_transform = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"# image_transform = transforms.Compose([\n",
"# # transforms.Resize(256),\n",
"# # transforms.CenterCrop(224),\n",
"# transforms.ToTensor(),\n",
"# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"# ])\n",
"\n",
"test_data = torchvision.datasets.ImageFolder('data/EXCV10/val/', transform=image_transform)\n",
"test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)"
]
@ -208,7 +215,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 75.8%\n"
"Test Accuracy: 70.05%\n"
]
}
],
@ -232,37 +239,25 @@
"metadata": {},
"outputs": [],
"source": [
"def test_cnn(model, test_loader):\n",
" \"\"\"\n",
" Test the trained ResNet model on the test dataset.\n",
"\n",
" Args:\n",
" model (nn.Module): The trained ResNet model.\n",
" test_loader (DataLoader): Data loader for the test data.\n",
" \n",
" Returns:\n",
" float: Test accuracy.\n",
" list: Predicted labels.\n",
" list: True labels.\n",
" \"\"\"\n",
" model.eval()\n",
" correct = 0\n",
"def test_cnn(model, test_loader, device='cpu'):\n",
" model.to(device)\n",
" model.eval() \n",
" total = 0\n",
" predicted_labels = []\n",
" true_labels = []\n",
" correct_num = 0\n",
" all_predicted_labels = []\n",
"\n",
" with torch.no_grad():\n",
" with torch.no_grad(): # No need to track gradients for testing\n",
" for images, labels in test_loader:\n",
" images, labels = images.to(device), labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" predicted_labels.extend(predicted.tolist())\n",
" true_labels.extend(labels.tolist())\n",
" correct_num += (predicted == labels).sum().item() \n",
" all_predicted_labels.append(predicted.cpu().numpy())\n",
"\n",
" accuracy = correct / total\n",
"\n",
" return predicted_labels, accuracy*100"
" accuracy = (correct_num / total) * 100\n",
" all_predicted_labels = np.concatenate(all_predicted_labels)\n",
" return all_predicted_labels, accuracy"
]
},
{
@ -275,7 +270,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 75.8%\n"
"Test Accuracy: 70.05%\n"
]
}
],
@ -286,7 +281,7 @@
},
{
"cell_type": "markdown",
"id": "17bc1948",
"id": "985e4f91",
"metadata": {},
"source": [
"### Test (Should output ALL PASS)"
@ -295,15 +290,15 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "853c4db3",
"id": "694097e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 75.8\n",
"Score 100%: 15.0\n",
"Test accuracy: 70.05\n",
"Score 90%: 13.5\n",
"ALL PASS\n"
]
}
@ -388,9 +383,32 @@
"m1_confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "60591999",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"def m2_compute_confusion_matrix(true, predictions):\n",
" return confusion_matrix(true, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "44b131f0",
"metadata": {},
"outputs": [],
"source": [
"m2_confusion_matrix = m2_compute_confusion_matrix(true_labels, m1_predicted_labels)"
]
},
{
"cell_type": "markdown",
"id": "608265af",
"id": "dd78bea6",
"metadata": {},
"source": [
"### Put Students' implementations here"
@ -398,43 +416,33 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 19,
"id": "1dce952c",
"metadata": {},
"outputs": [],
"source": [
"def compute_confusion_matrix(true_labels, predicted_labels):\n",
"\n",
" # Ensure inputs are NumPy arrays\n",
" true_labels = np.array(true_labels)\n",
" predicted_labels = np.array(predicted_labels)\n",
"\n",
" # Determine the number of classes\n",
" num_classes = len(np.unique(true_labels))\n",
"\n",
" # Initialize the confusion matrix with zeros\n",
" cm = np.zeros((num_classes, num_classes))\n",
"\n",
" # Count occurrences of true-predicted label pairs\n",
" for i in range(len(true_labels)):\n",
" cm[true_labels[i]][predicted_labels[i]] += 1\n",
"\n",
" return cm"
"def compute_confusion_matrix(true, predictions):\n",
" unique_labels = np.unique(np.concatenate((true, predictions)))\n",
" confusion_matrix = np.zeros((len(unique_labels), len(unique_labels)), dtype=np.int64)\n",
" for i, true_label in enumerate(unique_labels):\n",
" for j, predicted_label in enumerate(unique_labels):\n",
" confusion_matrix[i, j] = np.sum((true == true_label) & (predictions == predicted_label))\n",
" return confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "21917014",
"execution_count": 20,
"id": "1945d637",
"metadata": {},
"outputs": [],
"source": [
"confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
"confusion_matrix = compute_confusion_matrix(true_labels, predicted_labels)"
]
},
{
"cell_type": "markdown",
"id": "935956b7",
"id": "3e1fd6eb",
"metadata": {},
"source": [
"### Test (Should output ALL PASS)"
@ -442,8 +450,8 @@
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b77da2e8",
"execution_count": 21,
"id": "6c87f7d6",
"metadata": {},
"outputs": [
{
@ -455,6 +463,7 @@
}
],
"source": [
"assert np.allclose(m1_confusion_matrix, m2_confusion_matrix)\n",
"assert np.allclose(confusion_matrix, m1_confusion_matrix)\n",
"\n",
"print(\"ALL PASS\")"
@ -463,7 +472,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "adc0a7c7",
"id": "d9bb6316",
"metadata": {},
"outputs": [],
"source": []