fix Tensorflow driver for IDF5.x (#21348)

This commit is contained in:
Christian Baars 2024-05-04 19:59:47 +02:00 committed by GitHub
parent abd013eda8
commit ca5a2d322d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 194 deletions

View File

@ -27,9 +27,6 @@ BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~");
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");
extern void be_TFL_rec(struct bvm *vm, const char* filename, size_t seconds);
BE_FUNC_CTYPE_DECLARE(be_TFL_rec, "", "@si");
#include "be_fixed_TFL.h"
/* @const_object_info_begin
@ -40,7 +37,6 @@ module TFL (scope: global) {
output, ctype_func(be_TFL_output)
log, ctype_func(be_TFL_log)
stats, ctype_func(be_TFL_stats)
rec, ctype_func(be_TFL_rec)
}
@const_object_info_end */

View File

@ -35,7 +35,6 @@
#include "tensorflow/lite/c/common.h"
#ifdef USE_I2S
#include <driver/i2s.h>
#include "mfcc.h"
#endif //USE_I2S
@ -48,12 +47,9 @@
#define kAudioSampleFrequency 16000
#define kAudioSampleBits 16
#endif //USE_I2S
struct TFL_mic_descriptor_t{
// uint8_t i2s_comm_format; // i2s_comm_format_t - enum as uint8_t
uint8_t channel_fmt; // i2s_channelformat_t - enum as uint8_t (right = 3, left = 4)
uint8_t preamp; // factor
uint8_t channel_fmt; // UNUSED NOW !!!
uint8_t preamp; // UNUSED NOW !!!
uint8_t slice_dur; // milliseconds
uint8_t slice_stride; // milliseconds
uint8_t num_filter; // mfe bins
@ -69,8 +65,7 @@ struct TFL_mic_ctx_t{
SemaphoreHandle_t feature_buffer_mutex = nullptr;
MFCC * mfcc = nullptr;
int8_t* model_input_buffer = nullptr;
File *file = nullptr;
int32_t file_bytes_left;
union{
struct {
uint32_t is_audio_initialized:1;
@ -81,17 +76,14 @@ struct TFL_mic_ctx_t{
uint32_t audio_capture_ended:1;
uint32_t use_mfcc:1;
uint32_t use_gain_filter:1;
uint32_t mode_record_audio:1;
uint32_t file_is_open:1;
} flag;
uint32_t flags;
};
int feature_buffer_idx = 0;
int8_t *feature_buffer;
// user input
// int32_t i2s_comm_format;
i2s_channel_fmt_t channel_fmt;
int32_t preamp; // factor
// int32_t channel_fmt; // UNUSED
int32_t preamp; // setup by I2S driver
int32_t slice_dur; // milliseconds
int32_t slice_stride; // milliseconds
uint8_t num_filter; // mfe filter bins
@ -105,6 +97,8 @@ struct TFL_mic_ctx_t{
float preemphasis;
};
#endif //USE_I2S
struct TFL_stats_t{
uint32_t model_size = 0;
uint32_t used_arena_bytes = 0;
@ -141,7 +135,7 @@ union{
};
#ifdef USE_I2S
TFL_mic_ctx_t *mic = nullptr;
#endif
#endif // USE_I2S
TFL_stats_t *stats = nullptr;
};
@ -248,68 +242,6 @@ int TFL_GenerateFeatures(const int16_t* input, int input_size,
return kTfLiteOk;
}
/**
* @brief Open a WAV file with specified audio length in seconds. Writing and closing will happen in the audio capture task.
*
* @param fname - file name with suffix, i.e. "/1.wav"
* @param record_time - duration in seconds, capture task willl close the file according to this value
* @return true - success
* @return false - failure
*/
bool TFL_init_wave_file(const char* fname, size_t record_time){
if(TfsFileExists(fname)){
return false;
}
if(ufsp == nullptr){
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: got no fs handle!!!"));
return false;
}
if(TFL->mic->file != nullptr){
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: close open file"));
TFL->mic->file->close();
delete TFL->mic->file;
}
TFL->mic->file = new File(ufsp->open(fname, "w"));
if(TFL->mic->file != nullptr){
TFL->mic->flag.file_is_open = 1;
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: file open"));
}
else{
return false;
}
#define WAVE_HEADER_SIZE 44
uint32_t sample_rate = kAudioSampleFrequency;
uint8_t sample_bits = kAudioSampleBits;
uint32_t byte_rate = sample_rate * (sample_bits/8);
uint32_t wav_size = byte_rate * record_time;
uint32_t file_size = wav_size + WAVE_HEADER_SIZE - 8;
#define U32_BYTE(x,y) (uint8_t)((x>>(y*8))&0xff)
const char set_wav_header[] = {
'R','I','F','F', // ChunkID
U32_BYTE(file_size,0),U32_BYTE(file_size,1),U32_BYTE(file_size,2),U32_BYTE(file_size,3), // ChunkSize
'W','A','V','E', // Format
'f','m','t',' ', // Subchunk1ID
sample_bits, 0x00, 0x00, 0x00, // Subchunk1Size (16 for PCM)
0x01, 0x00, // AudioFormat (1 for PCM)
0x01, 0x00, // NumChannels (1 channel)
U32_BYTE(sample_rate,0),U32_BYTE(sample_rate,1),U32_BYTE(sample_rate,2),U32_BYTE(sample_rate,3), // ChuSampleRatekSize
U32_BYTE(byte_rate,0),U32_BYTE(byte_rate,1),U32_BYTE(byte_rate,2),U32_BYTE(byte_rate,3), // ByteRate
0x02, 0x00, // BlockAlign
sample_bits, 0x00, // BitsPerSample (16 bits)
'd','a','t','a', // Subchunk2ID
U32_BYTE(wav_size,0),U32_BYTE(wav_size,1),U32_BYTE(wav_size,2),U32_BYTE(wav_size,3), // ByteRate
};
TFL->mic->file->write((uint8_t*)set_wav_header,WAVE_HEADER_SIZE);
TFL->mic->file_bytes_left = wav_size;
return true;
}
/**
* @brief Init I2S microphone. Pins must be configured in the "usual" Tasmota way. Some properties are variables stored in the descriptor.
*
@ -318,80 +250,29 @@ bool TFL_init_wave_file(const char* fname, size_t record_time){
* @return false - failure
*/
bool TFL_init_MIC(const uint8_t* descriptor){
if (PinUsed(GPIO_I2S_BCLK) && PinUsed(GPIO_I2S_WS) && PinUsed(GPIO_I2S_DIN)) {
if (audio_i2s.in) {
if(audio_i2s.in->getRxRate() != kAudioSampleFrequency || audio_i2s.in->getRxBitsPerSample() != kAudioSampleBits){
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: please configure microphone to 16 bits per sample at 16000 Hz"));
return bfalse;
}
audio_i2s.in->startRx();
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: init mic"));
}
else{
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: I2S GPIO's not set for mic input!"));
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: could not connect to I2S driver"));
return bfalse;
}
#define I2S_NUM (i2s_port_t)I2S_NUM_0 // 0 or 1
TFL->mic = new TFL_mic_ctx_t;
TFL->mic->flags = 0;
TFL_set_mic_config(descriptor);
i2s_config_t i2s_config = {
.mode = (i2s_mode_t)(I2S_MODE_MASTER | I2S_MODE_RX),
.sample_rate = kAudioSampleFrequency,
.bits_per_sample = (i2s_bits_per_sample_t)kAudioSampleBits,
.channel_format = TFL->mic->channel_fmt,
.communication_format = I2S_COMM_FORMAT_STAND_I2S, ///i2s_comm_format_t(1), // ?? I2S_COMM_FORMAT_STAND_I2S | I2S_COMM_FORMAT_STAND_MSB
.intr_alloc_flags = ESP_INTR_FLAG_LEVEL1,
.dma_buf_count = 5,
.dma_buf_len = 320,
.use_apll = false,
.tx_desc_auto_clear = false,
.fixed_mclk = 0
};
i2s_pin_config_t pin_config = {
.bck_io_num = Pin(GPIO_I2S_BCLK),
.ws_io_num = Pin(GPIO_I2S_WS),
.data_out_num = I2S_PIN_NO_CHANGE,
.data_in_num = Pin(GPIO_I2S_DIN)
};
esp_err_t ret = ESP_OK;
ret = i2s_driver_install(I2S_NUM, &i2s_config, 0, NULL);
if (ret != ESP_OK) {
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: Error in i2s_driver_install"));
return bfalse;
}
ret = i2s_set_pin(I2S_NUM, &pin_config);
if (ret != ESP_OK) {
AddLog(LOG_LEVEL_ERROR, PSTR("Error in i2s_set_pin"));
return bfalse;
}
TFL->mic->feature_buffer_mutex = xSemaphoreCreateMutex();
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: MIC ctx created"));
return btrue;
}
void TFL_append_audio_to_file(uint8_t *byte_buffer, uint16_t length){
int16_t *samples = (int16_t *)byte_buffer;
for(int i=0;i<length/2;i++){
samples[i] *= TFL->mic->preamp; //gain ... a lot
}
TFL->mic->file->write(byte_buffer,length);
TFL->mic->file_bytes_left -= length;
if(TFL->mic->file_bytes_left<0){
TFL->mic->file->close();
delete TFL->mic->file;
TFL->mic->flag.file_is_open = 0;
TFL->mic->flag.continue_audio_capture = 0;
TFL->option.running_loop = 0;
MicroPrintf( PSTR( "Closing file."));
}
}
/**
* @brief Function spawned as a task for capturing audio. Used for recording or inference.
*
@ -399,9 +280,9 @@ void TFL_append_audio_to_file(uint8_t *byte_buffer, uint16_t length){
*/
void TFL_capture_samples(void* arg) {
MicroPrintf( PSTR( "Capture task started"));
int i2s_bytes_to_read = TFL->mic->i2s_samples_to_get * 2; // according to slice duration
const int i2s_bytes_to_read = TFL->mic->i2s_samples_to_get * 2; // according to slice duration
int buffer_size = (i2s_bytes_to_read * TFL->mic->slice_dur)/TFL->mic->slice_stride; // in bytes, current slice duration plus (potential) history data
const int buffer_size = (i2s_bytes_to_read * TFL->mic->slice_dur)/TFL->mic->slice_stride; // in bytes, current slice duration plus (potential) history data
size_t samples_to_read;
size_t bytes_read;
@ -412,37 +293,22 @@ void TFL_capture_samples(void* arg) {
uint32_t *i2s_long_buffer = (uint32_t*)i2s_sample_buffer;
uint8_t *i2s_read_buffer = i2s_byte_buffer + (buffer_size - i2s_bytes_to_read); // behind the history data, if slice duration != slice stride
// read to "nowhere" to get no startup noise on some mics
i2s_read(I2S_NUM, i2s_byte_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(100));
vTaskDelay(pdMS_TO_TICKS(1000));
TFL_InitializeFeatures(); // TODO: check or not for success
if(TFL->option.mode_inference == 1){
TFL_InitializeFeatures(); // TODO: check or not for success
}
TFL->mic->flag.continue_audio_capture = 1;
MicroPrintf( PSTR( "Enter capture samples loop"));
// "clean" the DMA buffers a last time
i2s_zero_dma_buffer(I2S_NUM);
vTaskDelay(pdMS_TO_TICKS(TFL->mic->slice_stride));
while (TFL->mic->flag.continue_audio_capture == 1) {
TFL->stats->mic_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
TickType_t xLastWakeTime = xTaskGetTickCount();
/* read slice data at once from i2s */
i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
// i2s_read(I2S_NUM, i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
i2s_channel_read(audio_i2s.in->getRxHandle(), (void*)i2s_read_buffer, i2s_bytes_to_read, &bytes_read, pdMS_TO_TICKS(TFL->mic->slice_stride));
if (bytes_read <= 0) {
MicroPrintf( PSTR( "Error in I2S read : %d"), bytes_read);
}
else if(TFL->mic->flag.file_is_open == 1){
TFL_append_audio_to_file(i2s_read_buffer,bytes_read);
if (bytes_read < i2s_bytes_to_read) {
MicroPrintf(PSTR("Partial I2S read: %d"), bytes_read);
}
}
else {
if (bytes_read < i2s_bytes_to_read) {
MicroPrintf(PSTR("Partial I2S read: %d"), bytes_read);
@ -470,7 +336,8 @@ void TFL_capture_samples(void* arg) {
if(TFL->mic->flag.continue_audio_capture == 1) vTaskDelayUntil( &xLastWakeTime, pdMS_TO_TICKS(TFL->mic->slice_stride) );
}
i2s_driver_uninstall(I2S_NUM);
audio_i2s.in->stopRx();
if(TFL->mic->mfcc != nullptr){
delete TFL->mic->mfcc;
TFL->mic->mfcc = nullptr;
@ -487,8 +354,9 @@ void TFL_capture_samples(void* arg) {
*/
void TFL_set_mic_config(const uint8_t *descriptor_buffer){
TFL_mic_descriptor_t *mic_descriptor = (TFL_mic_descriptor_t*)descriptor_buffer;
TFL->mic->channel_fmt = (i2s_channel_fmt_t)mic_descriptor->channel_fmt;
TFL->mic->preamp = mic_descriptor->preamp;
// TFL->mic->channel_fmt = mic_descriptor->channel_fmt; // UNUSED!! - setup by I2S driver
// TFL->mic->preamp = mic_descriptor->preamp; // UNUSED !!
TFL->mic->preamp = audio_i2s.Settings->rx.gain / 16; // setup by I2S driver
TFL->mic->slice_dur = mic_descriptor->slice_dur;
TFL->mic->slice_stride = mic_descriptor->slice_stride;
TFL->mic->num_filter = mic_descriptor->num_filter;
@ -552,9 +420,9 @@ void TFL_delete_tasks(){
vTaskDelay(pdMS_TO_TICKS(10));
}
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: task loop did stop"));
if(TFL->mic != nullptr){
delete TFL->mic;
}
#ifdef USE_I2S
if(TFL->mic != nullptr) {delete TFL->mic;}
#endif //USE_I2S
delete TFL;
TFL = nullptr;
}
@ -604,14 +472,6 @@ void TFL_task_loop(void *pvParameters){
{
TickType_t xLastWakeTime = xTaskGetTickCount();
TFL->stats->loop_task_free_stack_bytes = uxTaskGetStackHighWaterMark(NULL);
#ifdef USE_I2S
if(TFL->option.use_mic == 1){
if(TFL->mic->flag.mode_record_audio == 1){
vTaskDelay(1000/ portTICK_PERIOD_MS); // sit and wait while recording
continue;
}
}
#endif //USE_I2S
bool do_invokation = true;
while(TFL->option.delay_next_invocation == 1 && TFL->option.running_loop == 1){
@ -658,7 +518,9 @@ void TFL_task_loop(void *pvParameters){
// end of loop section
loop_task_exit:
delete TFL->stats;
#ifdef USE_I2S
if(TFL->option.use_mic == 1) {TFL_stop_audio_capture();}
#endif //USE_I2S
MicroPrintf(PSTR("end loop task"));
TFL->option.loop_ended = 1;
vTaskDelete( NULL );
@ -703,6 +565,7 @@ extern "C" {
}
}
else if(*(uint32_t*)type == 0x0043494D){ //MIC
#ifdef USE_I2S
if(descriptor && len==sizeof(TFL_mic_descriptor_t)){
if(TFL_init_MIC(descriptor)){
TFL->option.use_mic = 1;
@ -715,6 +578,9 @@ extern "C" {
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: expected descriptor of size: %u"), sizeof(TFL_mic_descriptor_t));
return bfalse;
}
#else
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: firmware with I2S audio required !!"));
#endif //USE_I2S
}
else{
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: unknown mode"));
@ -882,28 +748,7 @@ extern "C" {
return s;
}
void be_TFL_rec(struct bvm *vm, const char* filename, size_t seconds){
if(TFL){
TFL->option.loop_ended = 1; // just in case someone wants to stop this from another scope
if(TFL->mic != nullptr){
if(TFL->mic->flag.continue_audio_capture == 1){
AddLog(LOG_LEVEL_DEBUG, PSTR("TFL: running recording, requesting termination"));
TFL->mic->flag.continue_audio_capture = 0;
return;
}
if(TFL_init_wave_file(filename,seconds)){
TFL->mic->flag.mode_record_audio = 1;
TFL->stats = new TFL_stats_t;
xTaskCreatePinnedToCore(TFL_capture_samples, "tfl_mic", 1024 * 3, NULL, 10, &TFL->mic->audio_capture_task, 0);
}
}
}
else{
AddLog(LOG_LEVEL_ERROR, PSTR("TFL: no MIC context initialized"));
}
}
} //extern "C"
#endif // USE_BERRY_TF_LITE
#endif // USE_BERRY
#endif // USE_BERRY