Skip to content

Commit 89032a1

Browse files
committed
feat: article v2
1 parent 4b10d17 commit 89032a1

File tree

3 files changed

+42
-38
lines changed

3 files changed

+42
-38
lines changed

custom_pytorch_yolov5/custom_pytorch.md

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Using a Mask Model on OCI with YOLOv5: Custom PyTorch From Scratch
1+
# YOLOv5 and OCI: Custom PyTorch Code From Scratch
22

33
## Introduction
44

@@ -22,19 +22,17 @@ I decided to create a "modified" version of what YOLOv5 does, by taking advantag
2222

2323
I believed custom PyTorch code would be great, because simply using YOLOv5's repository didn't give you 100% flexibility and responsiveness (real-time), so I decided to __very slightly__ add some *extra* functionalities (we'll talk about them below). If you're trying to use [the standard GitHub repository for YOLOv5](https://github.com/ultralytics/yolov5), you'll find that you can use their code like [this detector](https://github.com/ultralytics/yolov5/blob/master/detect.py) to post-process video or image files. You can also use it directly with a YouTube video, and an integrated youtube downloader will download frames and process them.
2424

25-
But what is the definition of real time? I want every frame that I see in my computer, somehow (be it either a camera frame from my webcam or a YouTube video, or even my screen) to display the results of my detection immediately.
25+
But what is the definition of real time? I want every frame that I see in my computer, somehow (be it either a camera frame from my webcam or a YouTube video, or even my screen) to display the results of my detection immediately. This is why I created my own custom code to detect with PyTorch.
2626

27-
This is why I created my own custom PyTorch detector.
27+
Finally, I'd like to mention my journey through a *very painful* road of finding *a few* bugs on the Windows Operating System and trying to *virtualize* your webcam feed. There's this great plugin that could "replicate" your camera feed into a virtual version that you could use in any program (you could give your computer any program / input and feed it into the webcam stream, so that it looked like your webcam feed was coming from somewhere else), and it was really great:
2828

29-
Finally, I'd like to mention my journey through a very painful road of finding __a few__ bugs on the Windows Operating System and trying to *virtualize* your webcam feed. There's this great plugin that could "replicate" your camera feed into a virtual version that you could use in any program (you could give your computer any program / input and feed it into the webcam stream, so that it looked like your webcam feed was coming from somewhere else), and it was really great:
30-
31-
![OBS great outdated plugin](./images/obs_1.PNG)
29+
![OBS great outdated plugin](./images/obs_1.png)
3230

3331
This was an [OBS (Open Broadcaster Software)](https://obsproject.com/) plugin. OBS is the go-to program to use when you're planning to make a livestream. However, this plugin was discontinued in OBS version 28, and all problems came with this update. I prepared this bug-compilation image so you can feel the pain too:
3432

35-
![OBS bug compilation](./images/obs_errors.PNG)
33+
![OBS bug compilation](./images/obs_errors.png)
3634

37-
So, once we've established that there are various impediments that prevent us from happily developing in a stable environment, let's start implementing.
35+
So, once we've established that there are several roadblocks that prevent us from happily developing in a stable environment, we finally understand the "why" of this article. Let's begin implementing.
3836

3937
## Implementation
4038

@@ -46,54 +44,53 @@ Technical requirements are Python 3.8 or higher, and PyTorch 1.7 or higher. [Her
4644

4745
First, we will use the `argparse` to include additional parameters to our Python script. I added three optional parameters:
4846

49-
![argparse](./images/argparse.PNG)
50-
51-
> **Note**: the confidence threshold will only display detected objects if the confidence score of the model's prediction is higher than a given value (0-1).
47+
![argparse](./images/argparse.png)
48+
> **Note**: the confidence threshold will only display detected objects if the confidence score of the model's prediction is higher than the given value (0.0-1.0).
5249
53-
These parameters' default can be modified. The frequency parameter will determine how many frames to detect (e.g. a frame step). If the user specifies any number, then only one frame every N frames will be used for detection. This can be useful if you're expecting your data to be very similar within different sequential frames, as one detection of the object will probably suffice. Specifying this frame step is also beneficial to avoid cost overages when making these predictions (electricity bill beware) or OCI costs if you're using the Cloud.
50+
These argparse parameters' default values can always be modified. The _frequency_ parameter will determine how many frames to detect (e.g. a frame step). If the user specifies any number N, then only 1 frame every N frames will be used for detection. This can be useful if you're expecting your data to be very similar within sequential frames, as the detection of one object, in one frame, will suffice. Specifying this frame step is also beneficial to avoid cost overages when making these predictions (electricity bill beware, or OCI costs if you're using Oracle Cloud).
5451

55-
After this initial configuration, we're ready to load our custom model. You can find the pre-trained custom model's weights for the Mask Detection Model being featured in this article [in this link](https://www.kaggle.com/datasets/jasperan/covid-19-mask-detection?select=best.pt). You'll need to have this file within reach for your Python code to work.
52+
After this initial configuration, we're ready to load our custom model. You can find the pre-trained custom model's weights for the Mask Detection Model being featured [in this link](https://www.kaggle.com/datasets/jasperan/covid-19-mask-detection?select=best.pt). You'll need to have this file within reach of your Python coding/execution environment to work.
5653

57-
So, what we do now, is load the custom weights file:
54+
So, we now load the custom weights file:
5855

59-
![loading PyTorch model](./images/load_model.PNG)
56+
![loading PyTorch model](./images/load_model.png)
6057
> **Note**: we specify the model as a custom YOLO detector, and give it the model's weights file as input.
6158
62-
Now, we're ready to get started. We create our main loop, which infinitely gets a new image from the input (in our case, either a webcam feed or a screenshot of what we're seeing in our screen) and displays it with the bounding box detections in place:
59+
Now, we're ready to get started. We create our main loop, which constantly gets a new image from the input source (either a webcam feed or a screenshot of what we're seeing in our screen) and displays it with the bounding box detections in place:
6360

64-
![main loop](./images/main_loop.PNG)
61+
![main loop](./images/main_loop.png)
6562

6663
The most important function in that code is the `infer()` function, which returns an image and the result object, transformed into a pandas object for your convenience.
6764

68-
![infer 1](./images/infer_1.PNG)
65+
![infer 1](./images/infer_1.png)
6966

7067
In this first part of the function, we obtain a new image from the OpenCV video capture object. You can also find a similar implementation but by taking screenshots, instead of using the webcam feed in [this file](https://github.com/oracle-devrel/devo.publishing.other/custom_pytorch_yolov5/files/lightweight_screen_torch_inference.py)
7168

7269
Now, we need to pass this to our Torch model, which will return bounding box results. However, there's an important consideration to make: since we want our predictions to be as fast as possible, and we know that the mask detection model has been trained with thousands of images of all shapes and sizes, we can consider a technique which will benefit Frames Per Second (FPS) on our program: **rescaling** images into a lower resolution (since my webcam feed had 1080p resolution, and I use a 2560x1440 monitor which causes screenshots to be too detailed, more than I need).
7370
For this, I chose a `SCALE_FACTOR` variable to hold this value (between 0-1). Currently, all images will be downscaled to 640 pixels in width and the respective resolution in height, to maintain the original image's aspect ratio.
7471

75-
![infer 2](./images/infer_2.PNG)
72+
![infer 2](./images/infer_2.png)
7673

7774
Now that we have our downscaled image, we pass it to the model, and it returns the object we wanted:
7875

79-
![infer 3](./images/infer_3.PNG)
80-
> **Note**: the `size=640` option tells the model we're going to pass it images with that width, so the model will predict results from that width too.
76+
![infer 3](./images/infer_3.png)
77+
> **Note**: the `size=640` option tells the model we're going to pass it images with that width, so the model will predict results of those dimensions.
8178
82-
The last thing we do is draw the bounding boxes from
83-
84-
![infer return](./images/infer_5.PNG)
79+
The last thing we do is draw the bounding boxes that we obtained into the image, and return the image to display it later.
8580

81+
![infer return](./images/infer_5.png)
8682

8783
## 1. Sorting Detections
8884

8985
This first technique is the simplest, and can be useful to add value to the standard YOLO functionality in an unique way. The idea is to quickly manipulate the PyTorch-pandas object to sort values according to one of the columns.
9086

91-
For this, I suggest two ideas: sorting by confidence score, or by detection coordinates. To illustrate how any of these techniques are useful, let's think of the following image:
87+
For this, I suggest two ideas: sorting by confidence score, or by detection coordinates. To illustrate how any of these techniques are useful, let's look at the following image:
9288

9389
![speed figure](./images/figure_speed.png)
94-
> **Note**: this image was produced by [Muhammad Moin](https://www.linkedin.com/in/muhammad-moin-7776751a0/) and illustrates how sorting detections can be useful.
90+
> **Note**: this image illustrates how sorting detections can be useful. [(image credits)](https://www.linkedin.com/in/muhammad-moin-7776751a0/)
91+
92+
In the image above, an imaginary line is drawn between both sides of the roadway, in this case **horizontally**. Any object passing from one equator to the other in a specific direction is counted as an "inward" or "downward" vehicle. This can be achieved by specifying (x,y) bounds, and any item in the PyTorch-pandas object that surpasses it in any direction is detected.
9593

96-
In the image above, an imaginary line is drawn between both sides of the roadway, in this case _horizontally_. Any object passing from one equator to the other in a specific direction is counted as an "inward" or "downward" vehicle. This can be achieved by specifying (x,y) bounds, and any item in the PyTorch-pandas object that surpasses it in any direction is detected.
9794
For processing purposes, sorting these values from the lowest y coordinate to the highest will return all cars in-order, from top to bottom of the image, which facilitates their processing in an ordered manner.
9895

9996
### 2. Cropping & Saving Detections
@@ -105,26 +102,26 @@ It's useful because we can manipulate or use the cropped image of an object, ins
105102
2. Give this cropped image to an Optical Character Recognition (OCR) program
106103
3. Extract the text in real-time
107104

108-
![bmw license plate](./images/bmw_car.PNG)
109-
> **Note**: this is an example of a car's license plate being passed through an OCR in Keras. Credits to [Theophilebuyssens](https://medium.com/@theophilebuyssens).
105+
![bmw license plate](./images/bmw_car.png)
106+
> **Note**: this is an example of a car's license plate being passed through an OCR in Keras. [(image credits)](https://medium.com/@theophilebuyssens)
110107
111-
This approach wouldn't work if we gave the whole image to the OCR, as it wouldn't be able to distinguish small texts, unless the image's resolution was very high (which in most cases isn't the case, especially in surveillance cameras).
108+
This approach wouldn't work if we gave the whole image to the OCR, as it wouldn't be able to confidently recognize such small texts (that represent only a fraction of the screen), unless the image's resolution was very high (which in most cases isn't the case, just like in surveillance cameras).
112109

113110
To implement this, we will base everything we do on **bounding boxes**. Our PyTorch code will return an object with bounding box coordinates for detected objects (and the detection's confidence scores), and we will use this object to create newly cropped images with the bounding box sizes.
114-
> **Note**: you can always modify the range of pixels you want to crop in each image, by being either more permissive (getting extra pixels around the bounding box) or more restrictiv, removing the edges of the detected object.
111+
> **Note**: you can always modify the range of pixels you want to crop in each image, by being either more **permissive** (getting extra pixels around the bounding box) or more **restrictive**, removing the edges of the detected object.
115112
116-
An important consideration is that, since we're passing images to our model with a width of 640 pixels, we need to keep our previously-mentioned `SCALE_FACTOR` variable. The problem is that the original image has a bigger size than the downscaled image (the one we pass the model), so bounding box detections will also be downscaled. We need to multiply these detections by the scale factor in order to _draw_ these bounding boxes over the original image; and then display it:
113+
An important consideration is that, since we're passing images to our model with a width of 640 pixels, we need to keep our previously-mentioned `SCALE_FACTOR` variable. The problem is that the original image has a higher size than the downscaled image (the one we pass the model), so bounding box detection coordinates will also be downscaled. We need to multiply these detections by the scale factor in order to _draw_ these bounding boxes over the original image; and then display it:
117114

118-
![infer 4](./images/infer_4.PNG)
115+
![infer 4](./images/infer_4.png)
119116

120117
We use the `save_cropped_images()` function to save images, while also accounting for the frequency parameter we set: we'll only save the cropped detections in case the frame is one we're supposed to save.
121118
Inside this function, we will **upscale** bounding box detections. Also, we'll only save the image if the detected image is higher than (x,y) width and height:
122119

123-
![save cropped images](./images/save_cropped_images.PNG)
120+
![save cropped images](./images/save_cropped_images.png)
124121

125122
Last thing we do is save the cropped image with OpenCV:
126123

127-
![save image](./images/save_image.PNG)
124+
![save image](./images/save_image.png)
128125

129126
And we successfully implemented the functionality.
130127

@@ -136,16 +133,23 @@ This last technique we're going to learn about is very straightforward and easy
136133

137134
Depending on the problem, you may want to choose one of these two options. In our case, we'll implement the second option:
138135

139-
![draw 2](./images/draw_1.PNG)
136+
![draw 2](./images/draw_1.png)
140137
> **Note**: to implement the first option, you just need to *increment* the variable every time, instead of setting it. However, you might benefit from looking at implementations like [DeepSORT](https://github.com/ZQPei/deep_sort_pytorch) or [Zero-Shot Tracking](https://github.com/roboflow/zero-shot-object-tracking), which is able to recognize the same object/detection from sequential frames, and only count them as one; not separate entities.
141138
142139
With our newly-created global variable, we'll hold a value of our liking. For example, in the code above, I'm detecting the _`mask`_ class. Then, I just need to draw the number of detected objects with OpenCV, along with the bounding boxes on top of the original image:
143140

144-
![draw 1](./images/draw_2.PNG)
141+
![draw 1](./images/draw_2.png)
142+
143+
As an example, I created this GIF.
144+
145+
![myself](./images/myself.gif)
146+
147+
148+
Note that I tested this on my own computer with an RTX 3080, I got about 25 FPS **with no downscaling** of images (I used their standard resolutions of 2560x1440). If you use downscaling you'll notice huge improvements in FPS.
145149

146150
## Conclusions
147151

148-
I've shown three additional features not currently present in YOLO models, by just adding a Python layer to it. PyTorch's Model Hub made this possible, as well as RoboFlow (made creating and exporting the mask detection model easy).
152+
I've shown three additional features not currently present in YOLO models, by just adding a Python layer to it. PyTorch's Model Hub ultimately made this possible, as well as RoboFlow (made creating and exporting the mask detection model easy).
149153

150154
In the future, I'm planning on releasing an implementation with either DeepSORT or Zero-Shot Tracking, together with YOLO, to track objects. If you'd like to see any additional use cases or features implemented, let me know in the comments!
151155

-545 Bytes
Loading
17.8 MB
Loading

0 commit comments

Comments
 (0)