Add Question 7
This commit is contained in:
parent
292a257ea3
commit
4ee29e676b
|
|
@ -2,3 +2,4 @@
|
||||||
data/EXCV10
|
data/EXCV10
|
||||||
data/MaskedFace
|
data/MaskedFace
|
||||||
__pycache__
|
__pycache__
|
||||||
|
yolov5/
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,813 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5457f0e2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Question 7 - ResNet"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "868f9566",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import glob\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"from xml.etree import ElementTree as ET\n",
|
||||||
|
"\n",
|
||||||
|
"from torchvision import transforms, models\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c4ce3f8a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load the dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "6e215553",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"val_labels = \"./data/MaskedFace/val/labels\"\n",
|
||||||
|
"val_imgs = \"./data/MaskedFace/val/images\"\n",
|
||||||
|
"\n",
|
||||||
|
"y_true = glob.glob(os.path.join(val_labels,\"*.txt\"))\n",
|
||||||
|
"images = glob.glob(os.path.join(val_imgs,\"*.png\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "94af35ab",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test_dataset = {\n",
|
||||||
|
" 'images': images, # list of image paths\n",
|
||||||
|
" 'y_true': y_true, # list of label paths\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "d1af863d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def count_obj(txt_file, n_class):\n",
|
||||||
|
" with open(txt_file, 'r') as file:\n",
|
||||||
|
" lines = file.readlines()\n",
|
||||||
|
" # Extracting the class identifiers from each line\n",
|
||||||
|
" class_ids = [int(line.split()[0]) for line in lines]\n",
|
||||||
|
"\n",
|
||||||
|
" # Counting the occurrences of each class\n",
|
||||||
|
" class_counts = Counter(class_ids)\n",
|
||||||
|
"\n",
|
||||||
|
" # Sorting the dictionary by class id and converting it to a list of counts\n",
|
||||||
|
" sorted_counts = [class_counts[i] if i in class_counts else 0 for i in range(n_class)]\n",
|
||||||
|
" return sorted_counts"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "a9f5c65f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"85it [00:00, 96.70it/s] \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"gt_counts = []\n",
|
||||||
|
"for idx , (img , txt) in enumerate(tqdm(zip(test_dataset['images'], test_dataset['y_true']))):\n",
|
||||||
|
" # get ground truth\n",
|
||||||
|
" obj_count = count_obj(txt, 3)\n",
|
||||||
|
" gt_counts.append(obj_count)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "71f5f968",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "e70f6949",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"D:\\Anaconda3\\envs\\what\\lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"D:\\Anaconda3\\envs\\what\\lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
||||||
|
" warnings.warn(msg)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"ResNet(\n",
|
||||||
|
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||||
|
" (layer1): Sequential(\n",
|
||||||
|
" (0): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer2): Sequential(\n",
|
||||||
|
" (0): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer3): Sequential(\n",
|
||||||
|
" (0): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (layer4): Sequential(\n",
|
||||||
|
" (0): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (downsample): Sequential(\n",
|
||||||
|
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||||
|
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (1): BasicBlock(\n",
|
||||||
|
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" (relu): ReLU(inplace=True)\n",
|
||||||
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||||
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||||||
|
" (fc): Linear(in_features=512, out_features=3, bias=True)\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"class ImageDataset(Dataset):\n",
|
||||||
|
" def __init__(self, directory, transformations=None):\n",
|
||||||
|
" self.directory = directory\n",
|
||||||
|
" self.transformations = transformations\n",
|
||||||
|
" self.filenames = [file for file in os.listdir(directory) if file.endswith('.png')]\n",
|
||||||
|
" self.labels_array = np.zeros((len(self.filenames), 3), dtype=np.int64)\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.filenames)\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, index):\n",
|
||||||
|
" file_path = os.path.join(self.directory, self.filenames[index])\n",
|
||||||
|
" img = Image.open(file_path).convert('RGB')\n",
|
||||||
|
" labels = self.extract_labels(file_path.replace('.png', '.xml'))\n",
|
||||||
|
" \n",
|
||||||
|
" if self.transformations:\n",
|
||||||
|
" img = self.transformations(img)\n",
|
||||||
|
" \n",
|
||||||
|
" self.labels_array[index] = labels\n",
|
||||||
|
" return img, torch.tensor(labels, dtype=torch.float32)\n",
|
||||||
|
"\n",
|
||||||
|
" def extract_labels(self, xml_path):\n",
|
||||||
|
" xml_data = ET.parse(xml_path)\n",
|
||||||
|
" categories = {'with_mask': 0, 'without_mask': 0, 'mask_weared_incorrect': 0}\n",
|
||||||
|
" for item in xml_data.getroot().findall('object'):\n",
|
||||||
|
" categories[item.find('name').text] += 1\n",
|
||||||
|
" return list(categories.values())\n",
|
||||||
|
"\n",
|
||||||
|
"# Define image transformations\n",
|
||||||
|
"image_transforms = {\n",
|
||||||
|
" 'train': transforms.Compose([\n",
|
||||||
|
" transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),\n",
|
||||||
|
" transforms.RandomRotation(degrees=15),\n",
|
||||||
|
" transforms.ColorJitter(),\n",
|
||||||
|
" transforms.RandomHorizontalFlip(),\n",
|
||||||
|
" transforms.CenterCrop(size=224),\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
|
||||||
|
" ]),\n",
|
||||||
|
" 'val': transforms.Compose([\n",
|
||||||
|
" transforms.Resize(size=256),\n",
|
||||||
|
" transforms.CenterCrop(size=224),\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
|
||||||
|
" ])\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize the pretrained ResNet18 model and modify the fully connected layer\n",
|
||||||
|
"pretrained_model = models.resnet18(pretrained=True)\n",
|
||||||
|
"pretrained_model.fc = torch.nn.Linear(pretrained_model.fc.in_features, 3)\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the dataset and dataloaders\n",
|
||||||
|
"training_data = ImageDataset('data/MaskedFace/train', transformations=image_transforms['train'])\n",
|
||||||
|
"validation_data = ImageDataset('data/MaskedFace/val', transformations=image_transforms['val'])\n",
|
||||||
|
"\n",
|
||||||
|
"train_data_loader = DataLoader(training_data, batch_size=32, shuffle=True)\n",
|
||||||
|
"validation_data_loader = DataLoader(validation_data, batch_size=32)\n",
|
||||||
|
"\n",
|
||||||
|
"# Setup device, loss function, optimizer, and learning rate scheduler\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"pretrained_model.to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "61ad7442",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:28<00:00, 6.76s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1, Loss: 13.687, Validation Loss: 0.191\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:12<00:00, 6.02s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 2, Loss: 10.426, Validation Loss: 0.219\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:18<00:00, 6.28s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 3, Loss: 11.348, Validation Loss: 0.227\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:17<00:00, 6.26s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 4, Loss: 9.872, Validation Loss: 0.163\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:08<00:00, 5.85s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 5, Loss: 8.712, Validation Loss: 0.190\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:08<00:00, 5.84s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 6, Loss: 10.092, Validation Loss: 0.150\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:10<00:00, 5.94s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 7, Loss: 9.503, Validation Loss: 0.321\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:09<00:00, 5.88s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 8, Loss: 6.198, Validation Loss: 0.123\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:09<00:00, 5.87s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 9, Loss: 5.333, Validation Loss: 0.128\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [02:36<00:00, 7.11s/it]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 10, Loss: 4.443, Validation Loss: 0.125\n",
|
||||||
|
"Validation MAE: 12.31%\n",
|
||||||
|
"[[1 2 0]\n",
|
||||||
|
" [8 1 0]\n",
|
||||||
|
" [3 0 1]\n",
|
||||||
|
" ...\n",
|
||||||
|
" [3 0 0]\n",
|
||||||
|
" [1 0 0]\n",
|
||||||
|
" [1 1 0]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import copy\n",
|
||||||
|
"from sklearn.metrics import mean_absolute_error\n",
|
||||||
|
"\n",
|
||||||
|
"# Setup device, loss function, optimizer, and learning rate scheduler\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"pretrained_model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"loss_function = torch.nn.MSELoss()\n",
|
||||||
|
"optimizer = torch.optim.SGD(pretrained_model.parameters(), lr=0.001, momentum=0.9)\n",
|
||||||
|
"learning_rate_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to evaluate the model's performance on validation data\n",
|
||||||
|
"def evaluate_performance(model, loader):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" total_error = 0.0\n",
|
||||||
|
" for imgs, lbls in loader:\n",
|
||||||
|
" imgs, lbls = imgs.to(device), lbls.to(device)\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" predictions = model(imgs)\n",
|
||||||
|
" error = mean_absolute_error(lbls.cpu().detach().numpy(), predictions.cpu().detach().numpy(), multioutput='raw_values')\n",
|
||||||
|
" total_error += np.sum(error)\n",
|
||||||
|
" return total_error / len(loader.dataset)\n",
|
||||||
|
"\n",
|
||||||
|
"# Early stopping and model saving setup\n",
|
||||||
|
"best_model_wts = copy.deepcopy(pretrained_model.state_dict())\n",
|
||||||
|
"best_loss = float('inf')\n",
|
||||||
|
"early_stopping_patience = 3\n",
|
||||||
|
"patience_counter = 0\n",
|
||||||
|
"\n",
|
||||||
|
"# Training loop\n",
|
||||||
|
"epochs = 10\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" pretrained_model.train()\n",
|
||||||
|
" epoch_loss = 0.0\n",
|
||||||
|
" for imgs, lbls in tqdm(train_data_loader):\n",
|
||||||
|
" imgs, lbls = imgs.to(device), lbls.to(device)\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" predictions = pretrained_model(imgs)\n",
|
||||||
|
" loss = loss_function(predictions, lbls)\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" epoch_loss += loss.item()\n",
|
||||||
|
" \n",
|
||||||
|
" learning_rate_scheduler.step()\n",
|
||||||
|
" \n",
|
||||||
|
" # Validation phase\n",
|
||||||
|
" validation_loss = evaluate_performance(pretrained_model, validation_data_loader)\n",
|
||||||
|
" print(f'Epoch {epoch+1}, Loss: {epoch_loss / len(train_data_loader):.3f}, Validation Loss: {validation_loss:.3f}')\n",
|
||||||
|
" \n",
|
||||||
|
" # Check for early stopping\n",
|
||||||
|
" if validation_loss < best_loss:\n",
|
||||||
|
" best_loss = validation_loss\n",
|
||||||
|
" best_model_wts = copy.deepcopy(pretrained_model.state_dict())\n",
|
||||||
|
" torch.save(pretrained_model.state_dict(), 'best_model.pth')\n",
|
||||||
|
" \n",
|
||||||
|
"\n",
|
||||||
|
"# Load the best model weights\n",
|
||||||
|
"pretrained_model.load_state_dict(torch.load('best_model.pth'))\n",
|
||||||
|
"\n",
|
||||||
|
"# Final evaluation on the validation dataset\n",
|
||||||
|
"validation_error = evaluate_performance(pretrained_model, validation_data_loader)\n",
|
||||||
|
"print(f'Validation MAE: {validation_error * 100:.2f}%')\n",
|
||||||
|
"\n",
|
||||||
|
"# Print label counts from the training dataset\n",
|
||||||
|
"print(training_data.labels_array)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ab063eb1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Evaluate on the test set"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "19729427",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Function to evaluate the model's performance on validation data\n",
|
||||||
|
"def evaluate_performance(model, loader):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" total_error = 0.0\n",
|
||||||
|
" for imgs, lbls in loader:\n",
|
||||||
|
" imgs, lbls = imgs.to(device), lbls.to(device)\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" predictions = model(imgs)\n",
|
||||||
|
" error = mean_absolute_error(lbls.cpu().detach().numpy(), predictions.cpu().detach().numpy(), multioutput='raw_values')\n",
|
||||||
|
" print(error)\n",
|
||||||
|
" total_error += np.sum(error)\n",
|
||||||
|
" return total_error / 3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d2b3f825",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[1.5965363 1.3042079 0.25560504]\n",
|
||||||
|
"[1.8177493 1.5732876 0.45420742]\n",
|
||||||
|
"[1.9562395 1.3338923 0.17067692]\n",
|
||||||
|
"Validation MAE: 348.75%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Load the best model weights\n",
|
||||||
|
"pretrained_model.load_state_dict(torch.load('best_model.pth'))\n",
|
||||||
|
"\n",
|
||||||
|
"# Final evaluation on the validation dataset\n",
|
||||||
|
"validation_error = evaluate_performance(pretrained_model, validation_data_loader)\n",
|
||||||
|
"print(f'Validation MAE: {validation_error * 100:.2f}%')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "e893f885",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [00:11<00:00, 7.45it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"counts = []\n",
|
||||||
|
"for idx , (img , lbls) in enumerate(tqdm(validation_data)):\n",
|
||||||
|
" img, lbls = img.to(device), lbls.to(device)\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" predictions = pretrained_model(torch.unsqueeze(img, 0))[0]\n",
|
||||||
|
" counts.append(predictions.detach().numpy())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "16f48e23",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[array([14.919903 , 1.9947946 , 0.65775687], dtype=float32),\n",
|
||||||
|
" array([ 6.8552303 , -0.18041131, 0.07070862], dtype=float32),\n",
|
||||||
|
" array([ 1.0139127 , 0.2854728 , -0.08013925], dtype=float32),\n",
|
||||||
|
" array([4.438932 , 0.7808308 , 0.05505312], dtype=float32),\n",
|
||||||
|
" array([7.2354264 , 3.4551375 , 0.30724907], dtype=float32),\n",
|
||||||
|
" array([5.588563 , 0.6697209 , 0.17777884], dtype=float32),\n",
|
||||||
|
" array([ 1.150365 , 0.6162016 , -0.10613517], dtype=float32),\n",
|
||||||
|
" array([ 8.920831 , 0.6018489 , -0.06503136], dtype=float32),\n",
|
||||||
|
" array([1.95457 , 0.17450362, 0.05267046], dtype=float32),\n",
|
||||||
|
" array([2.5774434 , 0.9776695 , 0.18825674], dtype=float32),\n",
|
||||||
|
" array([ 2.6093178 , 0.48708877, -0.17551954], dtype=float32),\n",
|
||||||
|
" array([ 0.16392124, 0.5478727 , -0.19237904], dtype=float32),\n",
|
||||||
|
" array([ 4.6288757 , 0.1531632 , -0.03170557], dtype=float32),\n",
|
||||||
|
" array([ 1.8586371 , 0.6651823 , -0.02203152], dtype=float32),\n",
|
||||||
|
" array([3.2771707 , 3.1532748 , 0.14557752], dtype=float32),\n",
|
||||||
|
" array([3.8890243 , 2.2504125 , 0.05863352], dtype=float32),\n",
|
||||||
|
" array([10.092557 , 0.9448385 , 0.26969808], dtype=float32),\n",
|
||||||
|
" array([ 0.27157634, 0.17475206, -0.23231247], dtype=float32),\n",
|
||||||
|
" array([ 2.3398385 , 0.6199454 , -0.06315048], dtype=float32),\n",
|
||||||
|
" array([7.9481864 , 0.86970013, 0.3186779 ], dtype=float32),\n",
|
||||||
|
" array([ 5.4592905e+00, 3.0020913e-01, -5.3105988e-03], dtype=float32),\n",
|
||||||
|
" array([ 0.97144395, 0.82078457, -0.06586552], dtype=float32),\n",
|
||||||
|
" array([ 1.3530452 , 0.5240793 , -0.06924771], dtype=float32),\n",
|
||||||
|
" array([1.1931357 , 0.5295599 , 0.20559092], dtype=float32),\n",
|
||||||
|
" array([5.624632 , 0.8383505 , 0.37541458], dtype=float32),\n",
|
||||||
|
" array([ 0.78418005, 0.9187632 , -0.0636546 ], dtype=float32),\n",
|
||||||
|
" array([10.465925 , 1.8872681 , 0.38873053], dtype=float32),\n",
|
||||||
|
" array([ 4.920414 , 1.9515185 , -0.12888059], dtype=float32),\n",
|
||||||
|
" array([1.518215 , 1.5924859 , 0.14355288], dtype=float32),\n",
|
||||||
|
" array([6.9586325, 1.1489052, 0.2852966], dtype=float32),\n",
|
||||||
|
" array([0.5843046 , 1.45111 , 0.00412361], dtype=float32),\n",
|
||||||
|
" array([12.129912 , 1.7866051 , 0.31929207], dtype=float32),\n",
|
||||||
|
" array([38.12094 , 6.549285 , 1.1005894], dtype=float32),\n",
|
||||||
|
" array([ 1.2271879 , 0.2557486 , -0.22623575], dtype=float32),\n",
|
||||||
|
" array([-0.06689173, 0.0394736 , 0.631119 ], dtype=float32),\n",
|
||||||
|
" array([17.32966 , 2.792189 , 0.54758376], dtype=float32),\n",
|
||||||
|
" array([3.3420715 , 0.09269053, 0.02531072], dtype=float32),\n",
|
||||||
|
" array([1.5794499 , 0.42056152, 0.06615666], dtype=float32),\n",
|
||||||
|
" array([20.351597 , 3.7114801, 0.7863975], dtype=float32),\n",
|
||||||
|
" array([8.772988 , 0.9012797 , 0.20384854], dtype=float32),\n",
|
||||||
|
" array([0.8031712 , 0.46975204, 0.10056265], dtype=float32),\n",
|
||||||
|
" array([1.3446803 , 0.8946388 , 0.12165649], dtype=float32),\n",
|
||||||
|
" array([ 0.32257232, -0.06660413, -0.22496015], dtype=float32),\n",
|
||||||
|
" array([3.845796 , 0.8221053 , 0.03321841], dtype=float32),\n",
|
||||||
|
" array([ 0.7769756 , 0.30658063, -0.3144942 ], dtype=float32),\n",
|
||||||
|
" array([0.9002108 , 0.38418356, 0.25538492], dtype=float32),\n",
|
||||||
|
" array([11.137635 , 1.4070593 , 0.46713832], dtype=float32),\n",
|
||||||
|
" array([1.0896404 , 0.3867779 , 0.03269624], dtype=float32),\n",
|
||||||
|
" array([-0.29543436, 0.58017415, -0.08616602], dtype=float32),\n",
|
||||||
|
" array([4.886879 , 1.328992 , 0.08463573], dtype=float32),\n",
|
||||||
|
" array([20.802843 , 2.5175433, 0.1205664], dtype=float32),\n",
|
||||||
|
" array([4.472849 , 1.8497019 , 0.07973102], dtype=float32),\n",
|
||||||
|
" array([3.800993 , 1.2847486 , 0.40869945], dtype=float32),\n",
|
||||||
|
" array([ 3.2214005, 2.3649635, -0.05755 ], dtype=float32),\n",
|
||||||
|
" array([6.194131 , 1.039898 , 0.19118609], dtype=float32),\n",
|
||||||
|
" array([5.946366 , 1.9515687, 0.0739623], dtype=float32),\n",
|
||||||
|
" array([ 1.548485 , -0.26474452, 0.13542093], dtype=float32),\n",
|
||||||
|
" array([-0.12953067, 2.0475016 , 0.12173931], dtype=float32),\n",
|
||||||
|
" array([ 3.2755911 , 2.0698051 , -0.03214201], dtype=float32),\n",
|
||||||
|
" array([ 4.795667 , -0.3839026, -0.324237 ], dtype=float32),\n",
|
||||||
|
" array([1.4601235 , 0.9413236 , 0.15387204], dtype=float32),\n",
|
||||||
|
" array([0.60179263, 0.18167558, 0.06993645], dtype=float32),\n",
|
||||||
|
" array([2.5860176 , 0.96621907, 0.1660994 ], dtype=float32),\n",
|
||||||
|
" array([2.3293552 , 2.248715 , 0.05637825], dtype=float32),\n",
|
||||||
|
" array([1.5858288 , 0.75048965, 0.5053718 ], dtype=float32),\n",
|
||||||
|
" array([4.6874514 , 2.613487 , 0.02177998], dtype=float32),\n",
|
||||||
|
" array([ 3.015262 , 1.2428983 , -0.06558037], dtype=float32),\n",
|
||||||
|
" array([ 5.4304247 , 1.3663604 , -0.18734889], dtype=float32),\n",
|
||||||
|
" array([1.169702 , 0.29014575, 0.07055575], dtype=float32),\n",
|
||||||
|
" array([ 2.785139 , 1.7807665 , -0.14221995], dtype=float32),\n",
|
||||||
|
" array([ 6.0665565e+00, -1.1839047e-03, -2.0407777e-01], dtype=float32),\n",
|
||||||
|
" array([ 4.0390615 , 1.0952463 , -0.17736901], dtype=float32),\n",
|
||||||
|
" array([ 2.0545983 , -1.0606133 , -0.20474596], dtype=float32),\n",
|
||||||
|
" array([14.975636 , 2.6628957 , 0.41037458], dtype=float32),\n",
|
||||||
|
" array([ 1.532108 , 1.0259324 , -0.02336033], dtype=float32),\n",
|
||||||
|
" array([ 1.6325457 , 2.1987557 , -0.23485237], dtype=float32),\n",
|
||||||
|
" array([ 0.9079408 , 0.1572775 , -0.20104134], dtype=float32),\n",
|
||||||
|
" array([ 1.0071435 , 1.1668189 , -0.06868404], dtype=float32),\n",
|
||||||
|
" array([ 1.153094 , 0.40935773, -0.05768288], dtype=float32),\n",
|
||||||
|
" array([0.5880935 , 0.42007735, 0.12577775], dtype=float32),\n",
|
||||||
|
" array([8.898152 , 0.9833183 , 0.27929026], dtype=float32),\n",
|
||||||
|
" array([ 0.46698472, 0.8412469 , -0.2756693 ], dtype=float32),\n",
|
||||||
|
" array([ 2.401714 , 1.1422199 , -0.04599947], dtype=float32),\n",
|
||||||
|
" array([6.7554636 , 0.9809863 , 0.21429788], dtype=float32),\n",
|
||||||
|
" array([ 2.7404675 , 0.83549696, -0.06813517], dtype=float32)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"counts"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "97afedd6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## MAPE"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "2c935860",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def compute_mape(prediction, truth):\n",
|
||||||
|
" mape = np.mean( np.abs(truth - prediction) / np.maximum(truth, np.ones_like(truth)) ) * 100\n",
|
||||||
|
" return mape"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "ea0405a6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"MAPE = compute_mape(np.array(counts), gt_counts)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "bdda69e3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"68.38530732497205\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(MAPE)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d11e9ede",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Final Score"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "b7aaaaca",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Score: 0\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"if MAPE <= 10:\n",
|
||||||
|
" print(\"Score: \", 25*1.0)\n",
|
||||||
|
"elif MAPE <= 15:\n",
|
||||||
|
" print(\"Score: \", 25*0.875)\n",
|
||||||
|
"elif MAPE <= 20:\n",
|
||||||
|
" print(\"Score: \", 25*0.75)\n",
|
||||||
|
"elif MAPE <= 25:\n",
|
||||||
|
" print(\"Score: \", 25*0.625)\n",
|
||||||
|
"elif MAPE <= 30:\n",
|
||||||
|
" print(\"Score: \", 25*0.5)\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Score: \", 0) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0bf0f953",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "what",
|
||||||
|
"language": "python",
|
||||||
|
"name": "what"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
BIN
data/points.npy
BIN
data/points.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue