Skip to content

Commit 96dc577

Browse files
authored
Merge pull request #3 from oracle-devrel/custom_pytorch
Custom PyTorch article
2 parents 70375ab + 12fcb10 commit 96dc577

29 files changed

+483
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# YOLOv5 and OCI: Custom PyTorch Code From Scratch
2+
3+
## Introduction
4+
5+
In this article, we're going to learn how to load a YOLOv5 model into PyTorch, and then augment the detections with three different techniques:
6+
7+
1. Sorting Detections
8+
2. Cropping and saving detections
9+
3. Counting Detected Objects
10+
11+
If you're a little confused about how we got here from the very beginning, you can check out the first and second articles (this article's predecessors) here:
12+
13+
- [Creating a CMask Detection Model on OCI with YOLOv5: Data Labeling with RoboFlow](https://medium.com/oracledevs/creating-a-cmask-detection-model-on-oci-with-yolov5-data-labeling-with-roboflow-5cff89cf9b0b)
14+
- [Creating a Mask Model on OCI with YOLOv5: Training and Real-Time Inference](https://medium.com/oracledevs/creating-a-mask-model-on-oci-with-yolov5-training-and-real-time-inference-3534c7f9eb21)
15+
16+
Additionally, [I offer a Kaggle link](https://www.kaggle.com/datasets/jasperan/covid-19-mask-detection/) where you can download the pre-trained weights file for the model itself. I appreciate any contribution that improves the model's mAP (Mean Average Precision).
17+
18+
I recommend you check all these resources out, as these previous articles will hugely clarify what we do and how we obtained the data.
19+
20+
## Why PyTorch?
21+
22+
I decided to create a "modified" version of what YOLOv5 does by taking advantage of Ultralytics' integration with PyTorch.
23+
24+
I believed custom PyTorch code would be great, because simply using YOLOv5's repository didn't give you 100% flexibility and responsiveness (real-time). I decided to __very slightly__ add some *extra* functionaly (we'll talk about them below). If you're trying to use [the standard GitHub repository for YOLOv5](https://github.com/ultralytics/yolov5), you'll find that you can use their code, like [this detector](https://github.com/ultralytics/yolov5/blob/master/detect.py), to post-process video or image files. You can also use it directly with a YouTube video, and an integrated youtube downloader will download frames and process them.
25+
26+
But, what is the definition of real-time? I want every frame that I see in my computer, be it a camera frame from my webcam or a YouTube video or even my screen, to display the results of my detection immediately. This is why I created my own custom code to detect with PyTorch.
27+
28+
Finally, I'd like to mention my journey down a *very painful* road of finding *a few* bugs on the Windows Operating System and trying to *virtualize* my webcam feed. There's this great plugin that could "replicate" your camera feed into a virtual version that you could use in any program -- you could give your computer any program/input and feed it into the webcam stream so that it looked like your webcam feed was coming from somewhere else -- and it was really great:
29+
30+
![OBS great outdated plugin](./images/obs_1.PNG)
31+
32+
This was an [OBS (Open Broadcaster Software)](https://obsproject.com/) plugin. OBS is the go-to program to use when you're planning to make a live stream. However, this plugin was discontinued in OBS version 28, and some problems came with this update. I prepared this bug-compilation image so you can feel the pain too:
33+
34+
![OBS bug compilation](./images/obs_errors.PNG)
35+
36+
So, once we've established that several roadblocks are preventing us from happily developing in a stable environment, we finally understand the "why" of this article. Let's begin implementing.
37+
38+
## Implementation
39+
40+
We are going to focus on the three problems explained in the Introduction: cropping and saving objects, counting objects, and sorting them. These techniques can be re-used in any computer vision projects, so once you understand how to implement them once, you're good to go.
41+
42+
Technical requirements are Python 3.8 or higher, and PyTorch 1.7 or higher. [Here's a list](https://github.com/oracle-devrel/devo.publishing.other/custom_pytorch_yolov5/files/requirements.txt) of the project's requirements if you want to reuse the code (which you can find in [the GitHub repository](https://github.com/oracle-devrel/devo.publishing.other/custom_pytorch_yolov5) along with everything we publish).
43+
44+
### 0. General Setup
45+
46+
First, we will use the `argparse` to include additional parameters in our Python script. I added three optional parameters:
47+
48+
![argparse](./images/argparse.PNG)
49+
50+
> **Note**: the confidence threshold will only display detected objects if the confidence score of the model's prediction is higher than the given value (0.0-1.0).
51+
52+
These argparse parameters' default values can always be modified. The _frequency_ parameter will determine how many frames to detect (e.g. a frame step). If the user specifies any number N, then only 1 frame every `N` frames will be used for detection. This can be useful if you're expecting your data to be very similar within sequential frames, as in the detection of one object, in one frame. Specifying this frame step is also beneficial to avoid cost overages when making these predictions (electricity bill beware, or OCI costs if you're using Oracle Cloud).
53+
54+
After this initial configuration, we're ready to load our custom model. You can find the pre-trained custom model's weights for the Mask Detection Model being featured [in the COVID 19 mask detection example](https://www.kaggle.com/datasets/jasperan/covid-19-mask-detection?select=best.pt). You'll need to have this file within reach of your Python coding/execution environment for it to work.
55+
56+
So, we now load the custom weights file:
57+
58+
![loading PyTorch model](./images/load_model.PNG)
59+
60+
> **Note**: we specify the model as a custom YOLO detector, and give it the model's weights file as input.
61+
62+
Now, we're ready to get started. We create our main loop, which constantly gets a new image from the input source (either a webcam feed or a screenshot of what we're seeing on our screen) and displays it with the bounding box detections in place:
63+
64+
![main loop](./images/main_loop.PNG)
65+
66+
The most important function in that code is the `infer()` function, which returns an image and the result object, transformed into a pandas object for your convenience.
67+
68+
![infer 1](./images/infer_1.PNG)
69+
70+
In this first part of the function, we obtain a new image from the OpenCV video capture object. You can also find a similar implementation but by taking screenshots, instead of using the webcam feed in [this file](https://github.com/oracle-devrel/devo.publishing.other/custom_pytorch_yolov5/files/lightweight_screen_torch_inference.py)
71+
72+
Now, we need to pass this to our Torch model, which will return bounding box results. However, there's an important consideration to make: since we want our predictions to be as fast as possible, and we know that the mask detection model has been trained with thousands of images of all shapes and sizes, we can consider a technique which will benefit Frames Per Second (FPS) on our program: **rescaling** images into a lower resolution (since my webcam feed had 1080p resolution, and I use a 2560x1440 monitor which causes screenshots to be more detailed than I need).
73+
74+
For this, I chose a `SCALE_FACTOR` variable to hold this value (between 0-1). Currently, all images will be downscaled to 640 pixels in width and the respective resolution in height, to maintain the original image's aspect ratio.
75+
76+
![infer 2](./images/infer_2.PNG)
77+
78+
Now that we have our downscaled image, we pass it to the model, and it returns the object we wanted:
79+
80+
![infer 3](./images/infer_3.PNG)
81+
82+
> **Note**: the `size=640` option tells the model we're going to pass images with that width, so the model will predict the results of those dimensions.
83+
84+
The last thing we do is draw the bounding boxes that we obtained into the image, and return the image to display it later.
85+
86+
![infer return](./images/infer_5.PNG)
87+
88+
## 1. Sorting Detections
89+
90+
This first technique is the simplest, and it can be useful to add value to the standard YOLO functionality in a unique way. The idea is to quickly manipulate the PyTorch-pandas object to sort values according to one of the columns.
91+
92+
For this, I have two suggestions: sorting by confidence score, or by detection coordinates. To illustrate how any of these techniques are useful, let's look at the following image:
93+
94+
![speed figure](./images/figure_speed.png)
95+
96+
> **Note**: this image illustrates how sorting detections can be useful. [(image credits)](https://www.linkedin.com/in/muhammad-moin-7776751a0/)
97+
98+
In the image above, an imaginary line is drawn between both sides of the roadway -- in this case, **horizontally**. Any object passing from one equator to the other in a specific direction is counted as an "inward" or "downward" vehicle. This can be achieved by specifying (x,y) bounds, and any item in the PyTorch-pandas object that surpasses it in any direction is detected.
99+
100+
For processing purposes, sorting these values from the lowest to highest y coordinate will return all cars in-order, from the top to the bottom of the image, which facilitates their processing in an ordered manner.
101+
102+
### 2. Cropping & Saving Detections
103+
104+
With this second technique, we just need to crop the detected area we want and save it.
105+
106+
It's useful because we can manipulate or use the cropped image of an object, instead of the whole image instead, and use this to our advantage. As an example, you could:
107+
108+
1. Crop images where you detect text
109+
2. Give this cropped image to an Optical Character Recognition (OCR) program
110+
3. Extract the text in real-time
111+
112+
![bmw license plate](./images/bmw_car.png)
113+
114+
> **Note**: this is an example of a car's license plate being passed through an OCR in Keras. [(image credits)](https://medium.com/@theophilebuyssens)
115+
116+
This approach wouldn't work if we gave the whole image to the OCR, as it wouldn't be able to confidently recognize such small text (which represents only a fraction of the screen), unless the image resolution was very high (which often isn't the case, just like in surveillance cameras).
117+
118+
To implement this, we will base everything we do on **bounding boxes**. Our PyTorch code will return an object with bounding box coordinates for detected objects (and the detection's confidence scores), and we will use this object to create newly cropped images with the bounding box sizes.
119+
120+
> **Note**: you can always modify the range of pixels you want to crop in each image, by being either more *permissive* (getting extra pixels around the bounding box) or more *restrictive*, removing the edges of the detected object.
121+
122+
An important consideration is that, since we're passing images to our model with a width of 640 pixels, we need to keep our previously-mentioned `SCALE_FACTOR` variable. The problem is that the original image has a higher size than the downscaled image (the one we pass the model), so bounding box detection coordinates will also be downscaled. We need to multiply these detections by the scale factor in order to _draw_ these bounding boxes over the original image, and then display it:
123+
124+
![infer 4](./images/infer_4.PNG)
125+
126+
We use the `save_cropped_images()` function to save images, while also accounting for the frequency parameter we set: we'll only save the cropped detections in case the frame is one we're supposed to save.
127+
128+
Inside this function, we will **upscale** bounding box detections. Also, we'll only save the image if the detected image is higher than (x,y) width and height:
129+
130+
![save cropped images](./images/save_cropped_images.PNG)
131+
132+
The last thing we do is save the cropped image with OpenCV:
133+
134+
![save image](./images/save_image.PNG)
135+
136+
Now we've successfully implemented the functionality.
137+
138+
## 3. Counting Detected Objects
139+
140+
This last technique we're going to learn about is the most straightforward and easy to implement: since we want to count the number of detected objects on the screen, we need to use a global variable (in memory) or a database of some sort to store this variable. We can either design the variable to either:
141+
142+
1. Always increment, and keep a global value of all detected objects since we started executing our Python program
143+
2. Only hold the value of currently detected objects in the screen
144+
145+
Depending on the problem, you may want to choose one of these two options. In our case, we'll implement the second option:
146+
147+
![draw 2](./images/draw_1.PNG)
148+
149+
> **Note**: to implement the first option, you just need to *increment* the variable every time, instead of setting it. However, you might benefit from looking at implementations like [DeepSORT](https://github.com/ZQPei/deep_sort_pytorch) or [Zero-Shot Tracking](https://github.com/roboflow/zero-shot-object-tracking), which can recognize the same object/detection from sequential frames, and only count them as one, not separate entities.
150+
151+
With our newly-created global variable, we'll hold a value of our liking. For example, in the code above, I'm detecting the _`mask`_ class. Then, I just need to draw the number of detected objects with OpenCV, along with the bounding boxes on top of the original image:
152+
153+
![draw 1](./images/draw_2.PNG)
154+
155+
As an example, I created this GIF.
156+
157+
![myself](./images/myself.gif)
158+
159+
Note that I tested this on my own computer with an RTX 3080 graphics card, and I got about 25 FPS *with no downscaling* of images (I used my screen's standard resolution of 2560x1440). If you use downscaling on top of this, you would notice huge improvements in FPS.
160+
161+
## Conclusions
162+
163+
I've shown three additional features not currently present in YOLO models just by adding a Python layer to them. PyTorch's Model Hub ultimately made this possible, as well as RoboFlow (which made creating and exporting the mask detection model easy).
164+
165+
In the future, I'm planning on releasing an implementation with either DeepSORT or Zero-Shot Tracking, together with YOLO, to track objects. If you'd like to see any additional use cases or features implemented, let me know in the comments!
166+
167+
If you're curious about the goings-on of Oracle Developers in their natural habitat like me, come join us [on our public Slack channel!](https://bit.ly/odevrel_slack) We don't mind being your fish bowl 🐠.
168+
169+
Stay tuned...
170+
171+
## Acknowledgments
172+
173+
* **Author** - [Nacho Martinez](https://www.linkedin.com/in/ignacio-g-martinez/), Data Science Advocate @ Oracle Developer Relations
174+
* **Last Updated By/Date** - February 8th, 2023
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
from PIL import ImageGrab
3+
import argparse
4+
import time
5+
import cv2
6+
import numpy as np
7+
8+
# parse arguments for different execution modes.
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('-m', '--model', help='Model path',
11+
type=str,
12+
required=True)
13+
parser.add_argument('-d', '--detect', help='Detection mode (league / screen)',
14+
choices=['league', 'screenshot'],
15+
default='screenshot',
16+
type=str,
17+
required=False
18+
)
19+
20+
args = parser.parse_args()
21+
22+
23+
# Model
24+
model = torch.hub.load('ultralytics/yolov5',
25+
'custom',
26+
path=args.model,
27+
force_reload=False)
28+
29+
30+
def draw_over_image(img, df):
31+
32+
draw_color = (255, 255, 255)
33+
yellow = (128, 128, 0)
34+
green = (0, 255, 0)
35+
red = (255, 0, 0)
36+
for idx, row in df.iterrows():
37+
# FONT_HERSHEY_SIMPLEX
38+
if row['name'] == 'mask':
39+
draw_color = green
40+
elif row['name'] == 'incorrect':
41+
draw_color = yellow
42+
else:
43+
draw_color = red
44+
img = cv2.rectangle(img=img, pt1=(int(row['xmin']), int(row['ymin'])),
45+
pt2=(int(row['xmax']), int(row['ymax'])),
46+
color=draw_color,
47+
thickness=5
48+
)
49+
50+
cv2.putText(img, row['name'], (int(row['xmin'])-10, int(row['ymin'])-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=draw_color, thickness=2
51+
)
52+
53+
cv2.putText(img, row['name'], (int(row['xmin'])-10, int(row['ymin'])-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=draw_color, thickness=2
54+
)
55+
56+
return img
57+
58+
# Main loop; infers sequentially until you press "q"
59+
while True:
60+
61+
# Image
62+
if args.detect == 'league':
63+
im = ImageGrab.grab(bbox=(2140+100, 1030+100, 2560-100, 1440-100)) # bbox=(2140, 1030, 2560, 1440))
64+
else:
65+
im = ImageGrab.grab() # take a screenshot
66+
67+
img = np.array(im)
68+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69+
70+
#img = cv2.resize(img, (1280, 1024))
71+
72+
# Inference
73+
results = model(img)
74+
# Capture start time to calculate fps
75+
start = time.time()
76+
77+
print(results.pandas().xyxy[0])
78+
79+
#results.show()
80+
81+
82+
83+
cv2.imshow('Image', draw_over_image(img, results.pandas().xyxy[0]))
84+
key = cv2.waitKey(30)
85+
if key == ord('q'):
86+
cv2.destroyAllWindows()
87+
break
88+
89+
# Print frames per second
90+
print('{} fps'.format(1/(time.time()-start)))

0 commit comments

Comments
 (0)