mirror of https://github.com/arendst/Tasmota.git
Berry: add int8 quantisation to Tensorflow lite driver (#21763)
* add input quantization, minor fixes * prevent divideByZero
This commit is contained in:
parent
23de275dbe
commit
13330eb085
|
@ -21,8 +21,8 @@ BE_FUNC_CTYPE_DECLARE(be_TFL_begin, "b", "@s[(bytes)~]");
|
||||||
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
|
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
|
||||||
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
|
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
|
||||||
|
|
||||||
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size);
|
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8);
|
||||||
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~");
|
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~[b]");
|
||||||
|
|
||||||
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
|
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
|
||||||
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");
|
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");
|
||||||
|
|
|
@ -70,7 +70,7 @@ struct TFL_mic_ctx_t{
|
||||||
struct {
|
struct {
|
||||||
uint32_t is_audio_initialized:1;
|
uint32_t is_audio_initialized:1;
|
||||||
uint32_t is_first_time:1;
|
uint32_t is_first_time:1;
|
||||||
uint32_t new_feature_data:1;
|
// uint32_t new_feature_data:1;
|
||||||
uint32_t continue_audio_capture:1;
|
uint32_t continue_audio_capture:1;
|
||||||
uint32_t stop_audio_capture:1;
|
uint32_t stop_audio_capture:1;
|
||||||
uint32_t audio_capture_ended:1;
|
uint32_t audio_capture_ended:1;
|
||||||
|
@ -102,7 +102,7 @@ struct TFL_mic_ctx_t{
|
||||||
struct TFL_stats_t{
|
struct TFL_stats_t{
|
||||||
uint32_t model_size = 0;
|
uint32_t model_size = 0;
|
||||||
uint32_t used_arena_bytes = 0;
|
uint32_t used_arena_bytes = 0;
|
||||||
uint32_t invokations = 0;
|
uint32_t invocations = 0;
|
||||||
uint32_t loop_task_free_stack_bytes = 0;
|
uint32_t loop_task_free_stack_bytes = 0;
|
||||||
uint32_t mic_task_free_stack_bytes = 0;
|
uint32_t mic_task_free_stack_bytes = 0;
|
||||||
};
|
};
|
||||||
|
@ -114,7 +114,7 @@ TfLiteTensor* output = nullptr;
|
||||||
int8_t *berry_output_buf = nullptr;
|
int8_t *berry_output_buf = nullptr;
|
||||||
size_t berry_output_bufsize;
|
size_t berry_output_bufsize;
|
||||||
int TensorArenaSize = 2000;
|
int TensorArenaSize = 2000;
|
||||||
uint8_t max_invocations; // max. invocations per second
|
uint8_t max_invocations = 4; // max. invocations per second
|
||||||
|
|
||||||
TaskHandle_t loop_task = nullptr;
|
TaskHandle_t loop_task = nullptr;
|
||||||
// QueueHandle_t loop_queue = nullptr;
|
// QueueHandle_t loop_queue = nullptr;
|
||||||
|
@ -127,9 +127,8 @@ union{
|
||||||
// uint32_t stop_loop:1;
|
// uint32_t stop_loop:1;
|
||||||
uint32_t loop_ended:1;
|
uint32_t loop_ended:1;
|
||||||
uint32_t unread_output:1;
|
uint32_t unread_output:1;
|
||||||
uint32_t use_cam:1;
|
uint32_t new_input_data:1;
|
||||||
uint32_t use_mic:1;
|
uint32_t use_mic:1;
|
||||||
uint32_t mode_inference:1;
|
|
||||||
} option;
|
} option;
|
||||||
uint32_t options;
|
uint32_t options;
|
||||||
};
|
};
|
||||||
|
@ -176,13 +175,6 @@ bool TFL_create_task(){
|
||||||
return btrue;
|
return btrue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TFL_init_CAM(){
|
|
||||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam not implemented yet"));
|
|
||||||
delete TFL;
|
|
||||||
TFL = nullptr;
|
|
||||||
return bfalse;
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef USE_I2S
|
#ifdef USE_I2S
|
||||||
/**
|
/**
|
||||||
* @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing.
|
* @brief Set up some buffers and tables for feature extraction of audio samples. Must run once before starting audio capturing.
|
||||||
|
@ -304,10 +296,10 @@ void TFL_capture_samples(void* arg) {
|
||||||
|
|
||||||
/* read slice data at once from i2s */
|
/* read slice data at once from i2s */
|
||||||
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
||||||
i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
esp_err_t err = i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
|
||||||
|
|
||||||
if (bytes_read <= 0) {
|
if (bytes_read <= 0) {
|
||||||
MicroPrintf( PSTR( "Error in I2S read : %d"), bytes_read);
|
MicroPrintf( PSTR( "Error %d in I2S, did read %d bytes"), err, bytes_read);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (bytes_read < i2s_bytes_to_read) {
|
if (bytes_read < i2s_bytes_to_read) {
|
||||||
|
@ -328,7 +320,7 @@ void TFL_capture_samples(void* arg) {
|
||||||
if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){
|
if(TFL->mic->feature_buffer_idx == TFL->mic->slice_count){
|
||||||
TFL->mic->feature_buffer_idx = 0;
|
TFL->mic->feature_buffer_idx = 0;
|
||||||
}
|
}
|
||||||
TFL->mic->flag.new_feature_data = 1;
|
TFL->option.new_input_data = 1;
|
||||||
xSemaphoreGive(TFL->mic->feature_buffer_mutex);
|
xSemaphoreGive(TFL->mic->feature_buffer_mutex);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -473,7 +465,6 @@ void TFL_task_loop(void *pvParameters){
|
||||||
TickType_t xLastWakeTime = xTaskGetTickCount();
|
TickType_t xLastWakeTime = xTaskGetTickCount();
|
||||||
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
|
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
|
||||||
|
|
||||||
bool do_invokation = true;
|
|
||||||
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
|
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
|
||||||
MicroPrintf(PSTR("delay_next_invocation"));
|
MicroPrintf(PSTR("delay_next_invocation"));
|
||||||
vTaskDelay(10/ portTICK_PERIOD_MS);
|
vTaskDelay(10/ portTICK_PERIOD_MS);
|
||||||
|
@ -483,16 +474,13 @@ void TFL_task_loop(void *pvParameters){
|
||||||
#ifdef USE_I2S
|
#ifdef USE_I2S
|
||||||
if(TFL->option.use_mic == 1){
|
if(TFL->option.use_mic == 1){
|
||||||
TFL->option.delay_next_invocation = 0; // Clean up later
|
TFL->option.delay_next_invocation = 0; // Clean up later
|
||||||
if (TFL->mic->flag.new_feature_data == 0){
|
|
||||||
do_invokation = false;
|
|
||||||
}
|
|
||||||
if(TFL->mic->flag.continue_audio_capture == 1){
|
if(TFL->mic->flag.continue_audio_capture == 1){
|
||||||
TFL_mic_feature_buf_to_input();
|
TFL_mic_feature_buf_to_input();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
if(do_invokation){
|
if(TFL->option.new_input_data){
|
||||||
// MicroPrintf(PSTR("Invokation requested"));
|
// MicroPrintf(PSTR("invocation requested"));
|
||||||
TFL->option.running_invocation = 1;
|
TFL->option.running_invocation = 1;
|
||||||
int invoke_status = interpreter.Invoke();
|
int invoke_status = interpreter.Invoke();
|
||||||
if (invoke_status != kTfLiteOk) {
|
if (invoke_status != kTfLiteOk) {
|
||||||
|
@ -502,15 +490,10 @@ void TFL_task_loop(void *pvParameters){
|
||||||
if(TFL->berry_output_buf != nullptr){
|
if(TFL->berry_output_buf != nullptr){
|
||||||
memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize);
|
memcpy(TFL->berry_output_buf,(int8_t*)TFL->output->data.data,TFL->berry_output_bufsize);
|
||||||
}
|
}
|
||||||
TFL->stats->invokations++;
|
TFL->stats->invocations++;
|
||||||
|
|
||||||
TFL->option.unread_output = 1;
|
TFL->option.unread_output = 1;
|
||||||
#ifdef USE_I2S
|
|
||||||
if(TFL->option.use_mic == 1) {
|
|
||||||
TFL->mic->flag.new_feature_data = 0;
|
|
||||||
}
|
|
||||||
TFL->option.running_invocation = 0;
|
TFL->option.running_invocation = 0;
|
||||||
#endif //USE_I2S
|
TFL->option.new_input_data == 0;
|
||||||
}
|
}
|
||||||
if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit
|
if(TFL->option.running_loop == 1) vTaskDelayUntil(&xLastWakeTime, pdMS_TO_TICKS(1000 / TFL->max_invocations)); //maybe we already want to exit
|
||||||
}
|
}
|
||||||
|
@ -535,7 +518,7 @@ extern "C" {
|
||||||
* @brief Create a context for a tensor flow session, that will later run in a task
|
* @brief Create a context for a tensor flow session, that will later run in a task
|
||||||
*
|
*
|
||||||
* @param vm
|
* @param vm
|
||||||
* @param type BUF - generic byte buffer, CAM - webcam input, MIC - microphone input
|
* @param type BUF - generic byte buffer, MIC - microphone input
|
||||||
* @return btrue
|
* @return btrue
|
||||||
* @return bfalse
|
* @return bfalse
|
||||||
*/
|
*/
|
||||||
|
@ -549,21 +532,12 @@ extern "C" {
|
||||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted"));
|
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: context deleted"));
|
||||||
return btrue;
|
return btrue;
|
||||||
}
|
}
|
||||||
TFL = new TFL_ctx_t;
|
TFL = new TFL_ctx_t{};
|
||||||
TFL->options = 0;
|
TFL->options = 0;
|
||||||
|
|
||||||
if(*(uint32_t*)type == 0x00465542){ //BUF
|
if(*(uint32_t*)type == 0x00465542){ //BUF
|
||||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer"));
|
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode generic buffer"));
|
||||||
}
|
}
|
||||||
else if(*(uint32_t*)type == 0x004D4143){ //CAM - not yet implemented
|
|
||||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: mode webcam"));
|
|
||||||
if(TFL_init_CAM()){
|
|
||||||
TFL->option.use_cam = 1;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
return bfalse;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(*(uint32_t*)type == 0x0043494D){ //MIC
|
else if(*(uint32_t*)type == 0x0043494D){ //MIC
|
||||||
#ifdef USE_I2S
|
#ifdef USE_I2S
|
||||||
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
|
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
|
||||||
|
@ -620,7 +594,7 @@ extern "C" {
|
||||||
if(arena){
|
if(arena){
|
||||||
TFL->TensorArenaSize = arena;
|
TFL->TensorArenaSize = arena;
|
||||||
}
|
}
|
||||||
TFL->option.mode_inference = 1;
|
|
||||||
TFL->berry_output_buf = (int8_t*)output_buf;
|
TFL->berry_output_buf = (int8_t*)output_buf;
|
||||||
TFL->berry_output_bufsize = output_size;
|
TFL->berry_output_bufsize = output_size;
|
||||||
TFL_create_task();
|
TFL_create_task();
|
||||||
|
@ -638,22 +612,33 @@ extern "C" {
|
||||||
* @param vm
|
* @param vm
|
||||||
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
|
* @param buf Arbitrary data in a byte buffer, must fit to the TF model
|
||||||
* @param size Size of buffer (auto-calculated by Berry)
|
* @param size Size of buffer (auto-calculated by Berry)
|
||||||
|
* @param quantize_to_int8 Optional: convert bytes to quantized int8 values
|
||||||
* @return btrue
|
* @return btrue
|
||||||
* @return bfalse
|
* @return bfalse
|
||||||
*/
|
*/
|
||||||
|
|
||||||
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size){
|
bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size, bbool quantize_to_int8){
|
||||||
if(!TFL) return bfalse;
|
if(!TFL) return bfalse;
|
||||||
if(TFL->option.running_loop == 1){
|
if(TFL->option.running_loop == 1){
|
||||||
uint32_t timeout = 0;
|
uint32_t timeout = 0;
|
||||||
|
int8_t* tensor_input_buffer = tflite::GetTensorData<int8_t>(TFL->input);
|
||||||
while(!TFL->option.delay_next_invocation == 1) {
|
while(!TFL->option.delay_next_invocation == 1) {
|
||||||
if(timeout>10) break;
|
if(timeout>4) return bfalse;;
|
||||||
delay(2);
|
delay(5);
|
||||||
timeout++;
|
timeout++;
|
||||||
}
|
}
|
||||||
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke"));
|
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: imput new data and invoke"));
|
||||||
memcpy((uint8_t*)TFL->input->data.data,(uint8_t*)buf,size);
|
if(quantize_to_int8){
|
||||||
|
int16_t temp_int;
|
||||||
|
for(int i = 0; i < size; i++){
|
||||||
|
temp_int = buf[i];
|
||||||
|
tensor_input_buffer[i] = (int8_t)(temp_int - 128);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
memcpy(tensor_input_buffer, buf,size);
|
||||||
|
}
|
||||||
TFL->option.delay_next_invocation = 0;
|
TFL->option.delay_next_invocation = 0;
|
||||||
|
TFL->option.new_input_data = 1;
|
||||||
return btrue;
|
return btrue;
|
||||||
}
|
}
|
||||||
return bfalse;
|
return bfalse;
|
||||||
|
@ -697,7 +682,6 @@ extern "C" {
|
||||||
return (const char *)item;
|
return (const char *)item;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Shows statistics about the model and the running TFL session
|
* @brief Shows statistics about the model and the running TFL session
|
||||||
*
|
*
|
||||||
|
@ -734,7 +718,7 @@ extern "C" {
|
||||||
}
|
}
|
||||||
inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type);
|
inc = snprintf_P(s + pos, size-pos, PSTR("],\"output_type\":%u}"),TFL->output->type);
|
||||||
pos += inc;
|
pos += inc;
|
||||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"sessiom\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
|
inc = snprintf_P(s + pos, size-pos, PSTR(",\"session\":{\"used_arena\":%u"),TFL->stats->used_arena_bytes);
|
||||||
pos += inc;
|
pos += inc;
|
||||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes);
|
inc = snprintf_P(s + pos, size-pos, PSTR(",\"loop_stack\":%u"),TFL->stats->loop_task_free_stack_bytes);
|
||||||
pos += inc;
|
pos += inc;
|
||||||
|
@ -742,7 +726,7 @@ extern "C" {
|
||||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes);
|
inc = snprintf_P(s + pos, size-pos, PSTR(",\"audio_stack\":%u"),TFL->stats->mic_task_free_stack_bytes);
|
||||||
pos += inc;
|
pos += inc;
|
||||||
}
|
}
|
||||||
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invokations\":%u}}"),TFL->stats->invokations);
|
inc = snprintf_P(s + pos, size-pos, PSTR(",\"invocations\":%u}}"),TFL->stats->invocations);
|
||||||
be_pushstring(vm, s);
|
be_pushstring(vm, s);
|
||||||
free(s);
|
free(s);
|
||||||
return s;
|
return s;
|
||||||
|
|
|
@ -2381,6 +2381,7 @@ void MI32sendEnergyWidget(){
|
||||||
#endif //USE_MI_ESP32_ENERGY
|
#endif //USE_MI_ESP32_ENERGY
|
||||||
#ifdef USE_WEBCAM
|
#ifdef USE_WEBCAM
|
||||||
void MI32sendCamWidget(){
|
void MI32sendCamWidget(){
|
||||||
|
#ifndef USE_BERRY_CAM
|
||||||
if (Wc.CamServer && Wc.up) {
|
if (Wc.CamServer && Wc.up) {
|
||||||
WSContentSend_P(PSTR("<div class='box"));
|
WSContentSend_P(PSTR("<div class='box"));
|
||||||
if(Settings->webcam_config.resolution>7){
|
if(Settings->webcam_config.resolution>7){
|
||||||
|
@ -2389,6 +2390,7 @@ void MI32sendCamWidget(){
|
||||||
WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"),
|
WSContentSend_P(PSTR("' id='cam' style='background-image:url(http://%_I:81/stream);background-repeat:no-repeat;background-size:cover;'></div>"),
|
||||||
(uint32_t)WiFi.localIP());
|
(uint32_t)WiFi.localIP());
|
||||||
}
|
}
|
||||||
|
#endif //USE_BERRY_CAM
|
||||||
}
|
}
|
||||||
#endif //USE_WEBCAM
|
#endif //USE_WEBCAM
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue