Berry: add int8 quantisation to Tensorflow lite driver (#21763)

* add input quantization, minor fixes

* prevent divideByZero
This commit is contained in:
Christian Baars 2024-07-11 23:05:47 +02:00 committed by GitHub
parent 23de275dbe
commit 13330eb085
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 52 deletions

View File

@ -21,8 +21,8 @@ BE_FUNC_CTYPE_DECLARE(be_TFL_begin, "b", "@s[(bytes)~]");
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena); extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]"); BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size); extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8);
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~"); BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~[b]");
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size); extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~"); BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");

View File

@ -70,7 +70,7 @@ struct TFL_mic_ctx_t{
struct { struct {
uint32_t is_audio_initialized:1; uint32_t is_audio_initialized:1;
uint32_t is_first_time:1; uint32_t is_first_time:1;
uint32_t new_feature_data:1; // uint32_t new_feature_data:1;
uint32_t continue_audio_capture:1; uint32_t continue_audio_capture:1;
uint32_t stop_audio_capture:1; uint32_t stop_audio_capture:1;
uint32_t audio_capture_ended:1; uint32_t audio_capture_ended:1;
@ -102,7 +102,7 @@ struct TFL_mic_ctx_t{
struct TFL_stats_t{ struct TFL_stats_t{
uint32_t model_size = 0; uint32_t model_size = 0;
uint32_t used_arena_bytes = 0; uint32_t used_arena_bytes = 0;
uint32_t invokations = 0; uint32_t invocations = 0;
uint32_t loop_task_free_stack_bytes = 0; uint32_t loop_task_free_stack_bytes = 0;
uint32_t mic_task_free_stack_bytes = 0; uint32_t mic_task_free_stack_bytes = 0;
}; };
@ -114,7 +114,7 @@ TfLiteTensor* output = nullptr;
int8_t *berry_output_buf = nullptr; int8_t *berry_output_buf = nullptr;
size_t berry_output_bufsize; size_t berry_output_bufsize;
int TensorArenaSize = 2000; int TensorArenaSize = 2000;
uint8_t max_invocations; // max. invocations per second uint8_t max_invocations = 4; // max. invocations per second
TaskHandle_t loop_task = nullptr; TaskHandle_t loop_task = nullptr;
// QueueHandle_t loop_queue = nullptr; // QueueHandle_t loop_queue = nullptr;
@ -127,9 +127,8 @@ union{
// uint32_t stop_loop:1; // uint32_t stop_loop:1;
uint32_t loop_ended:1; uint32_t loop_ended:1;
uint32_t unread_output:1; uint32_t unread_output:1;
uint32_t use_cam:1; uint32_t new_input_data:1;
uint32_t use_mic:1; uint32_t use_mic:1;
uint32_t mode_inference:1;
} option; } option;
uint32_t options; uint32_t options;
}; };
@ -176,13 +175,6 @@ bool TFL_create_task(){
return btrue; return btrue;
} }
bool TFL_init_CAM(){
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam not implemented yet"));
delete TFL;
TFL = nullptr;
return bfalse;
}
#ifdef USE_I2S #ifdef USE_I2S
/** /**
* @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing. * @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing.
@ -304,10 +296,10 @@ void TFL_capture_samples(void* arg) {
/* read slice data at once from i2s */ /* read slice data at once from i2s */
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride)); // i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride)); esp_err_t err = i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
if (bytes_read <= 0) { if (bytes_read <= 0) {
MicroPrintf( PSTR( "Error in I2S read : %d"), bytes_read); MicroPrintf( PSTR( "Error %d in I2S, did read %d bytes"), err, bytes_read);
} }
else { else {
if (bytes_read < i2s_bytes_to_read) { if (bytes_read < i2s_bytes_to_read) {
@ -328,7 +320,7 @@ void TFL_capture_samples(void* arg) {
if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){ if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){
TFL->mic->feature_buffer_idx = 0; TFL->mic->feature_buffer_idx = 0;
} }
TFL->mic->flag.new_feature_data = 1; TFL->option.new_input_data = 1;
xSemaphoreGive(TFL->mic->feature_buffer_mutex); xSemaphoreGive(TFL->mic->feature_buffer_mutex);
} }
@ -473,7 +465,6 @@ void TFL_task_loop(void *pvParameters){
TickType_t xLastWakeTime = xTaskGetTickCount(); TickType_t xLastWakeTime = xTaskGetTickCount();
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL); TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
bool do_invokation = true;
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){ while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
MicroPrintf(PSTR("delay_next_invocation")); MicroPrintf(PSTR("delay_next_invocation"));
vTaskDelay(10/ portTICK_PERIOD_MS); vTaskDelay(10/ portTICK_PERIOD_MS);
@ -483,16 +474,13 @@ void TFL_task_loop(void *pvParameters){
#ifdef USE_I2S #ifdef USE_I2S
if(TFL->option.use_mic == 1){ if(TFL->option.use_mic == 1){
TFL->option.delay_next_invocation = 0; // Clean up later TFL->option.delay_next_invocation = 0; // Clean up later
if (TFL->mic->flag.new_feature_data == 0){
do_invokation = false;
}
if(TFL->mic->flag.continue_audio_capture == 1){ if(TFL->mic->flag.continue_audio_capture == 1){
TFL_mic_feature_buf_to_input(); TFL_mic_feature_buf_to_input();
} }
} }
#endif #endif
if(do_invokation){ if(TFL->option.new_input_data){
// MicroPrintf(PSTR("Invokation requested")); // MicroPrintf(PSTR("invocation requested"));
TFL->option.running_invocation = 1; TFL->option.running_invocation = 1;
int invoke_status = interpreter.Invoke(); int invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) { if (invoke_status != kTfLiteOk) {
@ -502,15 +490,10 @@ void TFL_task_loop(void *pvParameters){
if(TFL->berry_output_buf != nullptr){ if(TFL->berry_output_buf != nullptr){
memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize); memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize);
} }
TFL->stats->invokations++; TFL->stats->invocations++;
TFL->option.unread_output = 1; TFL->option.unread_output = 1;
#ifdef USE_I2S TFL->option.running_invocation = 0;
if(TFL->option.use_mic == 1) { TFL->option.new_input_data == 0;
TFL->mic->flag.new_feature_data = 0;
}
TFL->option.running_invocation = 0;
#endif //USE_I2S
} }
if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit
} }
@ -535,7 +518,7 @@ extern "C" {
* @brief Create a context for a tensor flow session, that will later run in a task * @brief Create a context for a tensor flow session, that will later run in a task
* *
* @param vm * @param vm
* @param type BUF - generic byte buffer, CAM - webcam input, MIC - microphone input * @param type BUF - generic byte buffer, MIC - microphone input
* @return btrue * @return btrue
* @return bfalse * @return bfalse
*/ */
@ -549,21 +532,12 @@ extern "C" {
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted")); AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted"));
return btrue; return btrue;
} }
TFL = new TFL_ctx_t; TFL = new TFL_ctx_t{};
TFL->options = 0; TFL->options = 0;
if(*(uint32_t*)type == 0x00465542){ //BUF if(*(uint32_t*)type == 0x00465542){ //BUF
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer")); AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer"));
} }
else if(*(uint32_t*)type == 0x004D4143){ //CAM - not yet implemented
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam"));
if(TFL_init_CAM()){
TFL->option.use_cam = 1;
}
else{
return bfalse;
}
}
else if(*(uint32_t*)type == 0x0043494D){ //MIC else if(*(uint32_t*)type == 0x0043494D){ //MIC
#ifdef USE_I2S #ifdef USE_I2S
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){ if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
@ -620,7 +594,7 @@ extern "C" {
if(arena){ if(arena){
TFL->TensorArenaSize = arena; TFL->TensorArenaSize = arena;
} }
TFL->option.mode_inference = 1;
TFL->berry_output_buf = (int8_t*)output_buf; TFL->berry_output_buf = (int8_t*)output_buf;
TFL->berry_output_bufsize = output_size; TFL->berry_output_bufsize = output_size;
TFL_create_task(); TFL_create_task();
@ -636,24 +610,35 @@ extern "C" {
* @brief Send new input data to the tensor flow session * @brief Send new input data to the tensor flow session
* *
* @param vm * @param vm
* @param buf Arbitrary data in a byte buffer, must fit to the TF model * @param buf Arbitrary data in a byte buffer, must fit to the TF model
* @param size Size of buffer (auto-calculated by Berry) * @param size Size of buffer (auto-calculated by Berry)
* @param quantize_to_int8 Optional: convert bytes to quantized int8 values
* @return btrue * @return btrue
* @return bfalse * @return bfalse
*/ */
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size){ bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8){
if(!TFL) return bfalse; if(!TFL) return bfalse;
if(TFL->option.running_loop == 1){ if(TFL->option.running_loop == 1){
uint32_t timeout = 0; uint32_t timeout = 0;
int8_t* tensor_input_buffer = tflite::GetTensorData<int8_t>(TFL->input);
while(!TFL->option.delay_next_invocation == 1) { while(!TFL->option.delay_next_invocation == 1) {
if(timeout>10) break; if(timeout>4) return bfalse;;
delay(2); delay(5);
timeout++; timeout++;
} }
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke")); AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke"));
memcpy((uint8_t*)TFL->input->data.data,(uint8_t*)buf,size); if(quantize_to_int8){
int16_t temp_int;
for(int i = 0; i < size; i++){
temp_int = buf[i];
tensor_input_buffer[i] = (int8_t)(temp_int - 128);
}
} else {
memcpy(tensor_input_buffer, buf,size);
}
TFL->option.delay_next_invocation = 0; TFL->option.delay_next_invocation = 0;
TFL->option.new_input_data = 1;
return btrue; return btrue;
} }
return bfalse; return bfalse;
@ -697,7 +682,6 @@ extern "C" {
return (const char *)item; return (const char *)item;
} }
/** /**
* @brief Shows statistics about the model and the running TFL session * @brief Shows statistics about the model and the running TFL session
* *
@ -734,7 +718,7 @@ extern "C" {
} }
inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type); inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type);
pos += inc; pos += inc;
inc = snprintf_P(s + pos, size-pos, PSTR(",\"sessiom\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes); inc = snprintf_P(s + pos, size-pos, PSTR(",\"session\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
pos += inc; pos += inc;
inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes); inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes);
pos += inc; pos += inc;
@ -742,7 +726,7 @@ extern "C" {
inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes); inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes);
pos += inc; pos += inc;
} }
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invokations\":%u}}"),TFL->stats->invokations); inc = snprintf_P(s + pos, size-pos, PSTR(",\"invocations\":%u}}"),TFL->stats->invocations);
be_pushstring(vm, s); be_pushstring(vm, s);
free(s); free(s);
return s; return s;

View File

@ -2381,6 +2381,7 @@ void MI32sendEnergyWidget(){
#endif //USE_MI_ESP32_ENERGY #endif //USE_MI_ESP32_ENERGY
#ifdef USE_WEBCAM #ifdef USE_WEBCAM
void MI32sendCamWidget(){ void MI32sendCamWidget(){
#ifndef USE_BERRY_CAM
if (Wc.CamServer && Wc.up) { if (Wc.CamServer && Wc.up) {
WSContentSend_P(PSTR("<div class='box")); WSContentSend_P(PSTR("<div class='box"));
if(Settings->webcam_config.resolution>7){ if(Settings->webcam_config.resolution>7){
@ -2389,6 +2390,7 @@ void MI32sendCamWidget(){
WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"), WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"),
(uint32_t)WiFi.localIP()); (uint32_t)WiFi.localIP());
} }
#endif //USE_BERRY_CAM
} }
#endif //USE_WEBCAM #endif //USE_WEBCAM