Skip to content

Commit 8e8a543

Browse files
committed
Add WioTerminal_TF_MNIST example
1 parent 8ea9906 commit 8e8a543

File tree

5 files changed

+22049
-2
lines changed

5 files changed

+22049
-2
lines changed

.travis.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ before_install:
2424
cd "${HOME}/Arduino/libraries/${repositoryName}"
2525
cd "${TRAVIS_BUILD_DIR}"
2626
}
27-
- buildExampleSketch() { arduino-cli compile --warnings all --fqbn $BOARD $PWD/examples/$1 --verbose; }
28-
- buildExampleUtilitySketch() { arduino-cli compile --warnings all --fqbn $BOARD $PWD/examples/utility/$1 --verbose; }
27+
- buildExampleSketch() { arduino-cli compile --warnings all --fqbn $BOARD $PWD/examples/$1 ; }
28+
- buildExampleUtilitySketch() { arduino-cli compile --warnings all --fqbn $BOARD $PWD/examples/utility/$1 ; }
2929

3030
install:
3131
- mkdir -p $HOME/Arduino/libraries
@@ -139,6 +139,7 @@ script:
139139
- export BOARD=Seeeduino:samd:seeed_XIAO_m0
140140
- buildExampleSketch SeeeduinoXIAO_SPISlave;
141141

142+
142143
notifications:
143144
webhooks:
144145
urls:

examples/WioTerminal_TF_MNIST/TF_MNIST.ipynb

Lines changed: 355 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include <TensorFlowLite.h>
2+
#include "tensorflow/lite/micro/all_ops_resolver.h"
3+
#include "tensorflow/lite/micro/micro_error_reporter.h"
4+
#include "tensorflow/lite/micro/micro_interpreter.h"
5+
#include "tensorflow/lite/schema/schema_generated.h"
6+
#include "tensorflow/lite/version.h"
7+
#include "model.h"
8+
#include "mnist.h"
9+
#include"TFT_eSPI.h"
10+
11+
TFT_eSPI tft;
12+
TFT_eSprite spr = TFT_eSprite(&tft); // Sprite
13+
14+
namespace {
15+
// Globals
16+
const tflite::Model* model = nullptr;
17+
tflite::MicroInterpreter* interpreter = nullptr;
18+
tflite::ErrorReporter* reporter = nullptr;
19+
TfLiteTensor* input = nullptr;
20+
TfLiteTensor* output = nullptr;
21+
constexpr int kTensorArenaSize = 8000; // Just pick a big enough number
22+
uint8_t tensor_arena[ kTensorArenaSize ] = { 0 };
23+
float *input_buffer=nullptr;
24+
} // namespace
25+
26+
27+
float display_buffer[28224];
28+
29+
void bitmap_to_float_array( float* dest, const unsigned char* bitmap ) { // Populate input_vec with the monochrome 1bpp bitmap
30+
int pixel = 0;
31+
for( int y = 0; y < 28; y++ ) {
32+
for( int x = 0; x < 28; x++ ) {
33+
int B = x / 8; // the Byte # of the row
34+
int b = x % 8; // the Bit # of the Byte
35+
dest[ pixel ] = ( bitmap[ y * 4 + B ] >> ( 7 - b ) ) &
36+
0x1 ? 1.0f : 0.0f;
37+
pixel++;
38+
}
39+
}
40+
}
41+
42+
void draw_input_buffer() {
43+
44+
int pre_i, pre_j, after_i, after_j;//缩放前后对应的像素点坐标
45+
46+
for (int i = 0; i<168; i++)
47+
{
48+
for (int j = 0; j<168; j++)
49+
{
50+
after_i = i;
51+
after_j = j;
52+
pre_i = (int)(after_i / 6);/////取整,插值方法为:最邻近插值(近邻取样法)
53+
pre_j = (int)(after_j / 6);
54+
if (pre_i >= 0 && pre_i < 28 && pre_j >= 0 && pre_j < 28)//在原图范围内
55+
*(display_buffer + i * 168 + j) = *(input_buffer + pre_i * 28 + pre_j);
56+
}
57+
}
58+
59+
for(int cy = 0; cy < 168; cy++)
60+
{
61+
for(int cx = 0; cx < 168; cx++)
62+
{
63+
64+
tft.drawPixel( cx + 76, cy, display_buffer[ cy * 168 + cx ] > 0 ? 0xFFFFFFFF : 0xFF000000 );
65+
66+
}
67+
}
68+
69+
70+
}
71+
72+
void print_input_buffer() {
73+
char output[ 28 * 29 ]; // Each row should end row newline
74+
for( int y = 0; y < 28; y++ ) {
75+
for( int x = 0; x < 28; x++ ) {
76+
output[ y * 29 + x ] = input_buffer[ y * 28 + x ] > 0 ? ' ' : '#';
77+
}
78+
output[ y * 29 + 28 ] = '\n';
79+
}
80+
reporter->Report( output );
81+
}
82+
83+
84+
85+
void setup() {
86+
// Load Model
87+
//Serial.begin(115200);
88+
static tflite::MicroErrorReporter error_reporter;
89+
reporter = &error_reporter;
90+
reporter->Report( "Let's use AI to recognize some numbers!" );
91+
92+
model = tflite::GetModel( tf_model );
93+
if( model->version() != TFLITE_SCHEMA_VERSION ) {
94+
reporter->Report( "Model is schema version: %d\nSupported schema version is: %d", model->version(), TFLITE_SCHEMA_VERSION );
95+
return;
96+
}
97+
// Setup our TF runner
98+
static tflite::AllOpsResolver resolver;
99+
static tflite::MicroInterpreter static_interpreter(
100+
model, resolver, tensor_arena, kTensorArenaSize, reporter );
101+
interpreter = &static_interpreter;
102+
103+
// Allocate memory from the tensor_arena for the model's tensors.
104+
TfLiteStatus allocate_status = interpreter->AllocateTensors();
105+
if( allocate_status != kTfLiteOk ) {
106+
reporter->Report( "AllocateTensors() failed" );
107+
return;
108+
}
109+
110+
// Obtain pointers to the model's input and output tensors.
111+
input = interpreter->input(0);
112+
output = interpreter->output(0);
113+
114+
// Save the input buffer to put our MNIST images into
115+
input_buffer = input->data.f;
116+
tft.begin();
117+
tft.setRotation(3);
118+
tft.fillScreen(TFT_BLACK); // fills entire the screen with colour red
119+
tft.setTextColor(TFT_RED);
120+
tft.setTextSize(2);
121+
tft.drawString("It looks like the number:",0,200);//prints string at (70,80)
122+
TFT_eSprite spr = TFT_eSprite(&tft); // Sprite
123+
124+
125+
}
126+
127+
void loop() {
128+
// Pick a random test image for input
129+
const int num_test_images = ( sizeof( test_images ) / sizeof( *test_images ) );
130+
131+
132+
bitmap_to_float_array( input_buffer, test_images[ rand() % num_test_images ] );
133+
134+
draw_input_buffer();
135+
//print_input_buffer();
136+
// Run our model
137+
TfLiteStatus invoke_status = interpreter->Invoke();
138+
if( invoke_status != kTfLiteOk ) {
139+
reporter->Report( "Invoke failed" );
140+
return;
141+
}
142+
143+
float* result = output->data.f;
144+
145+
reporter->Report( "It looks like the number: %d", std::distance( result, std::max_element( result, result + 10 ) ) );
146+
// tft.drawNumber(std::distance( result, std::max_element( result, result + 10 ) ), 300, 200),
147+
spr.createSprite(20, 20);
148+
// spr.fillSprite(TFT_RED);
149+
spr.setTextColor(TFT_RED);
150+
spr.setTextSize(2);
151+
spr.drawNumber(std::distance( result, std::max_element( result, result + 10 ) ),0, 0);
152+
spr.pushSprite(300, 200);
153+
spr.deleteSprite();
154+
// Wait 1-sec til before running again
155+
delay( 200 );
156+
157+
}

0 commit comments

Comments
 (0)