Berry: add int8 quantisation to Tensorflow lite driver (#21763)

* add input quantization, minor fixes

* prevent divideByZero
This commit is contained in:
Christian Baars 2024-07-11 23:05:47 +02:00 committed by GitHub
parent 23de275dbe
commit 13330eb085
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 52 deletions

View File

@ -21,8 +21,8 @@ BE_FUNC_CTYPE_DECLARE(be_TFL_begin, "b", "@s[(bytes)~]");
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~");
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8);
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~[b]");
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");

View File

@ -70,7 +70,7 @@ struct TFL_mic_ctx_t{
struct {
uint32_t is_audio_initialized:1;
uint32_t is_first_time:1;
uint32_t new_feature_data:1;
// uint32_t new_feature_data:1;
uint32_t continue_audio_capture:1;
uint32_t stop_audio_capture:1;
uint32_t audio_capture_ended:1;
@ -102,7 +102,7 @@ struct TFL_mic_ctx_t{
struct TFL_stats_t{
uint32_t model_size = 0;
uint32_t used_arena_bytes = 0;
uint32_t invokations = 0;
uint32_t invocations = 0;
uint32_t loop_task_free_stack_bytes = 0;
uint32_t mic_task_free_stack_bytes = 0;
};
@ -114,7 +114,7 @@ TfLiteTensor* output = nullptr;
int8_t *berry_output_buf = nullptr;
size_t berry_output_bufsize;
int TensorArenaSize = 2000;
uint8_t max_invocations; // max. invocations per second
uint8_t max_invocations = 4; // max. invocations per second
TaskHandle_t loop_task = nullptr;
// QueueHandle_t loop_queue = nullptr;
@ -127,9 +127,8 @@ union{
// uint32_t stop_loop:1;
uint32_t loop_ended:1;
uint32_t unread_output:1;
uint32_t use_cam:1;
uint32_t new_input_data:1;
uint32_t use_mic:1;
uint32_t mode_inference:1;
} option;
uint32_t options;
};
@ -176,13 +175,6 @@ bool TFL_create_task(){
return btrue;
}
bool TFL_init_CAM(){
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam not implemented yet"));
delete TFL;
TFL = nullptr;
return bfalse;
}
#ifdef USE_I2S
/**
* @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing.
@ -304,10 +296,10 @@ void TFL_capture_samples(void* arg) {
/* read slice data at once from i2s */
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
esp_err_t err = i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
if (bytes_read <= 0) {
MicroPrintf( PSTR( "Error in I2S read : %d"), bytes_read);
MicroPrintf( PSTR( "Error %d in I2S, did read %d bytes"), err, bytes_read);
}
else {
if (bytes_read < i2s_bytes_to_read) {
@ -328,7 +320,7 @@ void TFL_capture_samples(void* arg) {
if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){
TFL->mic->feature_buffer_idx = 0;
}
TFL->mic->flag.new_feature_data = 1;
TFL->option.new_input_data = 1;
xSemaphoreGive(TFL->mic->feature_buffer_mutex);
}
@ -473,7 +465,6 @@ void TFL_task_loop(void *pvParameters){
TickType_t xLastWakeTime = xTaskGetTickCount();
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
bool do_invokation = true;
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
MicroPrintf(PSTR("delay_next_invocation"));
vTaskDelay(10/ portTICK_PERIOD_MS);
@ -483,16 +474,13 @@ void TFL_task_loop(void *pvParameters){
#ifdef USE_I2S
if(TFL->option.use_mic == 1){
TFL->option.delay_next_invocation = 0; // Clean up later
if (TFL->mic->flag.new_feature_data == 0){
do_invokation = false;
}
if(TFL->mic->flag.continue_audio_capture == 1){
TFL_mic_feature_buf_to_input();
}
}
#endif
if(do_invokation){
// MicroPrintf(PSTR("Invokation requested"));
if(TFL->option.new_input_data){
// MicroPrintf(PSTR("invocation requested"));
TFL->option.running_invocation = 1;
int invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
@ -502,15 +490,10 @@ void TFL_task_loop(void *pvParameters){
if(TFL->berry_output_buf != nullptr){
memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize);
}
TFL->stats->invokations++;
TFL->stats->invocations++;
TFL->option.unread_output = 1;
#ifdef USE_I2S
if(TFL->option.use_mic == 1) {
TFL->mic->flag.new_feature_data = 0;
}
TFL->option.running_invocation = 0;
#endif //USE_I2S
TFL->option.running_invocation = 0;
TFL->option.new_input_data == 0;
}
if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit
}
@ -535,7 +518,7 @@ extern "C" {
* @brief Create a context for a tensor flow session, that will later run in a task
*
* @param vm
* @param type BUF - generic byte buffer, CAM - webcam input, MIC - microphone input
* @param type BUF - generic byte buffer, MIC - microphone input
* @return btrue
* @return bfalse
*/
@ -549,21 +532,12 @@ extern "C" {
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted"));
return btrue;
}
TFL = new TFL_ctx_t;
TFL = new TFL_ctx_t{};
TFL->options = 0;
if(*(uint32_t*)type == 0x00465542){ //BUF
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer"));
}
else if(*(uint32_t*)type == 0x004D4143){ //CAM - not yet implemented
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam"));
if(TFL_init_CAM()){
TFL->option.use_cam = 1;
}
else{
return bfalse;
}
}
else if(*(uint32_t*)type == 0x0043494D){ //MIC
#ifdef USE_I2S
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
@ -620,7 +594,7 @@ extern "C" {
if(arena){
TFL->TensorArenaSize = arena;
}
TFL->option.mode_inference = 1;
TFL->berry_output_buf = (int8_t*)output_buf;
TFL->berry_output_bufsize = output_size;
TFL_create_task();
@ -636,24 +610,35 @@ extern "C" {
* @brief Send new input data to the tensor flow session
*
* @param vm
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
* @param size Size of buffer (auto-calculated by Berry)
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
* @param size Size of buffer (auto-calculated by Berry)
* @param quantize_to_int8 Optional: convert bytes to quantized int8 values
* @return btrue
* @return bfalse
*/
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size){
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8){
if(!TFL) return bfalse;
if(TFL->option.running_loop == 1){
uint32_t timeout = 0;
int8_t* tensor_input_buffer = tflite::GetTensorData<int8_t>(TFL->input);
while(!TFL->option.delay_next_invocation == 1) {
if(timeout>10) break;
delay(2);
if(timeout>4) return bfalse;;
delay(5);
timeout++;
}
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke"));
memcpy((uint8_t*)TFL->input->data.data,(uint8_t*)buf,size);
if(quantize_to_int8){
int16_t temp_int;
for(int i = 0; i < size; i++){
temp_int = buf[i];
tensor_input_buffer[i] = (int8_t)(temp_int - 128);
}
} else {
memcpy(tensor_input_buffer, buf,size);
}
TFL->option.delay_next_invocation = 0;
TFL->option.new_input_data = 1;
return btrue;
}
return bfalse;
@ -697,7 +682,6 @@ extern "C" {
return (const char *)item;
}
/**
* @brief Shows statistics about the model and the running TFL session
*
@ -734,7 +718,7 @@ extern "C" {
}
inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type);
pos += inc;
inc = snprintf_P(s + pos, size-pos, PSTR(",\"sessiom\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
inc = snprintf_P(s + pos, size-pos, PSTR(",\"session\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
pos += inc;
inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes);
pos += inc;
@ -742,7 +726,7 @@ extern "C" {
inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes);
pos += inc;
}
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invokations\":%u}}"),TFL->stats->invokations);
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invocations\":%u}}"),TFL->stats->invocations);
be_pushstring(vm, s);
free(s);
return s;

View File

@ -2381,6 +2381,7 @@ void MI32sendEnergyWidget(){
#endif //USE_MI_ESP32_ENERGY
#ifdef USE_WEBCAM
void MI32sendCamWidget(){
#ifndef USE_BERRY_CAM
if (Wc.CamServer && Wc.up) {
WSContentSend_P(PSTR("<div class='box"));
if(Settings->webcam_config.resolution>7){
@ -2389,6 +2390,7 @@ void MI32sendCamWidget(){
WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"),
(uint32_t)WiFi.localIP());
}
#endif //USE_BERRY_CAM
}
#endif //USE_WEBCAM