mirror of https://github.com/arendst/Tasmota.git
Berry: add int8 quantisation to Tensorflow lite driver (#21763)
* add input quantization, minor fixes * prevent divideByZero
This commit is contained in:
parent
23de275dbe
commit
13330eb085
|
@ -21,8 +21,8 @@ BE_FUNC_CTYPE_DECLARE(be_TFL_begin, "b", "@s[(bytes)~]");
|
|||
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
|
||||
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
|
||||
|
||||
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size);
|
||||
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~");
|
||||
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8);
|
||||
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~[b]");
|
||||
|
||||
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
|
||||
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");
|
||||
|
|
|
@ -70,7 +70,7 @@ struct TFL_mic_ctx_t{
|
|||
struct {
|
||||
uint32_t is_audio_initialized:1;
|
||||
uint32_t is_first_time:1;
|
||||
uint32_t new_feature_data:1;
|
||||
// uint32_t new_feature_data:1;
|
||||
uint32_t continue_audio_capture:1;
|
||||
uint32_t stop_audio_capture:1;
|
||||
uint32_t audio_capture_ended:1;
|
||||
|
@ -102,7 +102,7 @@ struct TFL_mic_ctx_t{
|
|||
struct TFL_stats_t{
|
||||
uint32_t model_size = 0;
|
||||
uint32_t used_arena_bytes = 0;
|
||||
uint32_t invokations = 0;
|
||||
uint32_t invocations = 0;
|
||||
uint32_t loop_task_free_stack_bytes = 0;
|
||||
uint32_t mic_task_free_stack_bytes = 0;
|
||||
};
|
||||
|
@ -114,7 +114,7 @@ TfLiteTensor* output = nullptr;
|
|||
int8_t *berry_output_buf = nullptr;
|
||||
size_t berry_output_bufsize;
|
||||
int TensorArenaSize = 2000;
|
||||
uint8_t max_invocations; // max. invocations per second
|
||||
uint8_t max_invocations = 4; // max. invocations per second
|
||||
|
||||
TaskHandle_t loop_task = nullptr;
|
||||
// QueueHandle_t loop_queue = nullptr;
|
||||
|
@ -127,9 +127,8 @@ union{
|
|||
// uint32_t stop_loop:1;
|
||||
uint32_t loop_ended:1;
|
||||
uint32_t unread_output:1;
|
||||
uint32_t use_cam:1;
|
||||
uint32_t new_input_data:1;
|
||||
uint32_t use_mic:1;
|
||||
uint32_t mode_inference:1;
|
||||
} option;
|
||||
uint32_t options;
|
||||
};
|
||||
|
@ -176,13 +175,6 @@ bool TFL_create_task(){
|
|||
return btrue;
|
||||
}
|
||||
|
||||
bool TFL_init_CAM(){
|
||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam not implemented yet"));
|
||||
delete TFL;
|
||||
TFL = nullptr;
|
||||
return bfalse;
|
||||
}
|
||||
|
||||
#ifdef USE_I2S
|
||||
/**
|
||||
* @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing.
|
||||
|
@ -304,10 +296,10 @@ void TFL_capture_samples(void* arg) {
|
|||
|
||||
/* read slice data at once from i2s */
|
||||
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
||||
i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
||||
esp_err_t err = i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
||||
|
||||
if (bytes_read <= 0) {
|
||||
MicroPrintf( PSTR( "Error in I2S read : %d"), bytes_read);
|
||||
MicroPrintf( PSTR( "Error %d in I2S, did read %d bytes"), err, bytes_read);
|
||||
}
|
||||
else {
|
||||
if (bytes_read < i2s_bytes_to_read) {
|
||||
|
@ -328,7 +320,7 @@ void TFL_capture_samples(void* arg) {
|
|||
if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){
|
||||
TFL->mic->feature_buffer_idx = 0;
|
||||
}
|
||||
TFL->mic->flag.new_feature_data = 1;
|
||||
TFL->option.new_input_data = 1;
|
||||
xSemaphoreGive(TFL->mic->feature_buffer_mutex);
|
||||
|
||||
}
|
||||
|
@ -473,7 +465,6 @@ void TFL_task_loop(void *pvParameters){
|
|||
TickType_t xLastWakeTime = xTaskGetTickCount();
|
||||
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
|
||||
|
||||
bool do_invokation = true;
|
||||
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
|
||||
MicroPrintf(PSTR("delay_next_invocation"));
|
||||
vTaskDelay(10/ portTICK_PERIOD_MS);
|
||||
|
@ -483,16 +474,13 @@ void TFL_task_loop(void *pvParameters){
|
|||
#ifdef USE_I2S
|
||||
if(TFL->option.use_mic == 1){
|
||||
TFL->option.delay_next_invocation = 0; // Clean up later
|
||||
if (TFL->mic->flag.new_feature_data == 0){
|
||||
do_invokation = false;
|
||||
}
|
||||
if(TFL->mic->flag.continue_audio_capture == 1){
|
||||
TFL_mic_feature_buf_to_input();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if(do_invokation){
|
||||
// MicroPrintf(PSTR("Invokation requested"));
|
||||
if(TFL->option.new_input_data){
|
||||
// MicroPrintf(PSTR("invocation requested"));
|
||||
TFL->option.running_invocation = 1;
|
||||
int invoke_status = interpreter.Invoke();
|
||||
if (invoke_status != kTfLiteOk) {
|
||||
|
@ -502,15 +490,10 @@ void TFL_task_loop(void *pvParameters){
|
|||
if(TFL->berry_output_buf != nullptr){
|
||||
memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize);
|
||||
}
|
||||
TFL->stats->invokations++;
|
||||
|
||||
TFL->stats->invocations++;
|
||||
TFL->option.unread_output = 1;
|
||||
#ifdef USE_I2S
|
||||
if(TFL->option.use_mic == 1) {
|
||||
TFL->mic->flag.new_feature_data = 0;
|
||||
}
|
||||
TFL->option.running_invocation = 0;
|
||||
#endif //USE_I2S
|
||||
TFL->option.running_invocation = 0;
|
||||
TFL->option.new_input_data == 0;
|
||||
}
|
||||
if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit
|
||||
}
|
||||
|
@ -535,7 +518,7 @@ extern "C" {
|
|||
* @brief Create a context for a tensor flow session, that will later run in a task
|
||||
*
|
||||
* @param vm
|
||||
* @param type BUF - generic byte buffer, CAM - webcam input, MIC - microphone input
|
||||
* @param type BUF - generic byte buffer, MIC - microphone input
|
||||
* @return btrue
|
||||
* @return bfalse
|
||||
*/
|
||||
|
@ -549,21 +532,12 @@ extern "C" {
|
|||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted"));
|
||||
return btrue;
|
||||
}
|
||||
TFL = new TFL_ctx_t;
|
||||
TFL = new TFL_ctx_t{};
|
||||
TFL->options = 0;
|
||||
|
||||
if(*(uint32_t*)type == 0x00465542){ //BUF
|
||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer"));
|
||||
}
|
||||
else if(*(uint32_t*)type == 0x004D4143){ //CAM - not yet implemented
|
||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam"));
|
||||
if(TFL_init_CAM()){
|
||||
TFL->option.use_cam = 1;
|
||||
}
|
||||
else{
|
||||
return bfalse;
|
||||
}
|
||||
}
|
||||
else if(*(uint32_t*)type == 0x0043494D){ //MIC
|
||||
#ifdef USE_I2S
|
||||
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
|
||||
|
@ -620,7 +594,7 @@ extern "C" {
|
|||
if(arena){
|
||||
TFL->TensorArenaSize = arena;
|
||||
}
|
||||
TFL->option.mode_inference = 1;
|
||||
|
||||
TFL->berry_output_buf = (int8_t*)output_buf;
|
||||
TFL->berry_output_bufsize = output_size;
|
||||
TFL_create_task();
|
||||
|
@ -636,24 +610,35 @@ extern "C" {
|
|||
* @brief Send new input data to the tensor flow session
|
||||
*
|
||||
* @param vm
|
||||
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
|
||||
* @param size Size of buffer (auto-calculated by Berry)
|
||||
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
|
||||
* @param size Size of buffer (auto-calculated by Berry)
|
||||
* @param quantize_to_int8 Optional: convert bytes to quantized int8 values
|
||||
* @return btrue
|
||||
* @return bfalse
|
||||
*/
|
||||
|
||||
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size){
|
||||
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8){
|
||||
if(!TFL) return bfalse;
|
||||
if(TFL->option.running_loop == 1){
|
||||
uint32_t timeout = 0;
|
||||
int8_t* tensor_input_buffer = tflite::GetTensorData<int8_t>(TFL->input);
|
||||
while(!TFL->option.delay_next_invocation == 1) {
|
||||
if(timeout>10) break;
|
||||
delay(2);
|
||||
if(timeout>4) return bfalse;;
|
||||
delay(5);
|
||||
timeout++;
|
||||
}
|
||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke"));
|
||||
memcpy((uint8_t*)TFL->input->data.data,(uint8_t*)buf,size);
|
||||
if(quantize_to_int8){
|
||||
int16_t temp_int;
|
||||
for(int i = 0; i < size; i++){
|
||||
temp_int = buf[i];
|
||||
tensor_input_buffer[i] = (int8_t)(temp_int - 128);
|
||||
}
|
||||
} else {
|
||||
memcpy(tensor_input_buffer, buf,size);
|
||||
}
|
||||
TFL->option.delay_next_invocation = 0;
|
||||
TFL->option.new_input_data = 1;
|
||||
return btrue;
|
||||
}
|
||||
return bfalse;
|
||||
|
@ -697,7 +682,6 @@ extern "C" {
|
|||
return (const char *)item;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Shows statistics about the model and the running TFL session
|
||||
*
|
||||
|
@ -734,7 +718,7 @@ extern "C" {
|
|||
}
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type);
|
||||
pos += inc;
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"sessiom\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"session\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
|
||||
pos += inc;
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes);
|
||||
pos += inc;
|
||||
|
@ -742,7 +726,7 @@ extern "C" {
|
|||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes);
|
||||
pos += inc;
|
||||
}
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invokations\":%u}}"),TFL->stats->invokations);
|
||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invocations\":%u}}"),TFL->stats->invocations);
|
||||
be_pushstring(vm, s);
|
||||
free(s);
|
||||
return s;
|
||||
|
|
|
@ -2381,6 +2381,7 @@ void MI32sendEnergyWidget(){
|
|||
#endif //USE_MI_ESP32_ENERGY
|
||||
#ifdef USE_WEBCAM
|
||||
void MI32sendCamWidget(){
|
||||
#ifndef USE_BERRY_CAM
|
||||
if (Wc.CamServer && Wc.up) {
|
||||
WSContentSend_P(PSTR("<div class='box"));
|
||||
if(Settings->webcam_config.resolution>7){
|
||||
|
@ -2389,6 +2390,7 @@ void MI32sendCamWidget(){
|
|||
WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"),
|
||||
(uint32_t)WiFi.localIP());
|
||||
}
|
||||
#endif //USE_BERRY_CAM
|
||||
}
|
||||
#endif //USE_WEBCAM
|
||||
|
||||
|
|
Loading…
Reference in New Issue