[Question 4-6] Add sklearn confusion matrix
This commit is contained in:
parent
5ae34629d2
commit
292a257ea3
|
|
@ -137,6 +137,13 @@
|
|||
"image_transform = transforms.Compose(\n",
|
||||
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
|
||||
"\n",
|
||||
"# image_transform = transforms.Compose([\n",
|
||||
"# # transforms.Resize(256),\n",
|
||||
"# # transforms.CenterCrop(224),\n",
|
||||
"# transforms.ToTensor(),\n",
|
||||
"# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||
"# ])\n",
|
||||
"\n",
|
||||
"test_data = torchvision.datasets.ImageFolder('data/EXCV10/val/', transform=image_transform)\n",
|
||||
"test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)"
|
||||
]
|
||||
|
|
@ -208,7 +215,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 75.8%\n"
|
||||
"Test Accuracy: 70.05%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
@ -232,37 +239,25 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_cnn(model, test_loader):\n",
|
||||
" \"\"\"\n",
|
||||
" Test the trained ResNet model on the test dataset.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" model (nn.Module): The trained ResNet model.\n",
|
||||
" test_loader (DataLoader): Data loader for the test data.\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" float: Test accuracy.\n",
|
||||
" list: Predicted labels.\n",
|
||||
" list: True labels.\n",
|
||||
" \"\"\"\n",
|
||||
"def test_cnn(model, test_loader, device='cpu'):\n",
|
||||
" model.to(device)\n",
|
||||
" model.eval() \n",
|
||||
" correct = 0\n",
|
||||
" total = 0\n",
|
||||
" predicted_labels = []\n",
|
||||
" true_labels = []\n",
|
||||
" correct_num = 0\n",
|
||||
" all_predicted_labels = []\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" with torch.no_grad(): # No need to track gradients for testing\n",
|
||||
" for images, labels in test_loader:\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images)\n",
|
||||
" _, predicted = torch.max(outputs.data, 1)\n",
|
||||
" total += labels.size(0)\n",
|
||||
" correct += (predicted == labels).sum().item()\n",
|
||||
" predicted_labels.extend(predicted.tolist())\n",
|
||||
" true_labels.extend(labels.tolist())\n",
|
||||
" correct_num += (predicted == labels).sum().item() \n",
|
||||
" all_predicted_labels.append(predicted.cpu().numpy())\n",
|
||||
"\n",
|
||||
" accuracy = correct / total\n",
|
||||
"\n",
|
||||
" return predicted_labels, accuracy*100"
|
||||
" accuracy = (correct_num / total) * 100\n",
|
||||
" all_predicted_labels = np.concatenate(all_predicted_labels)\n",
|
||||
" return all_predicted_labels, accuracy"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -275,7 +270,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 75.8%\n"
|
||||
"Test Accuracy: 70.05%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
@ -286,7 +281,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "17bc1948",
|
||||
"id": "985e4f91",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Test (Should output ALL PASS)"
|
||||
|
|
@ -295,15 +290,15 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "853c4db3",
|
||||
"id": "694097e2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test accuracy: 75.8\n",
|
||||
"Score 100%: 15.0\n",
|
||||
"Test accuracy: 70.05\n",
|
||||
"Score 90%: 13.5\n",
|
||||
"ALL PASS\n"
|
||||
]
|
||||
}
|
||||
|
|
@ -388,9 +383,32 @@
|
|||
"m1_confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "60591999",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.metrics import confusion_matrix\n",
|
||||
"\n",
|
||||
"def m2_compute_confusion_matrix(true, predictions):\n",
|
||||
" return confusion_matrix(true, predictions)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "44b131f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"m2_confusion_matrix = m2_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "608265af",
|
||||
"id": "dd78bea6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Put Students' implementations here"
|
||||
|
|
@ -398,43 +416,33 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 19,
|
||||
"id": "1dce952c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_confusion_matrix(true_labels, predicted_labels):\n",
|
||||
"\n",
|
||||
" # Ensure inputs are NumPy arrays\n",
|
||||
" true_labels = np.array(true_labels)\n",
|
||||
" predicted_labels = np.array(predicted_labels)\n",
|
||||
"\n",
|
||||
" # Determine the number of classes\n",
|
||||
" num_classes = len(np.unique(true_labels))\n",
|
||||
"\n",
|
||||
" # Initialize the confusion matrix with zeros\n",
|
||||
" cm = np.zeros((num_classes, num_classes))\n",
|
||||
"\n",
|
||||
" # Count occurrences of true-predicted label pairs\n",
|
||||
" for i in range(len(true_labels)):\n",
|
||||
" cm[true_labels[i]][predicted_labels[i]] += 1\n",
|
||||
"\n",
|
||||
" return cm"
|
||||
"def compute_confusion_matrix(true, predictions):\n",
|
||||
" unique_labels = np.unique(np.concatenate((true, predictions)))\n",
|
||||
" confusion_matrix = np.zeros((len(unique_labels), len(unique_labels)), dtype=np.int64)\n",
|
||||
" for i, true_label in enumerate(unique_labels):\n",
|
||||
" for j, predicted_label in enumerate(unique_labels):\n",
|
||||
" confusion_matrix[i, j] = np.sum((true == true_label) & (predictions == predicted_label))\n",
|
||||
" return confusion_matrix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "21917014",
|
||||
"execution_count": 20,
|
||||
"id": "1945d637",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"confusion_matrix = m1_compute_confusion_matrix(true_labels, m1_predicted_labels)"
|
||||
"confusion_matrix = compute_confusion_matrix(true_labels, predicted_labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "935956b7",
|
||||
"id": "3e1fd6eb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Test (Should output ALL PASS)"
|
||||
|
|
@ -442,8 +450,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "b77da2e8",
|
||||
"execution_count": 21,
|
||||
"id": "6c87f7d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
@ -455,6 +463,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"assert np.allclose(m1_confusion_matrix, m2_confusion_matrix)\n",
|
||||
"assert np.allclose(confusion_matrix, m1_confusion_matrix)\n",
|
||||
"\n",
|
||||
"print(\"ALL PASS\")"
|
||||
|
|
@ -463,7 +472,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adc0a7c7",
|
||||
"id": "d9bb6316",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
|
|
|||
Loading…
Reference in New Issue