[Question 4-6] Add sklearn confusion matrix
This commit is contained in:
parent
5ae34629d2
commit
292a257ea3
|
|
@ -137,6 +137,13 @@
|
||||||
"image_transform = transforms.Compose(\n",
|
"image_transform = transforms.Compose(\n",
|
||||||
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
|
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# image_transform = transforms.Compose([\n",
|
||||||
|
"# # transforms.Resize(256),\n",
|
||||||
|
"# # transforms.CenterCrop(224),\n",
|
||||||
|
"# transforms.ToTensor(),\n",
|
||||||
|
"# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||||
|
"# ])\n",
|
||||||
|
"\n",
|
||||||
"test_data = torchvision.datasets.ImageFolder('data/EXCV10/val/', transform=image_transform)\n",
|
"test_data = torchvision.datasets.ImageFolder('data/EXCV10/val/', transform=image_transform)\n",
|
||||||
"test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)"
|
"test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)"
|
||||||
]
|
]
|
||||||
|
|
@ -208,7 +215,7 @@
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Test Accuracy: 75.8%\n"
|
"Test Accuracy: 70.05%\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -232,37 +239,25 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def test_cnn(model, test_loader):\n",
|
"def test_cnn(model, test_loader, device='cpu'):\n",
|
||||||
" \"\"\"\n",
|
" model.to(device)\n",
|
||||||
" Test the trained ResNet model on the test dataset.\n",
|
|
||||||
"\n",
|
|
||||||
" Args:\n",
|
|
||||||
" model (nn.Module): The trained ResNet model.\n",
|
|
||||||
" test_loader (DataLoader): Data loader for the test data.\n",
|
|
||||||
" \n",
|
|
||||||
" Returns:\n",
|
|
||||||
" float: Test accuracy.\n",
|
|
||||||
" list: Predicted labels.\n",
|
|
||||||
" list: True labels.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" model.eval() \n",
|
" model.eval() \n",
|
||||||
" correct = 0\n",
|
|
||||||
" total = 0\n",
|
" total = 0\n",
|
||||||
" predicted_labels = []\n",
|
" correct_num = 0\n",
|
||||||
" true_labels = []\n",
|
" all_predicted_labels = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad(): # No need to track gradients for testing\n",
|
||||||
" for images, labels in test_loader:\n",
|
" for images, labels in test_loader:\n",
|
||||||
|
" images, labels = images.to(device), labels.to(device)\n",
|
||||||
" outputs = model(images)\n",
|
" outputs = model(images)\n",
|
||||||
" _, predicted = torch.max(outputs.data, 1)\n",
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
||||||
" total += labels.size(0)\n",
|
" total += labels.size(0)\n",
|
||||||
" correct += (predicted == labels).sum().item()\n",
|
" correct_num += (predicted == labels).sum().item() \n",
|
||||||
" predicted_labels.extend(predicted.tolist())\n",
|
" all_predicted_labels.append(predicted.cpu().numpy())\n",
|
||||||
" true_labels.extend(labels.tolist())\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" accuracy = correct / total\n",
|
" accuracy = (correct_num / total) * 100\n",
|
||||||
"\n",
|
" all_predicted_labels = np.concatenate(all_predicted_labels)\n",
|
||||||
" return predicted_labels, accuracy*100"
|
" return all_predicted_labels, accuracy"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -275,7 +270,7 @@
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Test Accuracy: 75.8%\n"
|
"Test Accuracy: 70.05%\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -286,7 +281,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "17bc1948",
|
"id": "985e4f91",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Test (Should output ALL PASS)"
|
"### Test (Should output ALL PASS)"
|
||||||
|
|
@ -295,15 +290,15 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 11,
|
||||||
"id": "853c4db3",
|
"id": "694097e2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Test accuracy: 75.8\n",
|
"Test accuracy: 70.05\n",
|
||||||
"Score 100%: 15.0\n",
|
"Score 90%: 13.5\n",
|
||||||
"ALL PASS\n"
|
"ALL PASS\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -388,9 +383,32 @@
|
||||||
"m1_confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
"m1_confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "60591999",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.metrics import confusion_matrix\n",
|
||||||
|
"\n",
|
||||||
|
"def m2_compute_confusion_matrix(true, predictions):\n",
|
||||||
|
" return confusion_matrix(true, predictions)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"id": "44b131f0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"m2_confusion_matrix = m2_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "608265af",
|
"id": "dd78bea6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Put Students' implementations here"
|
"### Put Students' implementations here"
|
||||||
|
|
@ -398,43 +416,33 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 19,
|
||||||
"id": "1dce952c",
|
"id": "1dce952c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_confusion_matrix(true_labels, predicted_labels):\n",
|
"def compute_confusion_matrix(true, predictions):\n",
|
||||||
"\n",
|
" unique_labels = np.unique(np.concatenate((true, predictions)))\n",
|
||||||
" # Ensure inputs are NumPy arrays\n",
|
" confusion_matrix = np.zeros((len(unique_labels), len(unique_labels)), dtype=np.int64)\n",
|
||||||
" true_labels = np.array(true_labels)\n",
|
" for i, true_label in enumerate(unique_labels):\n",
|
||||||
" predicted_labels = np.array(predicted_labels)\n",
|
" for j, predicted_label in enumerate(unique_labels):\n",
|
||||||
"\n",
|
" confusion_matrix[i, j] = np.sum((true == true_label) & (predictions == predicted_label))\n",
|
||||||
" # Determine the number of classes\n",
|
" return confusion_matrix"
|
||||||
" num_classes = len(np.unique(true_labels))\n",
|
|
||||||
"\n",
|
|
||||||
" # Initialize the confusion matrix with zeros\n",
|
|
||||||
" cm = np.zeros((num_classes, num_classes))\n",
|
|
||||||
"\n",
|
|
||||||
" # Count occurrences of true-predicted label pairs\n",
|
|
||||||
" for i in range(len(true_labels)):\n",
|
|
||||||
" cm[true_labels[i]][predicted_labels[i]] += 1\n",
|
|
||||||
"\n",
|
|
||||||
" return cm"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 20,
|
||||||
"id": "21917014",
|
"id": "1945d637",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
"confusion_matrix = compute_confusion_matrix(true_labels, predicted_labels)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "935956b7",
|
"id": "3e1fd6eb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Test (Should output ALL PASS)"
|
"### Test (Should output ALL PASS)"
|
||||||
|
|
@ -442,8 +450,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 21,
|
||||||
"id": "b77da2e8",
|
"id": "6c87f7d6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
|
@ -455,6 +463,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"assert np.allclose(m1_confusion_matrix, m2_confusion_matrix)\n",
|
||||||
"assert np.allclose(confusion_matrix, m1_confusion_matrix)\n",
|
"assert np.allclose(confusion_matrix, m1_confusion_matrix)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"ALL PASS\")"
|
"print(\"ALL PASS\")"
|
||||||
|
|
@ -463,7 +472,7 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "adc0a7c7",
|
"id": "d9bb6316",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue