|
| 1 | +#include <TensorFlowLite.h> |
| 2 | +#include "tensorflow/lite/micro/all_ops_resolver.h" |
| 3 | +#include "tensorflow/lite/micro/micro_error_reporter.h" |
| 4 | +#include "tensorflow/lite/micro/micro_interpreter.h" |
| 5 | +#include "tensorflow/lite/schema/schema_generated.h" |
| 6 | +#include "tensorflow/lite/version.h" |
| 7 | +#include "model.h" |
| 8 | +#include "mnist.h" |
| 9 | +#include"TFT_eSPI.h" |
| 10 | + |
| 11 | +TFT_eSPI tft; |
| 12 | +TFT_eSprite spr = TFT_eSprite(&tft); // Sprite |
| 13 | + |
| 14 | +namespace { |
| 15 | +// Globals |
| 16 | +const tflite::Model* model = nullptr; |
| 17 | +tflite::MicroInterpreter* interpreter = nullptr; |
| 18 | +tflite::ErrorReporter* reporter = nullptr; |
| 19 | +TfLiteTensor* input = nullptr; |
| 20 | +TfLiteTensor* output = nullptr; |
| 21 | +constexpr int kTensorArenaSize = 8000; // Just pick a big enough number |
| 22 | +uint8_t tensor_arena[ kTensorArenaSize ] = { 0 }; |
| 23 | +float *input_buffer=nullptr; |
| 24 | +} // namespace |
| 25 | + |
| 26 | + |
| 27 | +float display_buffer[28224]; |
| 28 | + |
| 29 | +void bitmap_to_float_array( float* dest, const unsigned char* bitmap ) { // Populate input_vec with the monochrome 1bpp bitmap |
| 30 | + int pixel = 0; |
| 31 | + for( int y = 0; y < 28; y++ ) { |
| 32 | + for( int x = 0; x < 28; x++ ) { |
| 33 | + int B = x / 8; // the Byte # of the row |
| 34 | + int b = x % 8; // the Bit # of the Byte |
| 35 | + dest[ pixel ] = ( bitmap[ y * 4 + B ] >> ( 7 - b ) ) & |
| 36 | + 0x1 ? 1.0f : 0.0f; |
| 37 | + pixel++; |
| 38 | + } |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +void draw_input_buffer() { |
| 43 | + |
| 44 | + int pre_i, pre_j, after_i, after_j;//缩放前后对应的像素点坐标 |
| 45 | + |
| 46 | + for (int i = 0; i<168; i++) |
| 47 | + { |
| 48 | + for (int j = 0; j<168; j++) |
| 49 | + { |
| 50 | + after_i = i; |
| 51 | + after_j = j; |
| 52 | + pre_i = (int)(after_i / 6);/////取整,插值方法为:最邻近插值(近邻取样法) |
| 53 | + pre_j = (int)(after_j / 6); |
| 54 | + if (pre_i >= 0 && pre_i < 28 && pre_j >= 0 && pre_j < 28)//在原图范围内 |
| 55 | + *(display_buffer + i * 168 + j) = *(input_buffer + pre_i * 28 + pre_j); |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + for(int cy = 0; cy < 168; cy++) |
| 60 | + { |
| 61 | + for(int cx = 0; cx < 168; cx++) |
| 62 | + { |
| 63 | + |
| 64 | + tft.drawPixel( cx + 76, cy, display_buffer[ cy * 168 + cx ] > 0 ? 0xFFFFFFFF : 0xFF000000 ); |
| 65 | + |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + |
| 70 | +} |
| 71 | + |
| 72 | +void print_input_buffer() { |
| 73 | + char output[ 28 * 29 ]; // Each row should end row newline |
| 74 | + for( int y = 0; y < 28; y++ ) { |
| 75 | + for( int x = 0; x < 28; x++ ) { |
| 76 | + output[ y * 29 + x ] = input_buffer[ y * 28 + x ] > 0 ? ' ' : '#'; |
| 77 | + } |
| 78 | + output[ y * 29 + 28 ] = '\n'; |
| 79 | + } |
| 80 | + reporter->Report( output ); |
| 81 | +} |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | +void setup() { |
| 86 | + // Load Model |
| 87 | + //Serial.begin(115200); |
| 88 | + static tflite::MicroErrorReporter error_reporter; |
| 89 | + reporter = &error_reporter; |
| 90 | + reporter->Report( "Let's use AI to recognize some numbers!" ); |
| 91 | + |
| 92 | + model = tflite::GetModel( tf_model ); |
| 93 | + if( model->version() != TFLITE_SCHEMA_VERSION ) { |
| 94 | + reporter->Report( "Model is schema version: %d\nSupported schema version is: %d", model->version(), TFLITE_SCHEMA_VERSION ); |
| 95 | + return; |
| 96 | + } |
| 97 | + // Setup our TF runner |
| 98 | + static tflite::AllOpsResolver resolver; |
| 99 | + static tflite::MicroInterpreter static_interpreter( |
| 100 | + model, resolver, tensor_arena, kTensorArenaSize, reporter ); |
| 101 | + interpreter = &static_interpreter; |
| 102 | + |
| 103 | + // Allocate memory from the tensor_arena for the model's tensors. |
| 104 | + TfLiteStatus allocate_status = interpreter->AllocateTensors(); |
| 105 | + if( allocate_status != kTfLiteOk ) { |
| 106 | + reporter->Report( "AllocateTensors() failed" ); |
| 107 | + return; |
| 108 | + } |
| 109 | + |
| 110 | + // Obtain pointers to the model's input and output tensors. |
| 111 | + input = interpreter->input(0); |
| 112 | + output = interpreter->output(0); |
| 113 | + |
| 114 | + // Save the input buffer to put our MNIST images into |
| 115 | + input_buffer = input->data.f; |
| 116 | + tft.begin(); |
| 117 | + tft.setRotation(3); |
| 118 | + tft.fillScreen(TFT_BLACK); // fills entire the screen with colour red |
| 119 | + tft.setTextColor(TFT_RED); |
| 120 | + tft.setTextSize(2); |
| 121 | + tft.drawString("It looks like the number:",0,200);//prints string at (70,80) |
| 122 | + TFT_eSprite spr = TFT_eSprite(&tft); // Sprite |
| 123 | + |
| 124 | + |
| 125 | +} |
| 126 | + |
| 127 | +void loop() { |
| 128 | + // Pick a random test image for input |
| 129 | + const int num_test_images = ( sizeof( test_images ) / sizeof( *test_images ) ); |
| 130 | + |
| 131 | + |
| 132 | + bitmap_to_float_array( input_buffer, test_images[ rand() % num_test_images ] ); |
| 133 | + |
| 134 | + draw_input_buffer(); |
| 135 | + //print_input_buffer(); |
| 136 | + // Run our model |
| 137 | + TfLiteStatus invoke_status = interpreter->Invoke(); |
| 138 | + if( invoke_status != kTfLiteOk ) { |
| 139 | + reporter->Report( "Invoke failed" ); |
| 140 | + return; |
| 141 | + } |
| 142 | + |
| 143 | + float* result = output->data.f; |
| 144 | + |
| 145 | + reporter->Report( "It looks like the number: %d", std::distance( result, std::max_element( result, result + 10 ) ) ); |
| 146 | +// tft.drawNumber(std::distance( result, std::max_element( result, result + 10 ) ), 300, 200), |
| 147 | + spr.createSprite(20, 20); |
| 148 | +// spr.fillSprite(TFT_RED); |
| 149 | + spr.setTextColor(TFT_RED); |
| 150 | + spr.setTextSize(2); |
| 151 | + spr.drawNumber(std::distance( result, std::max_element( result, result + 10 ) ),0, 0); |
| 152 | + spr.pushSprite(300, 200); |
| 153 | + spr.deleteSprite(); |
| 154 | + // Wait 1-sec til before running again |
| 155 | + delay( 200 ); |
| 156 | + |
| 157 | +} |
0 commit comments