Berry: add tensorflow lite for microcontrollers (#18119)

Co-authored-by: Christian Baars <christianbaars@MacBook-Pro-von-Christian.local>
This commit is contained in:
Christian Baars 2023-03-05 16:46:18 +01:00 committed by GitHub
parent 0f9bece011
commit 83f039cdf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
466 changed files with 125929 additions and 0 deletions

View File

@ -50,6 +50,7 @@ be_extern_native_module(partition_core);
be_extern_native_module(crc);
be_extern_native_module(crypto);
be_extern_native_module(ULP);
be_extern_native_module(TFL);
be_extern_native_module(mdns);
#ifdef USE_ZIGBEE
be_extern_native_module(zigbee);
@ -171,6 +172,9 @@ BERRY_LOCAL const bntvmodule* const be_module_table[] = {
#if defined(USE_BERRY_ULP) && ((CONFIG_IDF_TARGET_ESP32) || defined(CONFIG_IDF_TARGET_ESP32S2) || defined(CONFIG_IDF_TARGET_ESP32S3))
&be_native_module(ULP),
#endif // USE_BERRY_ULP
#if defined(USE_BERRY_TF_LITE)
&be_native_module(TFL),
#endif //USE_BERRY_TF_LITE
#if defined(USE_MI_ESP32) && !defined(USE_BLE_ESP32)
&be_native_module(MI32),
&be_native_module(BLE),

View File

@ -0,0 +1,47 @@
/********************************************************************
* Tasmota lib
*
* To use: import TFL`
*******************************************************************/
#include "be_constobj.h"
#include "be_mapping.h"
#ifdef USE_BERRY_TF_LITE
extern const char* be_TFL_log(struct bvm *vm);
BE_FUNC_CTYPE_DECLARE(be_TFL_log, "s", "@");
extern const char* be_TFL_stats(struct bvm *vm);
BE_FUNC_CTYPE_DECLARE(be_TFL_stats, "s", "@");
extern bbool be_TFL_begin(struct bvm *vm, const char* type, const uint8_t *descriptor, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_begin, "b", "@s[(bytes)~]");
extern bbool be_TFL_load(struct bvm *vm, const uint8_t *model_buf, size_t model_size, const uint8_t *output_buf, size_t output_size,int arena);
BE_FUNC_CTYPE_DECLARE(be_TFL_load, "b", "@(bytes)~(bytes)~[i]");
extern bbool be_TFL_input(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_input, "b", "@(bytes)~");
extern bbool be_TFL_output(struct bvm *vm, const uint8_t *buf, size_t size);
BE_FUNC_CTYPE_DECLARE(be_TFL_output, "b", "@(bytes)~");
extern void be_TFL_rec(struct bvm *vm, const char* filename, size_t seconds);
BE_FUNC_CTYPE_DECLARE(be_TFL_rec, "", "@si");
#include "be_fixed_TFL.h"
/* @const_object_info_begin
module TFL (scope: global) {
begin, ctype_func(be_TFL_begin)
load, ctype_func(be_TFL_load)
input, ctype_func(be_TFL_input)
output, ctype_func(be_TFL_output)
log, ctype_func(be_TFL_log)
stats, ctype_func(be_TFL_stats)
rec, ctype_func(be_TFL_rec)
}
@const_object_info_end */
#endif // USE_BERRY_TF_LITE

View File

@ -0,0 +1,7 @@
name=MElFreqencyExtractor
version=1.0
author=Christian Baars
maintainer=Christian Baars
sentence=Feature Extractor using mel frequencies
paragraph=Uses ESP-DSP library.
architectures=esp32

View File

@ -0,0 +1,315 @@
/*
mfcc.h - mel frequency extractor for ESP32
Computes features for slizes of audio data similiar to speechpy
This is intended to provide a stripped down implementation that can work with Edgempulse trained models
based on:
https://github.com/astorfi/speechpy
https://github.com/AIWintermuteAI/Speech-to-Intent-Micro/blob/main/inference_code/Wio_Terminal/wio_speech_to_intent_150_10/mfcc.cpp
Copyright (C) 2022 Christian Baars
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.@
*/
#ifndef MELFREQUENCYEXTRACTOR_H
#define MELFREQUENCYEXTRACTOR_H
#include <Arduino.h>
#include <string.h>
#include "float.h"
#include "esp_dsp.h"
class MFCC{
private:
int num_mfcc_features;
int frame_len;
int frame_len_padded;
int num_bank_bins;
float * m_frame;
float * m_buffer;
float * m_mel_energies;
float * m_dct_matrix;
float * m_mel_fbank;
uint8_t m_amplification;
float m_preemphasis;
float * create_dct_matrix(int32_t input_length, int32_t coefficient_count);
void create_mel_filterbank(int samp_freq, int low_freq, int high_freq);
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf (mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf (1.0f + freq / 700.0f);
}
public:
MFCC(int num_mfcc_features, int frame_len,int num_bank_bins, int samp_freq, int low_freq, int high_freq);
~MFCC();
void set_preamp(uint8_t amplification);
void set_preemphasis(float preemphasis);
void mfcc_compute(const int16_t* data, float* mfcc_out);
void log10_normalize(float* out_buf, int out_buf_len, int noise_floor_db);
};
MFCC::MFCC(int num_mfcc_features, int frame_len, int num_bank_bins, int samp_freq, int low_freq, int high_freq)
:num_mfcc_features(num_mfcc_features),
frame_len(frame_len),
num_bank_bins(num_bank_bins)
{
// Round-up to nearest power of 2.
frame_len_padded = pow(2,ceil((log(frame_len)/log(2))));
m_frame = new float[frame_len_padded];
m_buffer = new float[frame_len_padded * 2];
m_mel_energies = new float[num_bank_bins];
//create window function
// window_func = new float[frame_len];
// dsps_wind_hann_f32(window_func, frame_len);
m_amplification = 1;
m_preemphasis = 0.0;
//create mel filterbank
create_mel_filterbank(samp_freq, low_freq, high_freq);
//create DCT matrix for mfcc mode
if(num_mfcc_features != 0){
m_dct_matrix = create_dct_matrix(num_bank_bins, num_mfcc_features);
}
//initialize FFT
int ret = dsps_fft2r_init_fc32(NULL, frame_len_padded);
if(ret==0){
MicroPrintf("Framelength: %u, (rounded: %u)", frame_len,frame_len_padded);
}
else{
MicroPrintf("dsps_fft2r_init_fc32 error: %d",ret);
}
}
MFCC::~MFCC() {
delete []m_frame;
delete []m_buffer;
delete []m_mel_energies;
// delete []window_func;
delete []m_dct_matrix;
// for(int i=0;i<num_bank_bins;i++)
// delete mel_fbank[i];
delete [] m_mel_fbank;
dsps_fft2r_deinit_fc32();
}
float * MFCC::create_dct_matrix(int32_t input_length, int32_t coefficient_count) {
int32_t k, n;
float * M = new float[input_length*coefficient_count];
float normalizer = sqrt(2.0/(float)input_length);
for (k = 0; k < coefficient_count; k++) {
for (n = 0; n < input_length; n++) {
M[k*input_length+n] = normalizer * cos( ((double)M_PI)/input_length * (n + 0.5) * k );
}
}
return M;
}
void MFCC::create_mel_filterbank(int samp_freq, int low_freq, int high_freq) {
// MicroPrintf("Create FB ...");
int coefficients = frame_len_padded/2 + 1;
m_mel_fbank = new float[num_bank_bins * coefficients](); // zero-init
uint32_t delta = (MelScale(high_freq) - MelScale(low_freq))/(num_bank_bins + 1);
float mels[num_bank_bins+2];
float hertz[num_bank_bins+2];
int freq_index[num_bank_bins+2];
for (int i = 0;i<num_bank_bins+2;i++){
mels[i] = MelScale(low_freq) + (i * delta);
hertz[i] = InverseMelScale(mels[i]);
freq_index[i] = (coefficients + 1) * hertz[i] /samp_freq;
}
for (int i = 0;i<num_bank_bins;i++ ){
int left = int(freq_index[i]);
int middle = int(freq_index[i + 1]);
int right = int(freq_index[i + 2]);
float slope_up = 1/float(middle-left);
float slope_down = 1/float(right-middle);
// MicroPrintf("%u %u %u %f %f",left,middle,right,slope_up,slope_down);
for (int j = 0;j<(right-left+1);j++){
if (j < middle-left + 1){
m_mel_fbank[(i*coefficients)+left+j] = j*slope_up;
}
else{
m_mel_fbank[(i*coefficients)+left+j] = m_mel_fbank[(i*coefficients)+left+j-1]-slope_down;
}
}
}
// MicroPrintf("%f %f %f %f %f %f %f %f ",m_mel_fbank[5],m_mel_fbank[6] ,m_mel_fbank[7] ,m_mel_fbank[8] ,m_mel_fbank[9] ,m_mel_fbank[10] ,m_mel_fbank[11] ,m_mel_fbank[12]);
// MicroPrintf("FB done");
}
void MFCC::log10_normalize(float* out_buf, int out_buf_len, int noise_floor_db) {
const float noise = static_cast<float>(noise_floor_db * -1);
const float noise_scale = 1.0f / (static_cast<float>(noise_floor_db * -1) + 12.0f);
for (size_t ix = 0; ix < out_buf_len; ix++) {
float f = out_buf[ix];
if (f < 1e-30) {
out_buf[ix] = 0;
return;
}
f = 10.0f * log10(f); // scale by 10
f += noise;
f *= noise_scale;
// clip again
if (f < 0.0f) f = 0.0f;
else if (f > 1.0f) f = 1.0f;
out_buf[ix] = f;
}
}
void MFCC::set_preamp(uint8_t amplification){
m_amplification = amplification;
}
void MFCC::set_preemphasis(float preemphasis){
m_preemphasis = preemphasis;
// Speechpy computes this over the window of a sample, here we will compute only over the slize !!
}
void MFCC::mfcc_compute(const int16_t * audio_data, float* mfcc_out) {
int32_t i, j, bin;
int coefficients = frame_len_padded/2 + 1;
int data_clipped = 0;
int data_clipped_low = 0;
float conv_factor = m_amplification;
float clip_thres = 0.99f * (float)(1<<15);
// MicroPrintf("%d %d %d %d %d %d %d %d",audio_data[0],audio_data[1] ,audio_data[2] ,audio_data[3] ,audio_data[4] ,audio_data[5] ,audio_data[6] ,audio_data[7]);
//TensorFlow way of normalizing .wav data to (-1,1) for speechpy's MFE
if(num_mfcc_features == 0){
conv_factor /= (float)(1<<15);
clip_thres /= (float)(1<<15);
}
for (int i = 0; i < frame_len; i++) {
m_buffer[i] = audio_data[i] * conv_factor; //mfe -1..1, mfcc int16_t as float, both with additional pre_amp factor
}
if(m_buffer[i]> clip_thres){
m_buffer[i] /= m_amplification;
data_clipped++;
}
else if( m_buffer[i]< -clip_thres){
m_buffer[i] /= m_amplification;
data_clipped_low++;
}
if(data_clipped>0)
MicroPrintf("Clip: %d __ %d",data_clipped, data_clipped_low);
// MicroPrintf("%f %f %f %f %f %f %f %f ",m_buffer[0],m_buffer[1] ,m_buffer[2] ,m_buffer[3] ,m_buffer[4] ,m_buffer[5] ,m_buffer[6] ,m_buffer[7]);
//pre-emphasis
if(m_preemphasis!=0.0){
m_frame[0] = m_buffer[0] - m_preemphasis * m_buffer[frame_len - 1]; // roll through the frame "back" to the end
for (i = 1; i < frame_len; i++){
m_frame[i] = m_buffer[i] - m_preemphasis * m_buffer[i - 1];
}
}
else{
for (i = 1; i < frame_len; i++){
m_frame[i] = m_buffer[i];
}
}
// prepare buffer for FFT
for (i = 0; i < frame_len_padded; i++) {
m_buffer[i * 2] = i<frame_len ? m_frame[i] : 0;
// m_buffer[i * 2] = i<frame_len ? frame[i] * window_func[i] : 0; // in case we want to use a window function
m_buffer[i*2 + 1] = 0;
}
// MicroPrintf("%f %f %f %f %f %f %f %f ",frame[0],frame[1] ,frame[2] ,frame[3] ,frame[4] ,frame[5] ,frame[6] ,frame[7]);
//Compute FFT
int err = dsps_fft2r_fc32(m_buffer, frame_len_padded);
err += dsps_bit_rev_fc32(m_buffer, frame_len_padded); //Bit reverse
// err += dsps_cplx2reC_fc32(m_buffer, frame_len_padded);// Complex spectrum in y_cf
if(err!=0){
MicroPrintf("dsps_fft2r error: %u %f %f %f %f %f %f %f %f ",err,m_buffer[0],m_buffer[1] ,m_buffer[2] ,m_buffer[3] ,m_buffer[4] ,m_buffer[5] ,m_buffer[512] ,m_buffer[513]);
}
for (int i = 0 ; i < coefficients ; i++) {
m_buffer[i] = (m_buffer[i*2] * m_buffer[i*2] + m_buffer[i*2 + 1] * m_buffer[i*2 + 1])/frame_len_padded;
}
// MicroPrintf(" pow spec: %f %f %f %f %f %f %f %f ",m_buffer[0],m_buffer[1] ,m_buffer[2] ,m_buffer[3] ,m_buffer[4] ,m_buffer[5] ,m_buffer[255] ,m_buffer[256]);
//Apply mel filterbanks
for (int i = 0;i<num_bank_bins;i++ ){
m_mel_energies[i] = 0;
for (int j = 0;j<coefficients;j++ ){
m_mel_energies[i] += m_buffer[j] * m_mel_fbank[(i * coefficients) + j];
}
}
// for MFE copy and return - compute 10 * log10() later explicitely
if(num_mfcc_features == 0){
for (bin = 0; bin < num_bank_bins; bin++){
mfcc_out[bin] = m_mel_energies[bin];
}
// MicroPrintf("%u feat: %f %f %f %f %f %f %f %f ",num_bank_bins,mfcc_out[0],mfcc_out[1] ,mfcc_out[2] ,mfcc_out[3] ,mfcc_out[4] ,mfcc_out[5] ,mfcc_out[6] ,mfcc_out[7]);
return;
}
// Continue for MFCC
// Take log
for (bin = 0; bin < num_bank_bins; bin++){
m_mel_energies[bin] = logf(m_mel_energies[bin]);
}
// Take DCT. Uses matrix mul.
for (i = 1; i < num_mfcc_features; i++) {
float sum = 0.0;
for (j = 0; j < num_bank_bins; j++) {
sum += m_dct_matrix[i*num_bank_bins+j] * m_mel_energies[j];
}
mfcc_out[i] = sum;
}
// replace first cepstral coefficient with log of frame energy for DC elimination
for (i=0; i<frame_len_padded; i++){
mfcc_out[0] += m_buffer[i];
}
mfcc_out[0] = logf(mfcc_out[0]);
}
#endif //MELFREQUENCYEXTRACTOR_H

View File

@ -0,0 +1,9 @@
# This is the list of Tensorflow's significant contributors.
#
# This does not necessarily list everyone who has contributed code,
# especially since many employees of one corporation may be contributing.
# To see the full list of contributors, see the revision history in
# source control.
Google LLC
Yuan Tang <terrytangyuan@gmail.com>

View File

@ -0,0 +1,4 @@
* @tensorflow/micro @ddavis-2015
/.github/ @advaitjain
/ci/ @advaitjain

View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,80 @@
# TensorFlow Lite Micro Library for Arduino-Espressif32
This repository has the code (including examples) needed to use Tensorflow Lite Micro on an Arduino.
## Table of contents
<!--ts-->
* [Table of contents](#table-of-contents)
* [Build Status](#build-status)
* [How to Install](#how-to-install)
* [GitHub](#github)
* [Checking your Installation](#checking-your-installation)
* [Compatibility](#compatibility)
* [License](#license)
* [Contributing](#contributing)
<!--te-->
## Build Status
Build Type | Status |
--------------- | ------------- |
Arduino CLI on Linux | [![Arduino](https://github.com/tensorflow/tflite-micro-arduino-examples/actions/workflows/ci.yml/badge.svg?event=schedule)](https://github.com/tensorflow/tflite-micro-arduino-examples/actions/workflows/ci.yml)
Sync from tflite-micro | [![Sync from tflite-micro](https://github.com/tensorflow/tflite-micro-arduino-examples/actions/workflows/sync.yml/badge.svg)](https://github.com/tensorflow/tflite-micro-arduino-examples/actions/workflows/sync.yml)
## How to Install
### GitHub
The officially supported TensorFlow Lite Micro library for Arduino resides
in the [tflite-micro-arduino-examples](https://github.com/tensorflow/tflite-micro-arduino-examples)
GitHub repository.
To install the in-development version of this library, you can use the
latest version directly from the GitHub repository. This requires you clone the
repo into the folder that holds libraries for the Arduino IDE. The location for
this folder varies by operating system, but typically it's in
`~/Arduino/libraries` on Linux, `~/Documents/Arduino/libraries/` on MacOS, and
`My Documents\Arduino\Libraries` on Windows.
Once you're in that folder in the terminal, you can then grab the code using the
git command line tool:
```
git clone https://github.com/tensorflow/tflite-micro-arduino-examples Arduino_TensorFlowLite
```
To update your clone of the repository to the latest code, use the following terminal commands:
```
cd Arduino_TensorFlowLite
git pull
```
### Checking your Installation
Once the library has been installed, you should then start the Arduino IDE.
You will now see an `Arduino_TensorFlowLite`
entry in the `File -> Examples` menu of the Arduino IDE. This submenu contains a list
of sample projects you can try out.
![Hello World](docs/hello_world_screenshot.png)
## Compatibility
This library is designed for the `Arduino Nano 33 BLE Sense` board. The framework
code for running machine learning models should be compatible with most Arm Cortex
M-based boards, such as the `Raspberry Pi Pico`, but the code to access peripherals
like microphones, cameras, and accelerometers is specific to the `Nano 33 BLE Sense`.
## License
This code is made available under the Apache 2 license.
## Contributing
Forks of this library are welcome and encouraged. If you have bug reports or
fixes to contribute, the source of this code is at [https:://github.com/tensorflow/tflite-micro](github.com/tensorflow/tflite-micro)
and all issues and pull requests should be directed there.
The code here is created through an automatic project generation process
and may differ from
that source of truth, since it's cross-platform and needs to be modified to
work within the Arduino IDE.

View File

@ -0,0 +1,59 @@
<!-- mdformat off(b/169948621#comment2) -->
# Hello World Example
This example is designed to demonstrate the absolute basics of using [TensorFlow
Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers).
It includes the full end-to-end workflow of training a model, converting it for
use with TensorFlow Lite for Microcontrollers for running inference on a
microcontroller.
The model is trained to replicate a `sine` function and generates a pattern of
data to blink the built-in LED in a fade in/out pattern.
## Table of contents
<!--ts-->
* [Table of contents](#table-of-contents)
* [Deploy to Arduino](#deploy-to-arduino)
* [Install the Arduino_TensorFlowLite library](#install-the-arduino_tensorflowlite-library)
* [Load and run the example](#load-and-run-the-example)
<!--te-->
## Deploy to Arduino
The following instructions will help you build and deploy this sample
to [Arduino](https://www.arduino.cc/) devices.
The sample has been tested with the following devices:
- [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers)
- [Arduino Tiny Machine Learning Kit](https://store-usa.arduino.cc/products/arduino-tiny-machine-learning-kit)
The sample will use PWM to fade an LED on and off according to the model's
output. In the code, the `LED_BUILTIN` constant is used to specify the board's
built-in LED as the one being controlled. However, on some boards, this built-in
LED is not attached to a pin with PWM capabilities. In this case, the LED will
blink instead of fading.
![Animation on Nano 33 BLE Sense](../../docs/hello_world_animation.gif)
### Install the Arduino_TensorFlowLite library
To install the TensorFlow Lite Micro for Arduino library, see the
[how to install](../../README.md#how-to-install) instructions.
### Load and run the example
Once the library has been added, go to `File -> Examples`. You should see an
entry within the list named `Arduino_TensorFlowLite`. Select
it and click `hello_world` to load the example.
Use the Arduino IDE to build and upload the example. Once it is running,
you should see the built-in LED on your device flashing.
The Arduino Desktop IDE includes a plotter that we can use to display the sine
wave graphically. To view it, go to `Tools -> Serial Plotter`. You will see one
datapoint being logged for each inference cycle, expressed as a number between 0
and 255.
![Serial Plotter with Nano 33 BLE Sense](../../docs/hello_world_serial_plotter.png)

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "constants.h"
// This is tuned so that a full cycle takes ~6.6 seconds on an
// Arduino Nano 33 BLE.
const int kInferencesPerCycle = 200;

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
// Arduino automatically calls the setup() and loop() functions in a sketch, so
// where other systems need their own main routine in this file, it can be left
// empty.

View File

@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "Arduino.h"
#include "constants.h"
#include "output_handler.h"
// The pin of the Arduino's built-in LED
int led = LED_BUILTIN;
// Track whether the function has run at least once
bool initialized = false;
// Animates a dot across the screen to represent the current x and y values
void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value,
float y_value) {
// Do this only once
if (!initialized) {
// Set the LED pin to output
pinMode(led, OUTPUT);
initialized = true;
}
// Calculate the brightness of the LED such that y=-1 is fully off
// and y=1 is fully on. The LED's brightness can range from 0-255.
int brightness = (int)(127.5f * (y_value + 1));
// The y value is not actually constrained to the range [-1, 1], so we need to
// clamp the brightness value before sending it to the PWM/LED.
int brightness_clamped = std::min(255, std::max(0, brightness));
// Set the brightness of the LED. If the specified pin does not support PWM,
// this will result in the LED being on when brightness_clamped > 127, off
// otherwise.
analogWrite(led, brightness_clamped);
// Log the current brightness value for display in the Arduino plotter
TF_LITE_REPORT_ERROR(error_reporter, "%d\n", brightness);
delay(33);
}

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_
// This constant represents the range of x values our model was trained on,
// which is from 0 to (2 * Pi). We approximate Pi to avoid requiring additional
// libraries.
const float kXrange = 2.f * 3.14159265359f;
// This constant determines the number of inferences to perform across the range
// of x values defined above. Since each inference takes time, the higher this
// number, the more time it will take to run through the entire range. The value
// of this constant can be tuned so that one full cycle takes a desired amount
// of time. Since different devices take different amounts of time to perform
// inference, this value should be defined per-device.
extern const int kInferencesPerCycle;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_

View File

@ -0,0 +1,123 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <TensorFlowLite.h>
#include "main_functions.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "constants.h"
#include "model.h"
#include "output_handler.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;
constexpr int kTensorArenaSize = 2000;
uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
tflite::InitializeTarget();
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// This pulls in all the operation implementations we need.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::AllOpsResolver resolver;
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
// Obtain pointers to the model's input and output tensors.
input = interpreter->input(0);
output = interpreter->output(0);
// Keep track of how many inferences we have performed.
inference_count = 0;
}
// The name of this function is important for Arduino compatibility.
void loop() {
// Calculate an x value to feed into the model. We compare the current
// inference_count to the number of inferences per cycle to determine
// our position within the range of possible x values the model was
// trained on, and use this to calculate a value.
float position = static_cast<float>(inference_count) /
static_cast<float>(kInferencesPerCycle);
float x = position * kXrange;
// Quantize the input from floating-point to integer
int8_t x_quantized = x / input->params.scale + input->params.zero_point;
// Place the quantized input in the model's input tensor
input->data.int8[0] = x_quantized;
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
static_cast<double>(x));
return;
}
// Obtain the quantized output from model's output tensor
int8_t y_quantized = output->data.int8[0];
// Dequantize the output from integer to floating-point
float y = (y_quantized - output->params.zero_point) * output->params.scale;
// Output the results. A custom HandleOutput function can be implemented
// for each supported hardware target.
HandleOutput(error_reporter, x, y);
// Increment the inference_counter, and reset it if we have reached
// the total number per cycle
inference_count += 1;
if (inference_count >= kInferencesPerCycle) inference_count = 0;
}

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_
// Expose a C friendly interface for main functions.
#ifdef __cplusplus
extern "C" {
#endif
// Initializes all data needed for the example. The name is important, and needs
// to be setup() for Arduino compatibility.
void setup();
// Runs one iteration of data gathering and inference. This should be called
// repeatedly from the application code. The name needs to be loop() for Arduino
// compatibility.
void loop();
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_

View File

@ -0,0 +1,237 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Automatically created from a TensorFlow Lite flatbuffer using the command:
// xxd -i model.tflite > model.cc
// This is a standard TensorFlow Lite model file that has been converted into a
// C data array, so it can be easily compiled into a binary for devices that
// don't have a file system.
// See train/README.md for a full description of the creation process.
#include "model.h"
// Keep model aligned to 8 bytes to guarantee aligned 64-bit accesses.
alignas(8) const unsigned char g_model[] = {
0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x14, 0x00, 0x20, 0x00,
0x1c, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x08, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
0x98, 0x00, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00,
0x2c, 0x03, 0x00, 0x00, 0x30, 0x09, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x60, 0xf7, 0xff, 0xff,
0x10, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00,
0x44, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x73, 0x65, 0x72, 0x76,
0x65, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x73, 0x65, 0x72, 0x76,
0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xff, 0xff, 0xff,
0x09, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x34, 0x00, 0x01, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x76, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
0x0d, 0x00, 0x00, 0x00, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f,
0x69, 0x6e, 0x70, 0x75, 0x74, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00,
0x08, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x13, 0x00, 0x00, 0x00, 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x75, 0x6e, 0x74,
0x69, 0x6d, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x50, 0x02, 0x00, 0x00, 0x48, 0x02, 0x00, 0x00,
0x34, 0x02, 0x00, 0x00, 0xdc, 0x01, 0x00, 0x00, 0x8c, 0x01, 0x00, 0x00,
0x6c, 0x01, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00,
0x34, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0xfa, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x31, 0x2e, 0x35, 0x2e, 0x30, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x84, 0xfd, 0xff, 0xff,
0x88, 0xfd, 0xff, 0xff, 0x8c, 0xfd, 0xff, 0xff, 0x22, 0xfe, 0xff, 0xff,
0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x21, 0xa5, 0x8b, 0xca,
0x5e, 0x1d, 0xce, 0x42, 0x9d, 0xce, 0x1f, 0xb0, 0xdf, 0x54, 0x2f, 0x81,
0x3e, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0xee, 0xfc, 0x00, 0xec, 0x05, 0x17, 0xef, 0xec, 0xe6, 0xf8, 0x03, 0x01,
0x00, 0xfa, 0xf8, 0xf5, 0xdc, 0xeb, 0x27, 0x14, 0xf1, 0xde, 0xe2, 0xdb,
0xf0, 0xde, 0x31, 0x06, 0x02, 0xe6, 0xee, 0xf9, 0x00, 0x16, 0x07, 0xe0,
0xfe, 0xff, 0xe9, 0x06, 0xe7, 0xef, 0x81, 0x1b, 0x18, 0xea, 0xc9, 0x01,
0x0f, 0x00, 0xda, 0xf7, 0x0e, 0xec, 0x13, 0x1f, 0x04, 0x13, 0xb4, 0xe6,
0xfd, 0x06, 0xb9, 0xe0, 0x0d, 0xec, 0xf0, 0xde, 0xeb, 0xf7, 0x05, 0x26,
0x1a, 0xe4, 0x6f, 0x1a, 0xea, 0x1e, 0x35, 0xdf, 0x1a, 0xf3, 0xf1, 0x19,
0x0f, 0x03, 0x1b, 0xe1, 0xde, 0x13, 0xf6, 0x19, 0xff, 0xf6, 0x1b, 0x18,
0xf0, 0x1c, 0xda, 0x1b, 0x1b, 0x20, 0xe5, 0x1a, 0xf5, 0xff, 0x96, 0x0b,
0x00, 0x01, 0xcd, 0xde, 0x0d, 0xf6, 0x16, 0xe3, 0xed, 0xfc, 0x0e, 0xe9,
0xfa, 0xeb, 0x5c, 0xfc, 0x1d, 0x02, 0x5b, 0xe2, 0xe1, 0xf5, 0x15, 0xec,
0xf4, 0x00, 0x13, 0x05, 0xec, 0x0c, 0x1d, 0x14, 0x0e, 0xe7, 0x0b, 0xf4,
0x19, 0x00, 0xd7, 0x05, 0x27, 0x02, 0x15, 0xea, 0xea, 0x02, 0x9b, 0x00,
0x0c, 0xfa, 0xe8, 0xea, 0xfd, 0x00, 0x14, 0xfd, 0x0b, 0x02, 0xef, 0xee,
0x06, 0xee, 0x01, 0x0d, 0x06, 0xe6, 0xf7, 0x11, 0xf7, 0x09, 0xf8, 0xf1,
0x21, 0xff, 0x0e, 0xf3, 0xec, 0x12, 0x26, 0x1d, 0xf2, 0xe9, 0x28, 0x18,
0xe0, 0xfb, 0xf3, 0xf4, 0x05, 0x1d, 0x1d, 0xfb, 0xfd, 0x1e, 0xfc, 0x11,
0xe8, 0x07, 0x09, 0x03, 0x12, 0xf2, 0x36, 0xfb, 0xdc, 0x1c, 0xf9, 0xef,
0xf3, 0xe7, 0x6f, 0x0c, 0x1d, 0x00, 0x45, 0xfd, 0x0e, 0xf0, 0x0b, 0x19,
0x1a, 0xfa, 0xe0, 0x19, 0x1f, 0x13, 0x36, 0x1c, 0x12, 0xeb, 0x3b, 0x0c,
0xb4, 0xcb, 0xe6, 0x13, 0xfa, 0xeb, 0xf1, 0x06, 0x1c, 0xfa, 0x18, 0xe5,
0xeb, 0xcb, 0x0c, 0xf4, 0x4a, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x75, 0x1c, 0x11, 0xe1, 0x0c, 0x81, 0xa5, 0x42,
0xfe, 0xd5, 0xd4, 0xb2, 0x61, 0x78, 0x19, 0xdf, 0x66, 0xff, 0xff, 0xff,
0x04, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x77, 0x0b, 0x00, 0x00, 0x53, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0x77, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xd3, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x72, 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2f, 0x07, 0x00, 0x00,
0x67, 0xf5, 0xff, 0xff, 0x34, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0xb2, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xb5, 0x04, 0x00, 0x00, 0x78, 0x0a, 0x00, 0x00,
0x2d, 0x06, 0x00, 0x00, 0x71, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0x9a, 0x0a, 0x00, 0x00, 0xfe, 0xf7, 0xff, 0xff, 0x0e, 0x05, 0x00, 0x00,
0xd4, 0x09, 0x00, 0x00, 0x47, 0xfe, 0xff, 0xff, 0xb6, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xac, 0xf7, 0xff, 0xff, 0x4b, 0xf9, 0xff, 0xff,
0x4a, 0x05, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00,
0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x8c, 0xef, 0xff, 0xff, 0x84, 0xff, 0xff, 0xff, 0x88, 0xff, 0xff, 0xff,
0x0f, 0x00, 0x00, 0x00, 0x4d, 0x4c, 0x49, 0x52, 0x20, 0x43, 0x6f, 0x6e,
0x76, 0x65, 0x72, 0x74, 0x65, 0x64, 0x2e, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0xdc, 0x00, 0x00, 0x00,
0xe0, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x84, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x96, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08,
0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0xca, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x08, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0xba, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
0x05, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x16, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x04, 0x00,
0x0e, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08,
0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00,
0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
0x01, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x4c, 0x04, 0x00, 0x00,
0xd0, 0x03, 0x00, 0x00, 0x68, 0x03, 0x00, 0x00, 0x0c, 0x03, 0x00, 0x00,
0x98, 0x02, 0x00, 0x00, 0x24, 0x02, 0x00, 0x00, 0xb0, 0x01, 0x00, 0x00,
0x24, 0x01, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0xf0, 0xfb, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x54, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
0x6c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0xdc, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x4a, 0xce, 0x0a, 0x3c, 0x01, 0x00, 0x00, 0x00,
0x34, 0x84, 0x85, 0x3f, 0x01, 0x00, 0x00, 0x00, 0xc5, 0x02, 0x8f, 0xbf,
0x1e, 0x00, 0x00, 0x00, 0x53, 0x74, 0x61, 0x74, 0x65, 0x66, 0x75, 0x6c,
0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x65, 0x64, 0x43,
0x61, 0x6c, 0x6c, 0x3a, 0x30, 0x5f, 0x69, 0x6e, 0x74, 0x38, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x80, 0xfc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x54, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
0x64, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
0x10, 0x00, 0x00, 0x00, 0x6c, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0x93, 0xd0, 0xc0, 0x3b, 0x01, 0x00, 0x00, 0x00,
0xc2, 0x0f, 0xc0, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x74, 0x66, 0x6c, 0x2e, 0x66, 0x75, 0x6c, 0x6c,
0x79, 0x5f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x31,
0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x08, 0xfd, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09, 0x64, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0xff, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0xf4, 0xfc, 0xff, 0xff,
0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x24, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0xe0, 0xdb, 0x47, 0x3c, 0x01, 0x00, 0x00, 0x00, 0x04, 0x14, 0x47, 0x40,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00,
0x74, 0x66, 0x6c, 0x2e, 0x66, 0x75, 0x6c, 0x6c, 0x79, 0x5f, 0x63, 0x6f,
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x00, 0x02, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x02, 0xfe, 0xff, 0xff,
0x14, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09, 0x50, 0x00, 0x00, 0x00, 0x6c, 0xfd, 0xff, 0xff,
0x10, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xfb, 0x4b, 0x0b, 0x3c,
0x01, 0x00, 0x00, 0x00, 0x40, 0x84, 0x4b, 0x3f, 0x01, 0x00, 0x00, 0x00,
0x63, 0x35, 0x8a, 0xbf, 0x0d, 0x00, 0x00, 0x00, 0x73, 0x74, 0x64, 0x2e,
0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x32, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x72, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00,
0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x50, 0x00, 0x00, 0x00,
0xdc, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x60, 0x01, 0x4f, 0x3c, 0x01, 0x00, 0x00, 0x00, 0x47, 0x6d, 0xb3, 0x3f,
0x01, 0x00, 0x00, 0x00, 0x5d, 0x63, 0xcd, 0xbf, 0x0d, 0x00, 0x00, 0x00,
0x73, 0x74, 0x64, 0x2e, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74,
0x31, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0xe2, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x48, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
0x50, 0x00, 0x00, 0x00, 0x4c, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0xd5, 0x6b, 0x8a, 0x3b, 0x01, 0x00, 0x00, 0x00,
0xab, 0x49, 0x01, 0x3f, 0x01, 0x00, 0x00, 0x00, 0xfd, 0x56, 0x09, 0xbf,
0x0c, 0x00, 0x00, 0x00, 0x73, 0x74, 0x64, 0x2e, 0x63, 0x6f, 0x6e, 0x73,
0x74, 0x61, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x52, 0xff, 0xff, 0xff,
0x14, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, 0x44, 0xff, 0xff, 0xff,
0x08, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x28, 0xb3, 0xd9, 0x38, 0x0c, 0x00, 0x00, 0x00,
0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x2f, 0x62, 0x69, 0x61, 0x73,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0xaa, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x38, 0x00, 0x00, 0x00,
0x9c, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0xdd, 0x9b, 0x21, 0x39, 0x0c, 0x00, 0x00, 0x00,
0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x33, 0x2f, 0x62, 0x69, 0x61, 0x73,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x13, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
0x48, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0xf4, 0xd4, 0x51, 0x38, 0x0c, 0x00, 0x00, 0x00, 0x64, 0x65, 0x6e, 0x73,
0x65, 0x5f, 0x34, 0x2f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x1c, 0x00,
0x18, 0x00, 0x17, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00,
0x00, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09, 0x84, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x24, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x5d, 0x4f, 0xc9, 0x3c, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x86, 0xc8, 0x40,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00,
0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61,
0x75, 0x6c, 0x74, 0x5f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f,
0x69, 0x6e, 0x70, 0x75, 0x74, 0x3a, 0x30, 0x5f, 0x69, 0x6e, 0x74, 0x38,
0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00,
0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd8, 0xff, 0xff, 0xff,
0x06, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06,
0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72,
0x0c, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09};
const int g_model_len = 2488;

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Automatically created from a TensorFlow Lite flatbuffer using the command:
// xxd -i model.tflite > model.cc
// This is a standard TensorFlow Lite model file that has been converted into a
// C data array, so it can be easily compiled into a binary for devices that
// don't have a file system.
// See train/README.md for a full description of the creation process.
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MODEL_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MODEL_H_
extern const unsigned char g_model[];
extern const int g_model_len;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MODEL_H_

View File

@ -0,0 +1,26 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Called by the main loop to produce some output based on the x and y values
void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value,
float y_value);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_

View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,114 @@
# Magic Wand
Magic Wand example for [TensorFlow Lite Micro](https://www.tensorflow.org/lite/microcontrollers) on the [Arduino Nano 33 BLE Sense](https://store-usa.arduino.cc/products/arduino-nano-33-ble-sense).
## Table of contents
<!--ts-->
* [Introduction](#introduction)
* [Hardware Requirements](#hardware-requirements)
* [Installing the Sketch](#installing-the-sketch)
* [Arduino Desktop IDE](#arduino-desktop-ide)
* [Building the Wand](#building-the-wand)
* [Using the wand](#using-the-wand)
* [Viewing Gestures in the Browser](#viewing-gestures-in-the-browser)
* [Pretrained Model](#pretrained-model)
* [Recording Gestures](#recording-gestures)
* [Training](#training)
* [Deployment](#deployment)
<!--te-->
## Introduction
This project shows you how to recognize gestures made by waving a magic wand, using machine learning to analyze accelerometer and gyroscope data. It demonstrates the three main stages of an end-to-end machine learning project:
- **Gathering Data**. Using a Bluetooth connection to a web page, you can capture gestures, label them, and download the results.
- **Training**. A Python notebook on the free Colab service shows how to use TensorFlow to train a model to recognize gestures from your data.
- **Deployment**. You can deploy your trained model to the Arduino board using TensorFlow Lite Micro and the Arduino IDE.
## Hardware Requirements
You'll need the following:
- Arduino Nano 33 BLE Sense board. These are available as part of [the TinyML Starter Kit](https://store-usa.arduino.cc/products/arduino-tiny-machine-learning-kit), or separately from Arduino or resellers. Other Arduinos won't work unfortunately, because the Bluetooth and sensor code rely on accessing the particular hardware of the Nano 33 BLE Sense.
- MicroUSB cable. This is included in the TinyML Kit, but you'll need a USB-A adaptor too if your computer only has USB-C ports.
- Computer. The Arduino toolchain runs on Linux, Windows, and MacOS, so you should be able to use most laptops, desktops, or even a Raspberry Pi. For the training process, you'll also need an up-to-date version of the Chrome web browser so you can use the Web Bluetooth APIs.
- Stick. We'll be attaching your Arduino to a 'wand', but this can be practically anything, as long as it's roughly a foot (30 centimeters) long.
## Installing the Sketch
You'll need to ensure you can successfully connect and load sketches onto your Arduino board,
using the desktop IDE.
Once you've made sure you can load a simple sketch successfully, you'll follow these steps:
### Arduino Desktop IDE
If you're running using the Arduino IDE application, you'll need to fetch the latest version of this sketch.
To install the TensorFlow Lite Micro for Arduino library, see the
[how to install](../../README.md#how-to-install) instructions.
Open up the magic_wand.ino file in the Arduino editor, and make sure the Arduino board is visible and connected to the right port. You'll need to search for the some libraries that the sketch depends on, using `Sketch->Include Library->Manage Libraries` from the main menu. The [Arduino_LSM9DS1](https://github.com/arduino-libraries/Arduino_LSM9DS1) lets us access the accelerometer and gyroscope readings from the board's IMU, and you need at least version 1.1.0. We'll be using Bluetooth to communicate with the web page, so you should also search for [ArduinoBLE](https://www.arduino.cc/en/Reference/ArduinoBLE) and make sure you've got version 1.1.3 or newer.
You should now be able to press the upload button to compile and install the sketch on your board.
## Building the Wand
The 'wand' itself can be as simple as a stick, it doesn't need to do anything other than keep the board at its end as you hold the other end and wave it about. A cheap wand from an online retailer will work. A simple piece of wood or ruler works just as well.
You should place the board at the end of the wand, with the USB socket facing downwards, towards where you hold it, so that the cable can run down the handle. The sketch is designed to compensate for any rotation of the board around the wand's shaft, so as long as it's parallel to the wand's length the board's twist won't matter. Use sticky tape or some other easy-to-remove method to attach the board to the wand, and hold the cable in place along the shaft. The end result should look something like this:
![Image of board attached to wand](../../docs/magic_wand_attachment.jpg)
If an ASCII-art diagram is more helpful, here's what you should aim for:
```
____
| |<- Arduino board
| |
| () | <- Reset button
| |
-TT- <- USB port
||
||<- Wand
....
||
||
()
```
## Using the wand
The wand can be used with or without the Nano 33 BLE attached to the Tiny Machine Learning Shield. It is easier to use without, as your hand will not tire as quickly. The wand should be held as you would a pencil or pen, about 8 inches from the USB socket. Use your wrist to make strokes, not your arm. Strokes will need to be made somewhat quickly, without stopping during changes in direction.
## Viewing Gestures in the Browser
To preview and record gestures, we'll be connecting the sketch you've just uploaded to a web page, using Bluetooth and Chrome's WebBLE API. The code for the page is [in this repository](https://github.com/tensorflow/tflite-micro-arduino-examples/tree/main/examples/magic_wand/website), but it's all implemented using browser-side Javascript in a static HTML page, so you don't need to host it on your own server. Just dragging and dropping the `index.html` file into your browser should work.
If the sketch has uploaded successfully, the Arduino should be advertising itself through Bluetooth. On the web page, press the 'Bluetooth' button to connect, and you should see a dialog appear asking you to pair with a device. After a second or two, there should be an entry that looks something like "BLESense-2F00". Click on that to pair, and you should be returned to the web page.
If everything is working as expected, the Bluetooth button should turn blue, with "Connected" next to it. Now try moving the wand and look at the square below the button. As you gesture, you should see tracks appearing as lines in the web page in real time. Try doing small circles, or a 'Z' like Zorro!
## Pretrained Model
The sketch comes with a model that's been trained to recognize the hand-drawn digits zero to nine. This is based on a small dataset recorded by Google, so your accuracy may vary, but if you bring up the Serial Monitor in the Arduino IDE you can see what the model predicts for each gesture you make, with a confidence score between 0% and 100%, as well as ASCII art of the gesture outline.
## Recording Gestures
As you get familiar with the wand, so should notice that the gestures you have performed start to stack on the right side of the web page. This is where the data you'll eventually want to use for training is stored. When you leave or refresh the web page, these gestures will be lost, so make sure you use the "Download Data" link to save them locally if you've generated a number of them.
The gestures are automatically split up by times when the wand is kept still. These pauses act like spaces between words, and so when you've finished a gesture you should stop moving the wand so that it ends cleanly.
To get started, you should pick a couple of easy gestures to perform, like a 'Z' and an 'O'. As you make these gestures, you should see them appear in the right-hand stack of gestures. You can look at the shapes shown there to understand whether the gestures came out cleanly. A good rule of thumb is that if you can't tell what the gesture is by looking at it in the stack, then a model will have a hard time recognizing it too.
Once you have ten or so of each gesture, scroll through the stack to review them. If any don't seem very recognizable, or are too 'sloppy' (which is very subjective unfortunately), then you can press the trash can button on the top right of the image to remove it. If you removed any, try recording some more so you have at least ten of each gesture. If you are happy with a gesture, click on the label at the top left to type in the correct name for it (for example `O` or `Z`).
After you've reviewed and labeled all of your data, you can download it as a JSON text file that can be used for training.
## Training
Once you have data, you should [run the Python training notebook in Colab](https://colab.research.google.com/github/tensorflow/tflite-micro-arduino-examples/blob/main/examples/magic_wand/train/train_magic_wand_model.ipynb) and follow the steps to create and export your own model.
## Deployment
The Python training process should give you a `magic_wand_model_data.cc` file. Replace the file of the same name (but with a `.cpp` suffix) that's in the sketch you're using with this version. You'll also need to update the `labels` and `label_count` variables near the top of the `magic_wand.ino` to reflect any changes you made to the gestures you're trying to recognize.
Upload this modified sketch, and you should be able to perform gestures and see them recognized in the Serial Monitor of your Arduino editor.

View File

@ -0,0 +1,708 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ArduinoBLE.h>
#include <Arduino_LSM9DS1.h>
#include <TensorFlowLite.h>
#include <cmath>
#include "magic_wand_model_data.h"
#include "rasterize_stroke.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
#define BLE_SENSE_UUID(val) ("4798e0f2-" val "-4d68-af64-8a8f5258404e")
#undef MAGIC_WAND_DEBUG
namespace {
const int VERSION = 0x00000000;
constexpr int stroke_transmit_stride = 2;
constexpr int stroke_transmit_max_length = 160;
constexpr int stroke_max_length =
stroke_transmit_max_length * stroke_transmit_stride;
constexpr int stroke_points_byte_count =
2 * sizeof(int8_t) * stroke_transmit_max_length;
constexpr int stroke_struct_byte_count =
(2 * sizeof(int32_t)) + stroke_points_byte_count;
constexpr int moving_sample_count = 50;
constexpr int raster_width = 32;
constexpr int raster_height = 32;
constexpr int raster_channels = 3;
constexpr int raster_byte_count =
raster_height * raster_width * raster_channels;
int8_t raster_buffer[raster_byte_count];
BLEService service(BLE_SENSE_UUID("0000"));
BLECharacteristic strokeCharacteristic(BLE_SENSE_UUID("300a"), BLERead,
stroke_struct_byte_count);
// String to calculate the local and device name
String name;
// A buffer holding the last 600 sets of 3-channel values from the
// accelerometer.
constexpr int acceleration_data_length = 600 * 3;
float acceleration_data[acceleration_data_length] = {};
// The next free entry in the data array.
int acceleration_data_index = 0;
float acceleration_sample_rate = 0.0f;
// A buffer holding the last 600 sets of 3-channel values from the gyroscope.
constexpr int gyroscope_data_length = 600 * 3;
float gyroscope_data[gyroscope_data_length] = {};
float orientation_data[gyroscope_data_length] = {};
// The next free entry in the data array.
int gyroscope_data_index = 0;
float gyroscope_sample_rate = 0.0f;
float current_velocity[3] = {0.0f, 0.0f, 0.0f};
float current_position[3] = {0.0f, 0.0f, 0.0f};
float current_gravity[3] = {0.0f, 0.0f, 0.0f};
float current_gyroscope_drift[3] = {0.0f, 0.0f, 0.0f};
int32_t stroke_length = 0;
uint8_t stroke_struct_buffer[stroke_struct_byte_count] = {};
int32_t* stroke_state = reinterpret_cast<int32_t*>(stroke_struct_buffer);
int32_t* stroke_transmit_length =
reinterpret_cast<int32_t*>(stroke_struct_buffer + sizeof(int32_t));
int8_t* stroke_points =
reinterpret_cast<int8_t*>(stroke_struct_buffer + (sizeof(int32_t) * 2));
enum {
eWaiting = 0,
eDrawing = 1,
eDone = 2,
};
// Create an area of memory to use for input, output, and intermediate arrays.
// The size of this will depend on the model you're using, and may need to be
// determined by experimentation.
constexpr int kTensorArenaSize = 30 * 1024;
uint8_t tensor_arena[kTensorArenaSize];
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
constexpr int label_count = 10;
const char* labels[label_count] = {"0", "1", "2", "3", "4",
"5", "6", "7", "8", "9"};
void SetupIMU() {
// Make sure we are pulling measurements into a FIFO.
// If you see an error on this line, make sure you have at least v1.1.0 of the
// Arduino_LSM9DS1 library installed.
IMU.setContinuousMode();
acceleration_sample_rate = IMU.accelerationSampleRate();
gyroscope_sample_rate = IMU.gyroscopeSampleRate();
#ifdef MAGIC_WAND_DEBUG
float rate_frac;
float rate_int;
rate_frac = modf(acceleration_sample_rate, &rate_int);
TF_LITE_REPORT_ERROR(error_reporter, "Acceleration sample rate %d.%d Hz",
static_cast<int32_t>(rate_int),
static_cast<int32_t>(rate_frac * 100));
rate_frac = modf(gyroscope_sample_rate, &rate_int);
TF_LITE_REPORT_ERROR(error_reporter, "Gyroscope sample rate %d.%d Hz",
static_cast<int32_t>(rate_int),
static_cast<int32_t>(rate_frac * 100));
#endif // MAGIC_WAND_DEBUG
}
void ReadAccelerometerAndGyroscope(int* new_accelerometer_samples,
int* new_gyroscope_samples) {
// Keep track of whether we stored any new data
*new_accelerometer_samples = 0;
*new_gyroscope_samples = 0;
// Loop through new samples and add to buffer
while (IMU.accelerationAvailable()) {
const int gyroscope_index = (gyroscope_data_index % gyroscope_data_length);
gyroscope_data_index += 3;
float* current_gyroscope_data = &gyroscope_data[gyroscope_index];
// Read each sample, removing it from the device's FIFO buffer
if (!IMU.readGyroscope(current_gyroscope_data[0], current_gyroscope_data[1],
current_gyroscope_data[2])) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to read gyroscope data");
break;
}
*new_gyroscope_samples += 1;
const int acceleration_index =
(acceleration_data_index % acceleration_data_length);
acceleration_data_index += 3;
float* current_acceleration_data = &acceleration_data[acceleration_index];
// Read each sample, removing it from the device's FIFO buffer
if (!IMU.readAcceleration(current_acceleration_data[0],
current_acceleration_data[1],
current_acceleration_data[2])) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to read acceleration data");
break;
}
*new_accelerometer_samples += 1;
}
}
float VectorMagnitude(const float* vec) {
const float x = vec[0];
const float y = vec[1];
const float z = vec[2];
return sqrtf((x * x) + (y * y) + (z * z));
}
void EstimateGravityDirection(float* gravity) {
int samples_to_average = 100;
if (samples_to_average >= acceleration_data_index) {
samples_to_average = acceleration_data_index;
}
const int start_index =
((acceleration_data_index +
(acceleration_data_length - (3 * (samples_to_average + 1)))) %
acceleration_data_length);
float x_total = 0.0f;
float y_total = 0.0f;
float z_total = 0.0f;
for (int i = 0; i < samples_to_average; ++i) {
const int index = ((start_index + (i * 3)) % acceleration_data_length);
const float* entry = &acceleration_data[index];
const float x = entry[0];
const float y = entry[1];
const float z = entry[2];
x_total += x;
y_total += y;
z_total += z;
}
gravity[0] = x_total / samples_to_average;
gravity[1] = y_total / samples_to_average;
gravity[2] = z_total / samples_to_average;
}
void UpdateVelocity(int new_samples, float* gravity) {
const float gravity_x = gravity[0];
const float gravity_y = gravity[1];
const float gravity_z = gravity[2];
const int start_index =
((acceleration_data_index +
(acceleration_data_length - (3 * (new_samples + 1)))) %
acceleration_data_length);
const float friction_fudge = 0.98f;
for (int i = 0; i < new_samples; ++i) {
const int index = ((start_index + (i * 3)) % acceleration_data_length);
const float* entry = &acceleration_data[index];
const float ax = entry[0];
const float ay = entry[1];
const float az = entry[2];
// Try to remove gravity from the raw acceleration values.
const float ax_minus_gravity = ax - gravity_x;
const float ay_minus_gravity = ay - gravity_y;
const float az_minus_gravity = az - gravity_z;
// Update velocity based on the normalized acceleration.
current_velocity[0] += ax_minus_gravity;
current_velocity[1] += ay_minus_gravity;
current_velocity[2] += az_minus_gravity;
// Dampen the velocity slightly with a fudge factor to stop it exploding.
current_velocity[0] *= friction_fudge;
current_velocity[1] *= friction_fudge;
current_velocity[2] *= friction_fudge;
// Update the position estimate based on the velocity.
current_position[0] += current_velocity[0];
current_position[1] += current_velocity[1];
current_position[2] += current_velocity[2];
}
}
void EstimateGyroscopeDrift(float* drift) {
const bool isMoving = VectorMagnitude(current_velocity) > 0.1f;
if (isMoving) {
return;
}
int samples_to_average = 20;
if (samples_to_average >= gyroscope_data_index) {
samples_to_average = gyroscope_data_index;
}
const int start_index =
((gyroscope_data_index +
(gyroscope_data_length - (3 * (samples_to_average + 1)))) %
gyroscope_data_length);
float x_total = 0.0f;
float y_total = 0.0f;
float z_total = 0.0f;
for (int i = 0; i < samples_to_average; ++i) {
const int index = ((start_index + (i * 3)) % gyroscope_data_length);
const float* entry = &gyroscope_data[index];
const float x = entry[0];
const float y = entry[1];
const float z = entry[2];
x_total += x;
y_total += y;
z_total += z;
}
drift[0] = x_total / samples_to_average;
drift[1] = y_total / samples_to_average;
drift[2] = z_total / samples_to_average;
}
void UpdateOrientation(int new_samples, float* gravity, float* drift) {
const float drift_x = drift[0];
const float drift_y = drift[1];
const float drift_z = drift[2];
const int start_index =
((gyroscope_data_index + (gyroscope_data_length - (3 * new_samples))) %
gyroscope_data_length);
// The gyroscope values are in degrees-per-second, so to approximate
// degrees in the integrated orientation, we need to divide each value
// by the number of samples each second.
const float recip_sample_rate = 1.0f / gyroscope_sample_rate;
for (int i = 0; i < new_samples; ++i) {
const int index = ((start_index + (i * 3)) % gyroscope_data_length);
const float* entry = &gyroscope_data[index];
const float dx = entry[0];
const float dy = entry[1];
const float dz = entry[2];
// Try to remove sensor errors from the raw gyroscope values.
const float dx_minus_drift = dx - drift_x;
const float dy_minus_drift = dy - drift_y;
const float dz_minus_drift = dz - drift_z;
// Convert from degrees-per-second to appropriate units for this
// time interval.
const float dx_normalized = dx_minus_drift * recip_sample_rate;
const float dy_normalized = dy_minus_drift * recip_sample_rate;
const float dz_normalized = dz_minus_drift * recip_sample_rate;
// Update orientation based on the gyroscope data.
float* current_orientation = &orientation_data[index];
const int previous_index =
(index + (gyroscope_data_length - 3)) % gyroscope_data_length;
const float* previous_orientation = &orientation_data[previous_index];
current_orientation[0] = previous_orientation[0] + dx_normalized;
current_orientation[1] = previous_orientation[1] + dy_normalized;
current_orientation[2] = previous_orientation[2] + dz_normalized;
}
}
bool IsMoving(int samples_before) {
constexpr float moving_threshold = 9.0f;
if ((gyroscope_data_index - samples_before) < moving_sample_count) {
return false;
}
const int start_index =
((gyroscope_data_index + (gyroscope_data_length -
(3 * (moving_sample_count + samples_before)))) %
gyroscope_data_length);
float total = 0.0f;
for (int i = 0; i < moving_sample_count; ++i) {
const int index = ((start_index + (i * 3)) % gyroscope_data_length);
float* current_orientation = &orientation_data[index];
const int previous_index =
(index + (gyroscope_data_length - 3)) % gyroscope_data_length;
const float* previous_orientation = &orientation_data[previous_index];
const float dx = current_orientation[0] - previous_orientation[0];
const float dy = current_orientation[1] - previous_orientation[1];
const float dz = current_orientation[2] - previous_orientation[2];
const float mag_squared = (dx * dx) + (dy * dy) + (dz * dz);
total += mag_squared;
}
const bool is_moving = (total > moving_threshold);
return is_moving;
}
void UpdateStroke(int new_samples, bool* done_just_triggered) {
constexpr int minimum_stroke_length = moving_sample_count + 10;
constexpr float minimum_stroke_size = 0.2f;
*done_just_triggered = false;
for (int i = 0; i < new_samples; ++i) {
const int current_head = (new_samples - (i + 1));
const bool is_moving = IsMoving(current_head);
const int32_t old_state = *stroke_state;
if ((old_state == eWaiting) || (old_state == eDone)) {
if (is_moving) {
stroke_length = moving_sample_count;
*stroke_state = eDrawing;
}
} else if (old_state == eDrawing) {
if (!is_moving) {
if (stroke_length > minimum_stroke_length) {
*stroke_state = eDone;
} else {
stroke_length = 0;
*stroke_state = eWaiting;
#ifdef MAGIC_WAND_DEBUG
TF_LITE_REPORT_ERROR(error_reporter, "stroke length too small");
#endif // MAGIC_WAND_DEBUG
}
}
}
const bool is_waiting = (*stroke_state == eWaiting);
if (is_waiting) {
continue;
}
stroke_length += 1;
if (stroke_length > stroke_max_length) {
stroke_length = stroke_max_length;
}
// Only recalculate the full stroke if it's needed.
const bool draw_last_point =
((i == (new_samples - 1)) && (*stroke_state == eDrawing));
*done_just_triggered = ((old_state != eDone) && (*stroke_state == eDone));
if (!(*done_just_triggered || draw_last_point)) {
continue;
}
const int start_index =
((gyroscope_data_index +
(gyroscope_data_length - (3 * (stroke_length + current_head)))) %
gyroscope_data_length);
float x_total = 0.0f;
float y_total = 0.0f;
float z_total = 0.0f;
for (int j = 0; j < stroke_length; ++j) {
const int index = ((start_index + (j * 3)) % gyroscope_data_length);
const float* entry = &orientation_data[index];
x_total += entry[0];
y_total += entry[1];
z_total += entry[2];
}
const float y_mean = y_total / stroke_length;
const float z_mean = z_total / stroke_length;
constexpr float range = 45.0f;
const float gy = current_gravity[1];
const float gz = current_gravity[2];
float gmag = sqrtf((gy * gy) + (gz * gz));
if (gmag < 0.0001f) {
gmag = 0.0001f;
}
const float ngy = gy / gmag;
const float ngz = gz / gmag;
const float xaxisz = -ngz;
const float xaxisy = -ngy;
const float yaxisz = -ngy;
const float yaxisy = ngz;
*stroke_transmit_length = stroke_length / stroke_transmit_stride;
float x_min = 0;
float y_min = 0;
float x_max = 0;
float y_max = 0;
for (int j = 0; j < *stroke_transmit_length; ++j) {
const int orientation_index =
((start_index + ((j * stroke_transmit_stride) * 3)) %
gyroscope_data_length);
const float* orientation_entry = &orientation_data[orientation_index];
const float orientation_y = orientation_entry[1];
const float orientation_z = orientation_entry[2];
const float ny = (orientation_y - y_mean) / range;
const float nz = (orientation_z - z_mean) / range;
const float x_axis = (xaxisz * nz) + (xaxisy * ny);
const float y_axis = (yaxisz * nz) + (yaxisy * ny);
const int stroke_index = j * 2;
int8_t* stroke_entry = &stroke_points[stroke_index];
int32_t unchecked_x = static_cast<int32_t>(roundf(x_axis * 128.0f));
int8_t stored_x;
if (unchecked_x > 127) {
stored_x = 127;
} else if (unchecked_x < -128) {
stored_x = -128;
} else {
stored_x = unchecked_x;
}
stroke_entry[0] = stored_x;
int32_t unchecked_y = static_cast<int32_t>(roundf(y_axis * 128.0f));
int8_t stored_y;
if (unchecked_y > 127) {
stored_y = 127;
} else if (unchecked_y < -128) {
stored_y = -128;
} else {
stored_y = unchecked_y;
}
stroke_entry[1] = stored_y;
const bool is_first = (j == 0);
if (is_first || (x_axis < x_min)) {
x_min = x_axis;
}
if (is_first || (y_axis < y_min)) {
y_min = y_axis;
}
if (is_first || (x_axis > x_max)) {
x_max = x_axis;
}
if (is_first || (y_axis > y_max)) {
y_max = y_axis;
}
}
// If the stroke is too small, cancel it.
if (*done_just_triggered) {
const float x_range = (x_max - x_min);
const float y_range = (y_max - y_min);
if ((x_range < minimum_stroke_size) && (y_range < minimum_stroke_size)) {
*done_just_triggered = false;
*stroke_state = eWaiting;
*stroke_transmit_length = 0;
stroke_length = 0;
#ifdef MAGIC_WAND_DEBUG
TF_LITE_REPORT_ERROR(error_reporter, "stroke too small");
#endif // MAGIC_WAND_DEBUG
}
}
}
}
} // namespace
void setup() {
tflite::InitializeTarget(); // setup serial port
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
static tflite::MicroErrorReporter micro_error_reporter; // NOLINT
error_reporter = &micro_error_reporter;
TF_LITE_REPORT_ERROR(error_reporter, "Started");
if (!IMU.begin()) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to initialized IMU!");
while (true) {
// NORETURN
}
}
SetupIMU();
if (!BLE.begin()) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to initialized BLE!");
while (true) {
// NORETURN
}
}
String address = BLE.address();
TF_LITE_REPORT_ERROR(error_reporter, "address = %s", address.c_str());
address.toUpperCase();
name = "BLESense-";
name += address[address.length() - 5];
name += address[address.length() - 4];
name += address[address.length() - 2];
name += address[address.length() - 1];
TF_LITE_REPORT_ERROR(error_reporter, "name = %s", name.c_str());
BLE.setLocalName(name.c_str());
BLE.setDeviceName(name.c_str());
BLE.setAdvertisedService(service);
service.addCharacteristic(strokeCharacteristic);
BLE.addService(service);
BLE.advertise();
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
static tflite::MicroMutableOpResolver<4> micro_op_resolver; // NOLINT
micro_op_resolver.AddConv2D();
micro_op_resolver.AddMean();
micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddSoftmax();
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
interpreter->AllocateTensors();
TfLiteTensor* model_input = interpreter->input(0);
if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
(model_input->dims->data[1] != raster_height) ||
(model_input->dims->data[2] != raster_width) ||
(model_input->dims->data[3] != raster_channels) ||
(model_input->type != kTfLiteInt8) ||
(model_input->params.zero_point != -128) ||
(model_input->params.scale != 1.0)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
return;
}
TfLiteTensor* model_output = interpreter->output(0);
if ((model_output->dims->size != 2) || (model_output->dims->data[0] != 1) ||
(model_output->dims->data[1] != label_count) ||
(model_output->type != kTfLiteInt8)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad output tensor parameters in model");
return;
}
}
void loop() {
BLEDevice central = BLE.central();
// if a central is connected to the peripheral:
static bool was_connected_last = false;
if (central && !was_connected_last) {
// print the central's BT address:
TF_LITE_REPORT_ERROR(error_reporter, "Connected to central: %s",
central.address().c_str());
}
was_connected_last = central;
const bool data_available =
IMU.accelerationAvailable() || IMU.gyroscopeAvailable();
if (!data_available) {
return;
}
int accelerometer_samples_read;
int gyroscope_samples_read;
ReadAccelerometerAndGyroscope(&accelerometer_samples_read,
&gyroscope_samples_read);
bool done_just_triggered = false;
if (gyroscope_samples_read > 0) {
EstimateGyroscopeDrift(current_gyroscope_drift);
UpdateOrientation(gyroscope_samples_read, current_gravity,
current_gyroscope_drift);
UpdateStroke(gyroscope_samples_read, &done_just_triggered);
if (central && central.connected()) {
strokeCharacteristic.writeValue(stroke_struct_buffer,
stroke_struct_byte_count);
}
}
if (accelerometer_samples_read > 0) {
EstimateGravityDirection(current_gravity);
UpdateVelocity(accelerometer_samples_read, current_gravity);
}
if (done_just_triggered) {
RasterizeStroke(stroke_points, *stroke_transmit_length, 0.6f, 0.6f,
raster_width, raster_height, raster_buffer);
for (int y = 0; y < raster_height; ++y) {
char line[raster_width + 1];
for (int x = 0; x < raster_width; ++x) {
const int8_t* pixel =
&raster_buffer[(y * raster_width * raster_channels) +
(x * raster_channels)];
const int8_t red = pixel[0];
const int8_t green = pixel[1];
const int8_t blue = pixel[2];
char output;
if ((red > -128) || (green > -128) || (blue > -128)) {
output = '#';
} else {
output = '.';
}
line[x] = output;
}
line[raster_width] = 0;
TF_LITE_REPORT_ERROR(error_reporter, line);
}
#ifdef MAGIC_WAND_DEBUG
TF_LITE_REPORT_ERROR(error_reporter, "tx len: %d", *stroke_transmit_length);
#endif // MAGIC_WAND_DEBUG
TfLiteTensor* model_input = interpreter->input(0);
for (int i = 0; i < raster_byte_count; ++i) {
model_input->data.int8[i] = raster_buffer[i];
}
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed");
return;
}
TfLiteTensor* output = interpreter->output(0);
int8_t max_score;
int max_index;
for (int i = 0; i < label_count; ++i) {
const int8_t score = output->data.int8[i];
if ((i == 0) || (score > max_score)) {
max_score = score;
max_index = i;
}
}
float max_score_f =
(max_score - output->params.zero_point) * output->params.scale;
float max_score_int;
float max_score_frac = modf(max_score_f * 100, &max_score_int);
TF_LITE_REPORT_ERROR(error_reporter, "Found %s (%d.%d%%)",
labels[max_index], static_cast<int>(max_score_int),
static_cast<int>(max_score_frac * 100));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a standard TensorFlow Lite model file that has been converted into a
// C data array, so it can be easily compiled into a binary for devices that
// don't have a file system. It was created using the command:
// xxd -i magic_wand_model.tflite > magic_wand_model_data.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_MAGIC_WAND_MODEL_DATA_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_MAGIC_WAND_MODEL_DATA_H_
extern const unsigned char g_magic_wand_model_data[];
extern const int g_magic_wand_model_data_len;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_MAGIC_WAND_MODEL_DATA_H_

View File

@ -0,0 +1,157 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rasterize_stroke.h"
namespace {
constexpr int kFixedPoint = 4096;
int32_t MulFP(int32_t a, int32_t b) { return (a * b) / kFixedPoint; }
int32_t DivFP(int32_t a, int32_t b) {
if (b == 0) {
b = 1;
}
return (a * kFixedPoint) / b;
}
int32_t FloatToFP(float a) { return static_cast<int32_t>(a * kFixedPoint); }
int32_t NormToCoordFP(int32_t a_fp, int32_t range_fp, int32_t half_size_fp) {
const int32_t norm_fp = DivFP(a_fp, range_fp);
return MulFP(norm_fp, half_size_fp) + half_size_fp;
}
int32_t RoundFPToInt(int32_t a) {
return static_cast<int32_t>((a + (kFixedPoint / 2)) / kFixedPoint);
}
int32_t Gate(int32_t a, int32_t min, int32_t max) {
if (a < min) {
return min;
} else if (a > max) {
return max;
} else {
return a;
}
}
int32_t Abs(int32_t a) {
if (a > 0) {
return a;
} else {
return -a;
}
}
} // namespace
void RasterizeStroke(int8_t* stroke_points, int stroke_points_count,
float x_range, float y_range, int width, int height,
int8_t* out_buffer) {
constexpr int num_channels = 3;
const int buffer_byte_count = height * width * num_channels;
for (int i = 0; i < buffer_byte_count; ++i) {
out_buffer[i] = -128;
}
const int32_t width_fp = width * kFixedPoint;
const int32_t height_fp = height * kFixedPoint;
const int32_t half_width_fp = width_fp / 2;
const int32_t half_height_fp = height_fp / 2;
const int32_t x_range_fp = FloatToFP(x_range);
const int32_t y_range_fp = FloatToFP(y_range);
const int t_inc_fp = kFixedPoint / stroke_points_count;
const int one_half_fp = (kFixedPoint / 2);
for (int point_index = 0; point_index < (stroke_points_count - 1);
++point_index) {
const int8_t* start_point = &stroke_points[point_index * 2];
const int32_t start_point_x_fp = (start_point[0] * kFixedPoint) / 128;
const int32_t start_point_y_fp = (start_point[1] * kFixedPoint) / 128;
const int8_t* end_point = &stroke_points[(point_index + 1) * 2];
const int32_t end_point_x_fp = (end_point[0] * kFixedPoint) / 128;
const int32_t end_point_y_fp = (end_point[1] * kFixedPoint) / 128;
const int32_t start_x_fp =
NormToCoordFP(start_point_x_fp, x_range_fp, half_width_fp);
const int32_t start_y_fp =
NormToCoordFP(-start_point_y_fp, y_range_fp, half_height_fp);
const int32_t end_x_fp =
NormToCoordFP(end_point_x_fp, x_range_fp, half_width_fp);
const int32_t end_y_fp =
NormToCoordFP(-end_point_y_fp, y_range_fp, half_height_fp);
const int32_t delta_x_fp = end_x_fp - start_x_fp;
const int32_t delta_y_fp = end_y_fp - start_y_fp;
const int32_t t_fp = point_index * t_inc_fp;
int32_t red_i32;
int32_t green_i32;
int32_t blue_i32;
if (t_fp < one_half_fp) {
const int32_t local_t_fp = DivFP(t_fp, one_half_fp);
const int32_t one_minus_t_fp = kFixedPoint - local_t_fp;
red_i32 = RoundFPToInt(one_minus_t_fp * 255) - 128;
green_i32 = RoundFPToInt(local_t_fp * 255) - 128;
blue_i32 = -128;
} else {
const int32_t local_t_fp = DivFP(t_fp - one_half_fp, one_half_fp);
const int32_t one_minus_t_fp = kFixedPoint - local_t_fp;
red_i32 = -128;
green_i32 = RoundFPToInt(one_minus_t_fp * 255) - 128;
blue_i32 = RoundFPToInt(local_t_fp * 255) - 128;
}
const int8_t red_i8 = Gate(red_i32, -128, 127);
const int8_t green_i8 = Gate(green_i32, -128, 127);
const int8_t blue_i8 = Gate(blue_i32, -128, 127);
int line_length;
int32_t x_inc_fp;
int32_t y_inc_fp;
if (Abs(delta_x_fp) > Abs(delta_y_fp)) {
line_length = Abs(RoundFPToInt(delta_x_fp));
if (delta_x_fp > 0) {
x_inc_fp = 1 * kFixedPoint;
y_inc_fp = DivFP(delta_y_fp, delta_x_fp);
} else {
x_inc_fp = -1 * kFixedPoint;
y_inc_fp = -DivFP(delta_y_fp, delta_x_fp);
}
} else {
line_length = Abs(RoundFPToInt(delta_y_fp));
if (delta_y_fp > 0) {
y_inc_fp = 1 * kFixedPoint;
x_inc_fp = DivFP(delta_x_fp, delta_y_fp);
} else {
y_inc_fp = -1 * kFixedPoint;
x_inc_fp = -DivFP(delta_x_fp, delta_y_fp);
}
}
for (int i = 0; i < (line_length + 1); ++i) {
const int32_t x_fp = start_x_fp + (i * x_inc_fp);
const int32_t y_fp = start_y_fp + (i * y_inc_fp);
const int x = RoundFPToInt(x_fp);
const int y = RoundFPToInt(y_fp);
if ((x < 0) or (x >= width) or (y < 0) or (y >= height)) {
continue;
}
const int buffer_index = (y * width * num_channels) + (x * num_channels);
out_buffer[buffer_index + 0] = red_i8;
out_buffer[buffer_index + 1] = green_i8;
out_buffer[buffer_index + 2] = blue_i8;
}
}
}

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_RASTERIZE_STROKE_H
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_RASTERIZE_STROKE_H
#include <cstdint>
void RasterizeStroke(int8_t* stroke_points, int stroke_points_count,
float x_range, float y_range, int width, int height,
int8_t* out_buffer);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_RASTERIZE_STROKE_H

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,474 @@
<!DOCTYPE html>
<html>
<head>
<title>Magic Wand Gesture Recorder</title>
<link href="https://fonts.googleapis.com/css?family=Roboto&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Roboto', sans-serif; color: #888888; background: #000000; font-size: small;
}
.button {
background-color: aqua; border: none; color: black; padding: 1px; text-align: center;
text-decoration: none; font-size: 12px; margin: 12px 20px; height: 20px; width: 100px;
border-radius: 10%; outline:none; font-family: 'Roboto', sans-serif;
float: left;
}
#bigButton {
float:left;
}
#downloadButton {
float: left;
}
.container {
width:860px;height:384px;margin-top:30px;margin-bottom:7.5px; margin: 0 auto; position: relative;
}
.widget {
background: #111111;
border: 1px solid #000000;
border-radius: 0px;
padding: 12px;
margin: 6px;
float: left;
width: 340px;
height: 340px;
}
.status {
background-image: url("logo.png");
background-size: 80px;
background-position: 98% 50%;
background-repeat: no-repeat;
width: 885px; height: 42px;
color: white;
}
.square {
width: 320px; height: 320px;
position: relative;
float: left;
}
.label {
height: 20px;
min-width: 100px;
display: inline;
font-size: 15px;
float: left;
}
.intro {
font-size: 15px;
}
.bluetooth-label {
margin-top: 14px;
}
.count-label {
margin-top: -14px;
float: right;
margin-right: 158px;
text-align: right;
}
.gesture_store {
width: 400px;
height: 500px;
float: left;
overflow-y: scroll;
}
.trash {
float: right;
font-size: x-large;
cursor: pointer;
}
</style>
</head>
<body>
<div class="container">
<div class="intro">
To get started recording magic wand gestures:
<ul>
<li>Upload the <a href="https://create.arduino.cc/editor/petewarden/669b0490-720a-4d5e-a1c0-2055823f4c08/preview">Magic Wand Capture sketch</a> to an Arduino Nano BLE Sense board</li>
<li>Connect to the board using the Bluetooth button below.</li>
<li>Wave the wand to make gestures. They'll be recorded and displayed on the right.</li>
<li>Review the gestures, add labels by clicking on the '?', and remove mistakes.</li>
<li>Download the gestures as a JSON data file, ready for model training.</li>
</ul>
</div>
<div class="status widget">
<a class="button" id="downloadButton">Download Data</a>
<button class="button" id="bigButton">Bluetooth</button>
<div class="bluetooth-label" id="bluetooth">Click button to connect to the board</div>
<div class="count-label" id="count"></div>
</div>
<div class="widget">
<div id="stroke_label" class="label"></div>
<canvas id="stroke" width="640px" height="640px" class="square"></canvas>
</div>
<div class="gesture_store">
</div>
</div>
</body>
<script type="text/javascript">
// Thanks to Dom Pajak for his original Arduino BLE Sense connection demo:
// https://arduino.github.io/ArduinoAI/BLESense-test-dashboard/
var maxRecords = 64;
var STROKE_POINT_COUNT = 160;
// UI elements
const bigButton = document.getElementById('bigButton');
const BLEstatus = document.getElementById('bluetooth');
if ("bluetooth" in navigator) {
bigButton.addEventListener('click', function(event) {
connect();
});
// else the browser doesn't support bluetooth
} else {
msg("Browser not supported"); bigButton.style.backgroundColor = "red";
alert("Error: This browser doesn't support Web Bluetooth. Try using Chrome.");
}
function msg(m){
BLEstatus.innerHTML = m;
}
async function connect() {
bigButton.style.backgroundColor="grey";
msg('Requesting device ...');
const device = await navigator.bluetooth.requestDevice({
filters: [
{
services: [SERVICE_UUID] // SERVICE_UUID
}
]
});
msg('Connecting to device ...');
device.addEventListener('gattserverdisconnected', onDisconnected);
const server = await device.gatt.connect();
msg('Getting primary service ...');
const service = await server.getPrimaryService(SERVICE_UUID);
// Set up the characteristics
for (const sensor of sensors) {
msg('Characteristic '+sensor+"...");
BLEsense[sensor].characteristic = await service.getCharacteristic(BLEsense[sensor].uuid);
// Set up notification
if (BLEsense[sensor].properties.includes("BLENotify")){
BLEsense[sensor].characteristic.addEventListener('characteristicvaluechanged',function(event){handleIncoming(BLEsense[sensor],event.target.value);});
await BLEsense[sensor].characteristic.startNotifications();
}
// Set up polling for read
if (BLEsense[sensor].properties.includes("BLERead")){
BLEsense[sensor].polling = setInterval(function() {
BLEsense[sensor].characteristic.readValue().then(function(data){handleIncoming(BLEsense[sensor],data);})}
, 200);
}
BLEsense[sensor].rendered = false;
}
bigButton.style.backgroundColor = 'green';
msg('Connected.');
}
function getStrokePoints(dataview, byteOffset, littleEndian) {
var result = [];
var currentOffset = byteOffset;
for (var i = 0; i < STROKE_POINT_COUNT; ++i) {
var entry = {};
entry.x = dataview.getInt8(currentOffset, littleEndian) / 128.0;
currentOffset += 1;
entry.y = dataview.getInt8(currentOffset, littleEndian) / 128.0;
currentOffset += 1;
result.push(entry);
}
return result;
}
function handleIncoming(sensor, dataReceived) {
const columns = Object.keys(sensor.data); // column headings for this sensor
const typeMap = {
"Uint8": {fn:DataView.prototype.getUint8, bytes:1},
"Uint16": {fn:DataView.prototype.getUint16, bytes:2},
"Int32": {fn:DataView.prototype.getInt32, bytes:4},
"Float32": {fn:DataView.prototype.getFloat32, bytes:4},
"StrokePoints": {fn:getStrokePoints, bytes:(STROKE_POINT_COUNT * 2 * 1)},
};
var packetPointer = 0,i = 0;
// Read each sensor value in the BLE packet and push into the data array
sensor.structure.forEach(function(dataType){
var unpackedValue;
if (dataType === "StrokePoints") {
var dataViewFn = typeMap[dataType].fn;
unpackedValue = dataViewFn(dataReceived, packetPointer,true);
} else {
var dataViewFn = typeMap[dataType].fn.bind(dataReceived);
unpackedValue = dataViewFn(packetPointer,true);
}
// Push sensor reading onto data array
sensor.data[columns[i]].push(unpackedValue);
// Keep array at buffer size
if (sensor.data[columns[i]].length> maxRecords) {sensor.data[columns[i]].shift();}
// move pointer forward in data packet to next value
packetPointer += typeMap[dataType].bytes;
bytesReceived += typeMap[dataType].bytes;
i++;
});
sensor.rendered = false; // flag - vizualization needs to be updated
if (typeof sensor.onUpdate != 'undefined') {
sensor.onUpdate();
}
}
function onDisconnected(event) {
let device = event.target;
bigButton.style.backgroundColor="red";
// clear read polling
for (const sensor of sensors) {
if(typeof BLEsense[sensor].polling !== 'undefined'){
clearInterval(BLEsense[sensor].polling);
}
}
msg('Device ' + device.name + ' is disconnected.');
}
function BLEwriteTo(sensor){
if (BLEsense[sensor].writeBusy) return; // dropping writes when one is in progress instead of queuing as LED is non-critical / realtime
BLEsense[sensor].writeBusy = true; // Ensure no write happens when GATT operation in progress
BLEsense[sensor].characteristic.writeValue(BLEsense[sensor].writeValue)
.then(_ => {
BLEsense[sensor].writeBusy = false;
})
.catch(error => {
console.log(error);
});
}
var storedStrokes = [];
function storeStroke(strokePoints) {
var storeIndex = storedStrokes.length;
var template =
' <div class="widget" id="store_' + storeIndex +'">' +
' <div contenteditable="true" class="label"></div>' +
' <div class="trash">&#128465;</div>' +
' <canvas width="640px" height="640px" class="square"></canvas>' +
' </div>';
var storeDiv = document.querySelector('.gesture_store');
var parser = new DOMParser();
var html = parser.parseFromString(template, 'text/html');
storeDiv.prepend(html.body.firstChild);
var strokeLabel = document.querySelector('#store_' + storeIndex +' > .label');
strokeLabel.innerText = "?";
strokeLabel.onfocus = onLabelFocus;
strokeLabel.onblur = onLabelBlur;
strokeLabel.onkeydown = onLabelKeydown;
var strokeCanvas = document.querySelector('#store_' + storeIndex +' > canvas');
const ctx = strokeCanvas.getContext('2d');
ctx.fillStyle = "#111111";
ctx.fillRect(0, 0, strokeCanvas.width, strokeCanvas.height);
drawStrokeGraph(strokeCanvas, strokePoints, strokePoints.length);
storedStrokes.push({
index: storeIndex,
strokePoints: strokePoints,
label: '',
});
onStoreChange();
var strokeTrash = document.querySelector('#store_' + storeIndex +' > .trash');
strokeTrash.onclick = onTrashClick;
}
function onLabelFocus(event) {
if (event.target.innerText === '?') {
event.target.innerText = '';
}
}
function onLabelBlur(event) {
var parent = event.target.parentElement;
var id = parent.id;
var index = Number(id.replace('store_', ''));
var entry = storedStrokes.find(entry => entry.index === index);
entry.label = event.target.innerText;
onStoreChange();
}
function onLabelKeydown(event) {
if (event.keyCode == 13) {
event.preventDefault();
event.target.blur();
}
}
function onTrashClick(event) {
var parent = event.target.parentElement;
var id = parent.id;
parent.remove();
var index = Number(id.replace('store_', ''));
storedStrokes = storedStrokes.filter(entry => entry.index !== index);
onStoreChange();
}
function onStoreChange() {
var data = {
strokes: storedStrokes,
};
var dataStr = "data:text/json;charset=utf-8," + encodeURIComponent(JSON.stringify(data));
var downloadButton = document.querySelector('#downloadButton');
downloadButton.setAttribute('href', dataStr);
downloadButton.setAttribute('download', 'wanddata.json');
var count = document.querySelector('#count');
count.innerText = storedStrokes.length;
}
function initStrokeGraph() {
var canvas = document.getElementById('stroke');
const ctx = canvas.getContext('2d');
ctx.fillStyle = "#111111";
ctx.fillRect(0, 0, canvas.width, canvas.height);
}
function drawStrokeGraph(canvas, strokePoints, strokeDataLength) {
const ctx = canvas.getContext('2d');
var canvasWidth = canvas.width;
var canvasHeight = canvas.height;
var halfHeight = canvasHeight / 2;
var halfWidth = canvasWidth / 2;
ctx.strokeStyle = "#ffffff";
ctx.beginPath();
for (var i = 0; i < strokeDataLength; ++i) {
var x = strokePoints[i].x;
var y = strokePoints[i].y;
var xCanvas = halfWidth + (x * halfWidth);
var yCanvas = halfHeight - (y * halfHeight);
if (i === 0) {
ctx.moveTo(xCanvas, yCanvas);
} else if (i == (strokeDataLength - 1)) {
ctx.lineTo(xCanvas+5, yCanvas+5);
ctx.lineTo(xCanvas-5, yCanvas-5);
ctx.moveTo(xCanvas+5, yCanvas-5);
ctx.moveTo(xCanvas-5, yCanvas+5);
} else {
ctx.lineTo(xCanvas, yCanvas);
}
}
ctx.stroke();
}
var previousStrokeState = 0;
function updateStrokeGraph() {
var strokeData = BLEsense['stroke'].data;
var strokeDataLength = strokeData.length.latest();
var strokeState = strokeData.state.latest();
var strokePoints = strokeData.strokePoints.latest();
strokePoints = strokePoints.slice(0, strokeDataLength);
if ((strokeState == 2) && (previousStrokeState != 2)) {
storeStroke(strokePoints);
}
previousStrokeState = strokeState;
var label = document.getElementById('stroke_label');
if (strokeState == 0) {
label.innerText = "Waiting for gesture";
} else if (strokeState == 1) {
label.innerText = "Drawing";
} else {
label.innerText = "Done";
}
var canvas = document.getElementById('stroke');
const ctx = canvas.getContext('2d');
ctx.fillStyle = "#111111";
ctx.fillRect(0, 0, canvas.width, canvas.height);
if (strokeState === 1) {
drawStrokeGraph(canvas, strokePoints, strokeDataLength);
}
}
var BLEsense =
{
stroke:
{
uuid: '4798e0f2-300a-4d68-af64-8a8f5258404e',
properties: ['BLERead'], // BLENotify only gives use the first 20 bytes.
structure: [
'Int32', 'Int32',
'StrokePoints',
],
data: {
'state': [], 'length': [],
'strokePoints': [],
},
onUpdate: updateStrokeGraph,
},
};
const sensors = Object.keys(BLEsense);
const SERVICE_UUID = '4798e0f2-0000-4d68-af64-8a8f5258404e';
var bytesReceived = 0;
var bytesPrevious = 0;
// return last item of array
Array.prototype.latest = function(){return this[this.length - 1];};
function bytes(){
if (bytesReceived > bytesPrevious){
bytesPrevious= bytesReceived;
msg(bytesReceived+" bytes received");
}
}
var skip_frame = false;
function draw(){
function updateViz(sensor,fns){
if (BLEsense[sensor].rendered == false) { // only render if new values are received
fns.forEach(function(fn){
fn(sensor);
});
BLEsense[sensor].rendered = true;
}
}
if (skip_frame == false){ // TODO update with fuction to iterate object with viz function as a property
skip_frame = true; // render alternate frames = 30fps
} else {skip_frame=false;}
requestAnimationFrame(draw);
}
initStrokeGraph();
requestAnimationFrame(draw);
</script>
</html>

View File

@ -0,0 +1,84 @@
<!-- mdformat off(b/169948621#comment2) -->
# Micro Speech Example
This example shows how to run a 20 kB model that can recognize 2 keywords,
"yes" and "no", from speech data.
The application listens to its surroundings with a microphone and indicates
when it has detected a word by lighting an LED or displaying data on a
screen, depending on the capabilities of the device.
![Animation on Arduino](../../docs/animation_on_arduino.gif)
The code has a small footprint (for example, around 166 kilobytes on a Cortex
M4) and only uses about 54 kilobytes of additional RAM for working memory.
## Table of contents
<!--ts-->
* [Table of contents](#table-of-contents)
* [Deploy to Arduino](#deploy-to-arduino)
* [Install the Arduino_TensorFlowLite library](#install-the-arduino_tensorflowlite-library)
* [Load and run the example](#load-and-run-the-example)
<!--te-->
## Deploy to Arduino
The following instructions will help you build and deploy this example to
[Arduino](https://www.arduino.cc/) devices.
The example has been tested with the following devices:
- [Arduino Nano 33 BLE Sense](https://store.arduino.cc/usa/nano-33-ble-sense-with-headers)
The Arduino Nano 33 BLE Sense is currently the only Arduino with a built-in
microphone. If you're using a different Arduino board and attaching your own
microphone, you'll need to implement your own `audio_provider.cpp` code. It also has a
set of LEDs, which are used to indicate that a word has been recognized.
### Install the Arduino_TensorFlowLite library
This example application is included as part of the official TensorFlow Lite Micro
Arduino library.
To install the TensorFlow Lite Micro for Arduino library, see the
[how to install](../../README.md#how-to-install) instructions.
### Load and run the example
Once the library has been added, go to `File -> Examples`. You should see an
entry within the list named `Arduino_TensorFlowLite`. Select
it and click `micro_speech` to load the example.
Use the Arduino IDE to build and upload the example. Once it is running, you
should see the built-in LED on your device flashing. The built-in LED will flash on/off for each inference cycle. Saying the word "yes" will
cause the green LED to remain on for 3 seconds. The current model has fairly low
accuracy, so you may have to repeat "yes" a few times. Saying the word "no" will cause the red LED to light up. The blue LED will be lit for certain "unknown" sounds.
Word recognition should occur at a distance of approximately 1.5 feet in a low-noise environment.
The program also outputs inference results to the serial port, which appear as
follows:
```
Heard yes (201) @4056ms
Heard no (205) @6448ms
Heard unknown (201) @13696ms
Heard yes (205) @15000ms
```
The number after each detected word is its score. By default, the program only
considers matches as valid if their score is over 200, so all of the scores you
see will be at least 200.
When the program is run, it waits several seconds for a USB-serial connection to be
available. If there is no connection available, it will not output data. To see
the serial output in the Arduino desktop IDE, do the following:
1. Open the Arduino IDE
1. Connect the Arduino board to your computer via USB
1. Press the reset button on the Arduino board
1. Within 5 seconds, go to `Tools -> Serial Monitor` in the Arduino IDE. You may
have to try several times, since the board will take a moment to connect.
If you don't see any output, repeat the process again.

View File

@ -0,0 +1,194 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#define ARDUINO_EXCLUDE_CODE
#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#ifndef ARDUINO_EXCLUDE_CODE
#include <algorithm>
#include <cmath>
#include "PDM.h"
#include "audio_provider.h"
#include "micro_features_micro_model_settings.h"
#include "test_over_serial/test_over_serial.h"
using namespace test_over_serial;
namespace {
bool g_is_audio_initialized = false;
// An internal buffer able to fit 16x our sample size
constexpr int kAudioCaptureBufferSize = DEFAULT_PDM_BUFFER_SIZE * 16;
int16_t g_audio_capture_buffer[kAudioCaptureBufferSize];
// A buffer that holds our output
int16_t g_audio_output_buffer[kMaxAudioSampleSize];
// Mark as volatile so we can check in a while loop to see if
// any samples have arrived yet.
volatile int32_t g_latest_audio_timestamp = 0;
// error reporter
tflite::ErrorReporter* g_error_reporter;
// test_over_serial sample index
uint32_t g_test_sample_index;
// test_over_serial silence insertion flag
bool g_test_insert_silence = true;
} // namespace
void CaptureSamples() {
// This is how many bytes of new data we have each time this is called
const int number_of_samples = DEFAULT_PDM_BUFFER_SIZE / 2;
// Calculate what timestamp the last audio sample represents
const int32_t time_in_ms =
g_latest_audio_timestamp +
(number_of_samples / (kAudioSampleFrequency / 1000));
// Determine the index, in the history of all samples, of the last sample
const int32_t start_sample_offset =
g_latest_audio_timestamp * (kAudioSampleFrequency / 1000);
// Determine the index of this sample in our ring buffer
const int capture_index = start_sample_offset % kAudioCaptureBufferSize;
// Read the data to the correct place in our buffer
int num_read =
PDM.read(g_audio_capture_buffer + capture_index, DEFAULT_PDM_BUFFER_SIZE);
if (num_read != DEFAULT_PDM_BUFFER_SIZE) {
TF_LITE_REPORT_ERROR(g_error_reporter, "### short read (%d/%d) @%dms",
num_read, DEFAULT_PDM_BUFFER_SIZE, time_in_ms);
while (true) {
// NORETURN
}
}
// This is how we let the outside world know that new audio data has arrived.
g_latest_audio_timestamp = time_in_ms;
}
TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) {
if (!g_is_audio_initialized) {
g_error_reporter = error_reporter;
// Hook up the callback that will be called with each sample
PDM.onReceive(CaptureSamples);
// Start listening for audio: MONO @ 16KHz
PDM.begin(1, kAudioSampleFrequency);
// gain: -20db (min) + 6.5db (13) + 3.2db (builtin) = -10.3db
PDM.setGain(13);
// Block until we have our first audio sample
while (!g_latest_audio_timestamp) {
}
g_is_audio_initialized = true;
}
return kTfLiteOk;
}
TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
int start_ms, int duration_ms,
int* audio_samples_size, int16_t** audio_samples) {
// This next part should only be called when the main thread notices that the
// latest audio sample data timestamp has changed, so that there's new data
// in the capture ring buffer. The ring buffer will eventually wrap around and
// overwrite the data, but the assumption is that the main thread is checking
// often enough and the buffer is large enough that this call will be made
// before that happens.
// Determine the index, in the history of all samples, of the first
// sample we want
const int start_offset = start_ms * (kAudioSampleFrequency / 1000);
// Determine how many samples we want in total
const int duration_sample_count =
duration_ms * (kAudioSampleFrequency / 1000);
for (int i = 0; i < duration_sample_count; ++i) {
// For each sample, transform its index in the history of all samples into
// its index in g_audio_capture_buffer
const int capture_index = (start_offset + i) % kAudioCaptureBufferSize;
// Write the sample to the output buffer
g_audio_output_buffer[i] = g_audio_capture_buffer[capture_index];
}
// Set pointers to provide access to the audio
*audio_samples_size = duration_sample_count;
*audio_samples = g_audio_output_buffer;
return kTfLiteOk;
}
namespace {
void InsertSilence(const size_t len, int16_t value) {
for (size_t i = 0; i < len; i++) {
const size_t index = (g_test_sample_index + i) % kAudioCaptureBufferSize;
g_audio_capture_buffer[index] = value;
}
g_test_sample_index += len;
}
int32_t ProcessTestInput(TestOverSerial& test) {
constexpr size_t samples_16ms = ((kAudioSampleFrequency / 1000) * 16);
InputHandler handler = [](const InputBuffer* const input) {
if (0 == input->offset) {
// don't insert silence
g_test_insert_silence = false;
}
for (size_t i = 0; i < input->length; i++) {
const size_t index = (g_test_sample_index + i) % kAudioCaptureBufferSize;
g_audio_capture_buffer[index] = input->data.int16[i];
}
g_test_sample_index += input->length;
if (input->total == (input->offset + input->length)) {
// allow silence insertion again
g_test_insert_silence = true;
}
return true;
};
test.ProcessInput(&handler);
if (g_test_insert_silence) {
// add 16ms of silence just like the PDM interface
InsertSilence(samples_16ms, 0);
}
// Round the timestamp to a multiple of 64ms,
// This emulates the PDM interface during inference processing.
g_latest_audio_timestamp = (g_test_sample_index / (samples_16ms * 4)) * 64;
return g_latest_audio_timestamp;
}
} // namespace
int32_t LatestAudioTimestamp() {
TestOverSerial& test = TestOverSerial::Instance(kAUDIO_PCM_16KHZ_MONO_S16);
if (!test.IsTestMode()) {
// check serial port for test mode command
test.ProcessInput(nullptr);
}
if (test.IsTestMode()) {
if (g_is_audio_initialized) {
// stop capture from hardware
PDM.end();
g_is_audio_initialized = false;
g_test_sample_index =
g_latest_audio_timestamp * (kAudioSampleFrequency / 1000);
}
return ProcessTestInput(test);
} else {
// CaptureSamples() updated the timestamp
return g_latest_audio_timestamp;
}
// NOTREACHED
}
#endif // ARDUINO_EXCLUDE_CODE

View File

@ -0,0 +1,89 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#define ARDUINO_EXCLUDE_CODE
#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#ifndef ARDUINO_EXCLUDE_CODE
#include "Arduino.h"
#include "command_responder.h"
// Toggles the built-in LED every inference, and lights a colored LED depending
// on which word was detected.
void RespondToCommand(tflite::ErrorReporter* error_reporter,
int32_t current_time, const char* found_command,
uint8_t score, bool is_new_command) {
static bool is_initialized = false;
if (!is_initialized) {
pinMode(LED_BUILTIN, OUTPUT);
// Pins for the built-in RGB LEDs on the Arduino Nano 33 BLE Sense
pinMode(LEDR, OUTPUT);
pinMode(LEDG, OUTPUT);
pinMode(LEDB, OUTPUT);
// Ensure the LED is off by default.
// Note: The RGB LEDs on the Arduino Nano 33 BLE
// Sense are on when the pin is LOW, off when HIGH.
digitalWrite(LEDR, HIGH);
digitalWrite(LEDG, HIGH);
digitalWrite(LEDB, HIGH);
is_initialized = true;
}
static int32_t last_command_time = 0;
static int count = 0;
if (is_new_command) {
TF_LITE_REPORT_ERROR(error_reporter, "Heard %s (%d) @%dms", found_command,
score, current_time);
// If we hear a command, light up the appropriate LED
digitalWrite(LEDR, HIGH);
digitalWrite(LEDG, HIGH);
digitalWrite(LEDB, HIGH);
if (found_command[0] == 'y') {
digitalWrite(LEDG, LOW); // Green for yes
} else if (found_command[0] == 'n') {
digitalWrite(LEDR, LOW); // Red for no
} else if (found_command[0] == 'u') {
digitalWrite(LEDB, LOW); // Blue for unknown
} else {
// silence
}
last_command_time = current_time;
}
// If last_command_time is non-zero but was >3 seconds ago, zero it
// and switch off the LED.
if (last_command_time != 0) {
if (last_command_time < (current_time - 3000)) {
last_command_time = 0;
digitalWrite(LEDR, HIGH);
digitalWrite(LEDG, HIGH);
digitalWrite(LEDB, HIGH);
}
}
// Otherwise, toggle the LED every time an inference is performed.
++count;
if (count & 1) {
digitalWrite(LED_BUILTIN, HIGH);
} else {
digitalWrite(LED_BUILTIN, LOW);
}
}
#endif // ARDUINO_EXCLUDE_CODE

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
// Arduino automatically calls the setup() and loop() functions in a sketch, so
// where other systems need their own main routine in this file, it can be left
// empty.

View File

@ -0,0 +1,49 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// This is an abstraction around an audio source like a microphone, and is
// expected to return 16-bit PCM sample data for a given point in time. The
// sample data itself should be used as quickly as possible by the caller, since
// to allow memory optimizations there are no guarantees that the samples won't
// be overwritten by new data in the future. In practice, implementations should
// ensure that there's a reasonable time allowed for clients to access the data
// before any reuse.
// The reference implementation can have no platform-specific dependencies, so
// it just returns an array filled with zeros. For real applications, you should
// ensure there's a specialized implementation that accesses hardware APIs.
TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
int start_ms, int duration_ms,
int* audio_samples_size, int16_t** audio_samples);
// Returns the time that audio data was last captured in milliseconds. There's
// no contract about what time zero represents, the accuracy, or the granularity
// of the result. Subsequent calls will generally not return a lower value, but
// even that's not guaranteed if there's an overflow wraparound.
// The reference implementation of this function just returns a constantly
// incrementing value for each call, since it would need a non-portable platform
// call to access time information. For real applications, you'll need to write
// your own platform-specific implementation.
int32_t LatestAudioTimestamp();
// Starts audio capture
TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Provides an interface to take an action based on an audio command.
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Called every time the results of an audio recognition run are available. The
// human-readable name of any recognized command is in the `found_command`
// argument, `score` has the numerical confidence, and `is_new_command` is set
// if the previous command was different to this one.
void RespondToCommand(tflite::ErrorReporter* error_reporter,
int32_t current_time, const char* found_command,
uint8_t score, bool is_new_command);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_

View File

@ -0,0 +1,127 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "feature_provider.h"
#include "audio_provider.h"
#include "micro_features_micro_features_generator.h"
#include "micro_features_micro_model_settings.h"
FeatureProvider::FeatureProvider(int feature_size, int8_t* feature_data)
: feature_size_(feature_size),
feature_data_(feature_data),
is_first_run_(true) {
// Initialize the feature data to default values.
for (int n = 0; n < feature_size_; ++n) {
feature_data_[n] = 0;
}
}
FeatureProvider::~FeatureProvider() {}
TfLiteStatus FeatureProvider::PopulateFeatureData(
tflite::ErrorReporter* error_reporter, int32_t last_time_in_ms,
int32_t time_in_ms, int* how_many_new_slices) {
if (feature_size_ != kFeatureElementCount) {
TF_LITE_REPORT_ERROR(error_reporter,
"Requested feature_data_ size %d doesn't match %d",
feature_size_, kFeatureElementCount);
return kTfLiteError;
}
// Quantize the time into steps as long as each window stride, so we can
// figure out which audio data we need to fetch.
const int last_step = (last_time_in_ms / kFeatureSliceStrideMs);
// Number of new 20ms slices from which we can take 30ms samples
int slices_needed =
((((time_in_ms - last_time_in_ms) - kFeatureSliceDurationMs) *
kFeatureSliceStrideMs) /
kFeatureSliceStrideMs +
kFeatureSliceStrideMs) /
kFeatureSliceStrideMs;
// If this is the first call, make sure we don't use any cached information.
if (is_first_run_) {
TfLiteStatus init_status = InitializeMicroFeatures(error_reporter);
if (init_status != kTfLiteOk) {
return init_status;
}
is_first_run_ = false;
return kTfLiteOk;
}
if (slices_needed > kFeatureSliceCount) {
slices_needed = kFeatureSliceCount;
}
if (slices_needed == 0) {
return kTfLiteOk;
}
*how_many_new_slices = slices_needed;
const int slices_to_keep = kFeatureSliceCount - slices_needed;
const int slices_to_drop = kFeatureSliceCount - slices_to_keep;
// If we can avoid recalculating some slices, just move the existing data
// up in the spectrogram, to perform something like this:
// last time = 80ms current time = 120ms
// +-----------+ +-----------+
// | data@20ms | --> | data@60ms |
// +-----------+ -- +-----------+
// | data@40ms | -- --> | data@80ms |
// +-----------+ -- -- +-----------+
// | data@60ms | -- -- | <empty> |
// +-----------+ -- +-----------+
// | data@80ms | -- | <empty> |
// +-----------+ +-----------+
if (slices_to_keep > 0) {
for (int dest_slice = 0; dest_slice < slices_to_keep; ++dest_slice) {
int8_t* dest_slice_data =
feature_data_ + (dest_slice * kFeatureSliceSize);
const int src_slice = dest_slice + slices_to_drop;
const int8_t* src_slice_data =
feature_data_ + (src_slice * kFeatureSliceSize);
for (int i = 0; i < kFeatureSliceSize; ++i) {
dest_slice_data[i] = src_slice_data[i];
}
}
}
// Any slices that need to be filled in with feature data have their
// appropriate audio data pulled, and features calculated for that slice.
if (slices_needed > 0) {
for (int new_slice = slices_to_keep; new_slice < kFeatureSliceCount;
++new_slice) {
const int new_step = last_step + (new_slice - slices_to_keep);
const int32_t slice_start_ms = (new_step * kFeatureSliceStrideMs);
int16_t* audio_samples = nullptr;
int audio_samples_size = 0;
GetAudioSamples(error_reporter, slice_start_ms, kFeatureSliceDurationMs,
&audio_samples_size, &audio_samples);
constexpr int wanted =
kFeatureSliceDurationMs * (kAudioSampleFrequency / 1000);
if (audio_samples_size != wanted) {
TF_LITE_REPORT_ERROR(error_reporter,
"Audio data size %d too small, want %d",
audio_samples_size, wanted);
return kTfLiteError;
}
int8_t* new_slice_data = feature_data_ + (new_slice * kFeatureSliceSize);
size_t num_samples_read;
TfLiteStatus generate_status = GenerateMicroFeatures(
error_reporter, audio_samples, audio_samples_size, kFeatureSliceSize,
new_slice_data, &num_samples_read);
if (generate_status != kTfLiteOk) {
return generate_status;
}
}
}
return kTfLiteOk;
}

View File

@ -0,0 +1,52 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Binds itself to an area of memory intended to hold the input features for an
// audio-recognition neural network model, and fills that data area with the
// features representing the current audio input, for example from a microphone.
// The audio features themselves are a two-dimensional array, made up of
// horizontal slices representing the frequencies at one point in time, stacked
// on top of each other to form a spectrogram showing how those frequencies
// changed over time.
class FeatureProvider {
public:
// Create the provider, and bind it to an area of memory. This memory should
// remain accessible for the lifetime of the provider object, since subsequent
// calls will fill it with feature data. The provider does no memory
// management of this data.
FeatureProvider(int feature_size, int8_t* feature_data);
~FeatureProvider();
// Fills the feature data with information from audio inputs, and returns how
// many feature slices were updated.
TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
int32_t last_time_in_ms, int32_t time_in_ms,
int* how_many_new_slices);
private:
int feature_size_;
int8_t* feature_data_;
// Make sure we don't try to use cached information if this is the first call
// into the provider.
bool is_first_run_;
};
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MAIN_FUNCTIONS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MAIN_FUNCTIONS_H_
// Expose a C friendly interface for main functions.
#ifdef __cplusplus
extern "C" {
#endif
// Initializes all data needed for the example. The name is important, and needs
// to be setup() for Arduino compatibility.
void setup();
// Runs one iteration of data gathering and inference. This should be called
// repeatedly from the application code. The name needs to be loop() for Arduino
// compatibility.
void loop();
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MAIN_FUNCTIONS_H_

View File

@ -0,0 +1,116 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "micro_features_micro_features_generator.h"
#include <cmath>
#include <cstring>
#include "micro_features_micro_model_settings.h"
#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
// Configure FFT to output 16 bit fixed point.
#define FIXED_POINT 16
namespace {
FrontendState g_micro_features_state;
bool g_is_first_time = true;
} // namespace
TfLiteStatus InitializeMicroFeatures(tflite::ErrorReporter* error_reporter) {
FrontendConfig config;
config.window.size_ms = kFeatureSliceDurationMs;
config.window.step_size_ms = kFeatureSliceStrideMs;
config.noise_reduction.smoothing_bits = 10;
config.filterbank.num_channels = kFeatureSliceSize;
config.filterbank.lower_band_limit = 125.0;
config.filterbank.upper_band_limit = 7500.0;
config.noise_reduction.smoothing_bits = 10;
config.noise_reduction.even_smoothing = 0.025;
config.noise_reduction.odd_smoothing = 0.06;
config.noise_reduction.min_signal_remaining = 0.05;
config.pcan_gain_control.enable_pcan = 1;
config.pcan_gain_control.strength = 0.95;
config.pcan_gain_control.offset = 80.0;
config.pcan_gain_control.gain_bits = 21;
config.log_scale.enable_log = 1;
config.log_scale.scale_shift = 6;
if (!FrontendPopulateState(&config, &g_micro_features_state,
kAudioSampleFrequency)) {
TF_LITE_REPORT_ERROR(error_reporter, "FrontendPopulateState() failed");
return kTfLiteError;
}
g_is_first_time = true;
return kTfLiteOk;
}
// This is not exposed in any header, and is only used for testing, to ensure
// that the state is correctly set up before generating results.
void SetMicroFeaturesNoiseEstimates(const uint32_t* estimate_presets) {
for (int i = 0; i < g_micro_features_state.filterbank.num_channels; ++i) {
g_micro_features_state.noise_reduction.estimate[i] = estimate_presets[i];
}
}
TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter,
const int16_t* input, int input_size,
int output_size, int8_t* output,
size_t* num_samples_read) {
const int16_t* frontend_input;
if (g_is_first_time) {
frontend_input = input;
g_is_first_time = false;
} else {
frontend_input = input;
}
FrontendOutput frontend_output = FrontendProcessSamples(
&g_micro_features_state, frontend_input, input_size, num_samples_read);
for (size_t i = 0; i < frontend_output.size; ++i) {
// These scaling values are derived from those used in input_data.py in the
// training pipeline.
// The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
// range. In training, these are then arbitrarily divided by 25.6 to get
// float values in the rough range of 0.0 to 26.0. This scaling is performed
// for historical reasons, to match up with the output of other feature
// generators.
// The process is then further complicated when we quantize the model. This
// means we have to scale the 0.0 to 26.0 real values to the -128 to 127
// signed integer numbers.
// All this means that to get matching values from our integer feature
// output into the tensor input, we have to perform:
// input = (((feature / 25.6) / 26.0) * 256) - 128
// To simplify this and perform it in 32-bit integer math, we rearrange to:
// input = (feature * 256) / (25.6 * 26.0) - 128
constexpr int32_t value_scale = 256;
constexpr int32_t value_div = static_cast<int32_t>((25.6f * 26.0f) + 0.5f);
int32_t value =
((frontend_output.values[i] * value_scale) + (value_div / 2)) /
value_div;
value -= 128;
if (value < -128) {
value = -128;
}
if (value > 127) {
value = 127;
}
output[i] = value;
}
return kTfLiteOk;
}

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Sets up any resources needed for the feature generation pipeline.
TfLiteStatus InitializeMicroFeatures(tflite::ErrorReporter* error_reporter);
// Converts audio sample data into a more compact form that's appropriate for
// feeding into a neural network.
TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter,
const int16_t* input, int input_size,
int output_size, int8_t* output,
size_t* num_samples_read);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_

View File

@ -0,0 +1,23 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "micro_features_micro_model_settings.h"
const char* kCategoryLabels[kCategoryCount] = {
"silence",
"unknown",
"yes",
"no",
};

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_
// Keeping these as constant expressions allow us to allocate fixed-sized arrays
// on the stack for our working memory.
// The size of the input time series data we pass to the FFT to produce the
// frequency information. This has to be a power of two, and since we're dealing
// with 30ms of 16KHz inputs, which means 480 samples, this is the next value.
constexpr int kMaxAudioSampleSize = 512;
constexpr int kAudioSampleFrequency = 16000;
// The following values are derived from values used during model training.
// If you change the way you preprocess the input, update all these constants.
constexpr int kFeatureSliceSize = 40;
constexpr int kFeatureSliceCount = 49;
constexpr int kFeatureElementCount = (kFeatureSliceSize * kFeatureSliceCount);
constexpr int kFeatureSliceStrideMs = 20;
constexpr int kFeatureSliceDurationMs = 30;
// Variables for the model's output categories.
constexpr int kSilenceIndex = 0;
constexpr int kUnknownIndex = 1;
// If you modify the output categories, you need to update the following values.
constexpr int kCategoryCount = 4;
extern const char* kCategoryLabels[kCategoryCount];
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a standard TensorFlow Lite FlatBuffer model file that has been
// converted into a C data array, so it can be easily compiled into a binary
// for devices that don't have a file system. It was created using the command:
// xxd -i model.tflite > model.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MODEL_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MODEL_H_
extern const unsigned char g_model[];
extern const int g_model_len;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MODEL_H_

View File

@ -0,0 +1,216 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <TensorFlowLite.h>
#include "audio_provider.h"
#include "command_responder.h"
#include "feature_provider.h"
#include "main_functions.h"
#include "micro_features_micro_model_settings.h"
#include "micro_features_model.h"
#include "recognize_commands.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
#undef PROFILE_MICRO_SPEECH
// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* model_input = nullptr;
FeatureProvider* feature_provider = nullptr;
RecognizeCommands* recognizer = nullptr;
int32_t previous_time = 0;
// Create an area of memory to use for input, output, and intermediate arrays.
// The size of this will depend on the model you're using, and may need to be
// determined by experimentation.
constexpr int kTensorArenaSize = 10 * 1024;
uint8_t tensor_arena[kTensorArenaSize];
int8_t feature_buffer[kFeatureElementCount];
int8_t* model_input_buffer = nullptr;
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
tflite::InitializeTarget();
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
//
// tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
return;
}
if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
return;
}
if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
return;
}
if (micro_op_resolver.AddReshape() != kTfLiteOk) {
return;
}
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
// Get information about the memory area to use for the model's input.
model_input = interpreter->input(0);
if ((model_input->dims->size != 2) || (model_input->dims->data[0] != 1) ||
(model_input->dims->data[1] !=
(kFeatureSliceCount * kFeatureSliceSize)) ||
(model_input->type != kTfLiteInt8)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
return;
}
model_input_buffer = model_input->data.int8;
// Prepare to access the audio spectrograms from a microphone or other source
// that will provide the inputs to the neural network.
// NOLINTNEXTLINE(runtime-global-variables)
static FeatureProvider static_feature_provider(kFeatureElementCount,
feature_buffer);
feature_provider = &static_feature_provider;
static RecognizeCommands static_recognizer(error_reporter);
recognizer = &static_recognizer;
previous_time = 0;
// start the audio
TfLiteStatus init_status = InitAudioRecording(error_reporter);
if (init_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Unable to initialize audio");
return;
}
TF_LITE_REPORT_ERROR(error_reporter, "Initialization complete");
}
// The name of this function is important for Arduino compatibility.
void loop() {
#ifdef PROFILE_MICRO_SPEECH
const uint32_t prof_start = millis();
static uint32_t prof_count = 0;
static uint32_t prof_sum = 0;
static uint32_t prof_min = std::numeric_limits<uint32_t>::max();
static uint32_t prof_max = 0;
#endif // PROFILE_MICRO_SPEECH
// Fetch the spectrogram for the current time.
const int32_t current_time = LatestAudioTimestamp();
int how_many_new_slices = 0;
TfLiteStatus feature_status = feature_provider->PopulateFeatureData(
error_reporter, previous_time, current_time, &how_many_new_slices);
if (feature_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Feature generation failed");
return;
}
previous_time += how_many_new_slices * kFeatureSliceStrideMs;
// If no new audio samples have been received since last time, don't bother
// running the network model.
if (how_many_new_slices == 0) {
return;
}
// Copy feature buffer to input tensor
for (int i = 0; i < kFeatureElementCount; i++) {
model_input_buffer[i] = feature_buffer[i];
}
// Run the model on the spectrogram input and make sure it succeeds.
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed");
return;
}
// Obtain a pointer to the output tensor
TfLiteTensor* output = interpreter->output(0);
// Determine whether a command was recognized based on the output of inference
const char* found_command = nullptr;
uint8_t score = 0;
bool is_new_command = false;
TfLiteStatus process_status = recognizer->ProcessLatestResults(
output, current_time, &found_command, &score, &is_new_command);
if (process_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter,
"RecognizeCommands::ProcessLatestResults() failed");
return;
}
// Do something based on the recognized command. The default implementation
// just prints to the error console, but you should replace this with your
// own function for a real application.
RespondToCommand(error_reporter, current_time, found_command, score,
is_new_command);
#ifdef PROFILE_MICRO_SPEECH
const uint32_t prof_end = millis();
if (++prof_count > 10) {
uint32_t elapsed = prof_end - prof_start;
prof_sum += elapsed;
if (elapsed < prof_min) {
prof_min = elapsed;
}
if (elapsed > prof_max) {
prof_max = elapsed;
}
if (prof_count % 300 == 0) {
TF_LITE_REPORT_ERROR(error_reporter,
"## time: min %dms max %dms avg %dms", prof_min,
prof_max, prof_sum / prof_count);
}
}
#endif // PROFILE_MICRO_SPEECH
}

View File

@ -0,0 +1,159 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "recognize_commands.h"
#include <limits>
#undef DEBUG_MICRO_SPEECH
RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter,
int32_t average_window_duration_ms,
uint8_t detection_threshold,
int32_t suppression_ms,
int32_t minimum_count)
: error_reporter_(error_reporter),
average_window_duration_ms_(average_window_duration_ms),
detection_threshold_(detection_threshold),
suppression_ms_(suppression_ms),
minimum_count_(minimum_count),
previous_results_(error_reporter) {
previous_top_label_ = kCategoryLabels[0]; // silence
previous_top_label_time_ = std::numeric_limits<int32_t>::min();
}
TfLiteStatus RecognizeCommands::ProcessLatestResults(
const TfLiteTensor* latest_results, const int32_t current_time_ms,
const char** found_command, uint8_t* score, bool* is_new_command) {
if ((latest_results->dims->size != 2) ||
(latest_results->dims->data[0] != 1) ||
(latest_results->dims->data[1] != kCategoryCount)) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"The results for recognition should contain %d elements, but there are "
"%d in an %d-dimensional shape",
kCategoryCount, latest_results->dims->data[1],
latest_results->dims->size);
return kTfLiteError;
}
if (latest_results->type != kTfLiteInt8) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"The results for recognition should be int8_t elements, but are %d",
latest_results->type);
return kTfLiteError;
}
if ((!previous_results_.empty()) &&
(current_time_ms < previous_results_.front().time_)) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Results must be fed in increasing time order, but received a "
"timestamp of %d that was earlier than the previous one of %d",
current_time_ms, previous_results_.front().time_);
return kTfLiteError;
}
// Prune any earlier results that are too old for the averaging window.
const int64_t time_limit = current_time_ms - average_window_duration_ms_;
while ((!previous_results_.empty()) &&
previous_results_.front().time_ < time_limit) {
previous_results_.pop_front();
}
// Add the latest results to the head of the queue.
previous_results_.push_back({current_time_ms, latest_results->data.int8});
// If there are too few results, assume the result will be unreliable and
// bail.
const int64_t how_many_results = previous_results_.size();
const int64_t earliest_time = previous_results_.front().time_;
const int64_t samples_duration = current_time_ms - earliest_time;
if ((how_many_results < minimum_count_) ||
(samples_duration < (average_window_duration_ms_ / 4))) {
*found_command = previous_top_label_;
*score = 0;
*is_new_command = false;
return kTfLiteOk;
}
// Calculate the average score across all the results in the window.
int32_t average_scores[kCategoryCount];
for (int offset = 0; offset < previous_results_.size(); ++offset) {
PreviousResultsQueue::Result previous_result =
previous_results_.from_front(offset);
const int8_t* scores = previous_result.scores;
for (int i = 0; i < kCategoryCount; ++i) {
if (offset == 0) {
average_scores[i] = scores[i] + 128;
} else {
average_scores[i] += scores[i] + 128;
}
}
}
for (int i = 0; i < kCategoryCount; ++i) {
average_scores[i] /= how_many_results;
}
// Find the current highest scoring category.
int current_top_index = 0;
int32_t current_top_score = 0;
for (int i = 0; i < kCategoryCount; ++i) {
if (average_scores[i] > current_top_score) {
current_top_score = average_scores[i];
current_top_index = i;
}
}
const char* current_top_label = kCategoryLabels[current_top_index];
// If we've recently had another label trigger, assume one that occurs too
// soon afterwards is a bad result.
int64_t time_since_last_top;
if ((previous_top_label_ == kCategoryLabels[0]) ||
(previous_top_label_time_ == std::numeric_limits<int32_t>::min())) {
time_since_last_top = std::numeric_limits<int32_t>::max();
} else {
time_since_last_top = current_time_ms - previous_top_label_time_;
}
if ((current_top_score > detection_threshold_) &&
((current_top_label != previous_top_label_) ||
(time_since_last_top > suppression_ms_))) {
#ifdef DEBUG_MICRO_SPEECH
TF_LITE_REPORT_ERROR(
error_reporter_, "Scores: s %d u %d y %d n %d %s -> %s",
average_scores[0], average_scores[1], average_scores[2],
average_scores[3], previous_top_label_, current_top_label);
#endif // DEBUG_MICRO_SPEECH
previous_top_label_ = current_top_label;
previous_top_label_time_ = current_time_ms;
*is_new_command = true;
} else {
#ifdef DEBUG_MICRO_SPEECH
if (current_top_label != previous_top_label_) {
TF_LITE_REPORT_ERROR(
error_reporter_, "#Scores: s %d u %d y %d n %d %s -> %s",
average_scores[0], average_scores[1], average_scores[2],
average_scores[3], previous_top_label_, current_top_label);
previous_top_label_ = current_top_label;
}
#endif // DEBUG_MICRO_SPEECH
*is_new_command = false;
}
*found_command = current_top_label;
*score = current_top_score;
return kTfLiteOk;
}

View File

@ -0,0 +1,159 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#include <cstdint>
#include "micro_features_micro_model_settings.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Partial implementation of std::dequeue, just providing the functionality
// that's needed to keep a record of previous neural network results over a
// short time period, so they can be averaged together to produce a more
// accurate overall prediction. This doesn't use any dynamic memory allocation
// so it's a better fit for microcontroller applications, but this does mean
// there are hard limits on the number of results it can store.
class PreviousResultsQueue {
public:
PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
: error_reporter_(error_reporter), front_index_(0), size_(0) {}
// Data structure that holds an inference result, and the time when it
// was recorded.
struct Result {
Result() : time_(0), scores() {}
Result(int32_t time, int8_t* input_scores) : time_(time) {
for (int i = 0; i < kCategoryCount; ++i) {
scores[i] = input_scores[i];
}
}
int32_t time_;
int8_t scores[kCategoryCount];
};
int size() { return size_; }
bool empty() { return size_ == 0; }
Result& front() { return results_[front_index_]; }
Result& back() {
int back_index = front_index_ + (size_ - 1);
if (back_index >= kMaxResults) {
back_index -= kMaxResults;
}
return results_[back_index];
}
void push_back(const Result& entry) {
if (size() >= kMaxResults) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Couldn't push_back latest result, too many already!");
return;
}
size_ += 1;
back() = entry;
}
Result pop_front() {
if (size() <= 0) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Couldn't pop_front result, none present!");
return Result();
}
Result result = front();
front_index_ += 1;
if (front_index_ >= kMaxResults) {
front_index_ = 0;
}
size_ -= 1;
return result;
}
// Most of the functions are duplicates of dequeue containers, but this
// is a helper that makes it easy to iterate through the contents of the
// queue.
Result& from_front(int offset) {
if ((offset < 0) || (offset >= size_)) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Attempt to read beyond the end of the queue!");
offset = size_ - 1;
}
int index = front_index_ + offset;
if (index >= kMaxResults) {
index -= kMaxResults;
}
return results_[index];
}
private:
tflite::ErrorReporter* error_reporter_;
static constexpr int kMaxResults = 50;
Result results_[kMaxResults];
int front_index_;
int size_;
};
// This class is designed to apply a very primitive decoding model on top of the
// instantaneous results from running an audio recognition model on a single
// window of samples. It applies smoothing over time so that noisy individual
// label scores are averaged, increasing the confidence that apparent matches
// are real.
// To use it, you should create a class object with the configuration you
// want, and then feed results from running a TensorFlow model into the
// processing method. The timestamp for each subsequent call should be
// increasing from the previous, since the class is designed to process a stream
// of data over time.
class RecognizeCommands {
public:
// labels should be a list of the strings associated with each one-hot score.
// The window duration controls the smoothing. Longer durations will give a
// higher confidence that the results are correct, but may miss some commands.
// The detection threshold has a similar effect, with high values increasing
// the precision at the cost of recall. The minimum count controls how many
// results need to be in the averaging window before it's seen as a reliable
// average. This prevents erroneous results when the averaging window is
// initially being populated for example. The suppression argument disables
// further recognitions for a set time after one has been triggered, which can
// help reduce spurious recognitions.
explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
int32_t average_window_duration_ms = 1000,
uint8_t detection_threshold = 200,
int32_t suppression_ms = 1500,
int32_t minimum_count = 3);
// Call this with the results of running a model on sample data.
TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
const int32_t current_time_ms,
const char** found_command, uint8_t* score,
bool* is_new_command);
private:
// Configuration
tflite::ErrorReporter* error_reporter_;
int32_t average_window_duration_ms_;
uint8_t detection_threshold_;
int32_t suppression_ms_;
int32_t minimum_count_;
// Working variables
PreviousResultsQueue previous_results_;
const char* previous_top_label_;
int32_t previous_top_label_time_;
};
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_

View File

@ -0,0 +1,69 @@
# Person detection example
This example shows how you can use [Tensorflow Lite Micro](https://www.tensorflow.org/lite/microcontrollers) to run a 300.5 kilobyte neural
network to recognize people in images.
## Table of contents
<!--ts-->
* [Deploy to Arduino](#deploy-to-arduino)
* [Install the Arduino_TensorFlowLite library](#install-the-arduino_tensorflowlite-library)
* [Load and run the example](#load-and-run-the-example)
* [Training your own model](#training-your-own-model)
<!--te-->
## Deploy to Arduino
The following instructions will help you build and deploy this example to
[Arduino](https://www.arduino.cc/) devices.
The example has been tested with the following devices:
- [Tiny Machine Learning Kit](https://store.arduino.cc/products/arduino-tiny-machine-learning-kit)
The Arduino Tiny Machine Learning Kit uses the OV7675 camera attachment. The OV7675 is currently not supported, and the code will simply generate a blank image (support coming _**soon**_). If you're using a different Arduino board and attaching your own
camera, you'll need to implement your own `arduino_image_provider.cpp` code. It also has a
set of LEDs, which are used to indicate whether a person has been recognized.
### Install the Arduino_TensorFlowLite library
This example application is included as part of the official TensorFlow Lite Micro
Arduino library.
To install the TensorFlow Lite Micro for Arduino library, see the
[how to install](../../README.md#how-to-install) instructions.
### Load and run the example
Once the library has been added, go to `File -> Examples`. You should see an
entry within the list named `Arduino_TensorFlowLite`. Select
it and click `person_detection` to load the example.
Use the Arduino IDE to build and upload the example. Once it is running, you
should see the built-in LED on your device flashing. The built-in LED will flash on/off for each inference cycle. The green LED will be lit if a person is predicted,
The blue LED will be lit if the prediction is not-a-person.
The program also outputs inference results to the serial port, which appear as
follows:
```
Cropping image and quantizing
Image cropped and quantized
Person score: 39.6% No person score: 60.93%
```
When the program is run, it waits several seconds for a USB-serial connection to be
available. If there is no connection available, it will not output data. To see
the serial output in the Arduino desktop IDE, do the following:
1. Open the Arduino IDE
1. Connect the Arduino board to your computer via USB
1. Press the reset button on the Arduino board
1. Within 5 seconds, go to `Tools -> Serial Monitor` in the Arduino IDE. You may
have to try several times, since the board will take a moment to connect.
If you don't see any output, repeat the process again.
## Training your own model
You can train your own model with some easy-to-use scripts. See
[training_a_model.md](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/person_detection/training_a_model.md) for instructions.

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#define ARDUINO_EXCLUDE_CODE
#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#ifndef ARDUINO_EXCLUDE_CODE
#include <cmath>
#include "Arduino.h"
#include "detection_responder.h"
// Flash the yellow (builtin) LED after each inference
void RespondToDetection(tflite::ErrorReporter* error_reporter,
float person_score, float no_person_score) {
static bool is_initialized = false;
if (!is_initialized) {
pinMode(LED_BUILTIN, OUTPUT);
digitalWrite(LED_BUILTIN, HIGH);
// Pins for the built-in RGB LEDs on the Arduino Nano 33 BLE Sense
pinMode(LEDR, OUTPUT);
pinMode(LEDG, OUTPUT);
pinMode(LEDB, OUTPUT);
// Switch the LEDs off
digitalWrite(LEDG, HIGH);
digitalWrite(LEDB, HIGH);
digitalWrite(LEDR, HIGH);
is_initialized = true;
}
// Note: The RGB LEDs on the Arduino Nano 33 BLE
// Sense are on when the pin is LOW, off when HIGH.
// Switch on the green LED when a person is detected,
// the blue when no person is detected
if (person_score > no_person_score) {
digitalWrite(LEDG, LOW);
digitalWrite(LEDB, HIGH);
} else {
digitalWrite(LEDG, HIGH);
digitalWrite(LEDB, LOW);
}
// Flash the yellow LED after every inference.
// The builtin LED is on when the pin is HIGH
digitalWrite(LED_BUILTIN, LOW);
delay(100);
digitalWrite(LED_BUILTIN, HIGH);
float person_score_frac, person_score_int;
float no_person_score_frac, no_person_score_int;
person_score_frac = std::modf(person_score * 100, &person_score_int);
no_person_score_frac = std::modf(no_person_score * 100, &no_person_score_int);
TF_LITE_REPORT_ERROR(error_reporter,
"Person score: %d.%d%% No person score: %d.%d%%",
static_cast<int>(person_score_int),
static_cast<int>(person_score_frac * 100),
static_cast<int>(no_person_score_int),
static_cast<int>(no_person_score_frac * 100));
}
#endif // ARDUINO_EXCLUDE_CODE

View File

@ -0,0 +1,199 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <type_traits>
#include "image_provider.h"
#include "model_settings.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "test_over_serial/test_over_serial.h"
using namespace test_over_serial;
#if defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#define ARDUINO_EXCLUDE_CODE
#endif // defined(ARDUINO) && !defined(ARDUINO_ARDUINO_NANO33BLE)
#ifndef ARDUINO_EXCLUDE_CODE
#include "Arduino.h"
namespace {
constexpr size_t kQQVGA_width = 160; // pixels
constexpr size_t kQQVGA_height = 120; // pixels
uint8_t image_buffer[kQQVGA_height * kQQVGA_width];
constexpr size_t kImageBufferLength =
std::extent<decltype(image_buffer)>::value;
// Get the camera module ready
TfLiteStatus InitCamera(tflite::ErrorReporter* error_reporter) {
// This function kept for future implementation
TF_LITE_REPORT_ERROR(
error_reporter,
"OV7675 not yet supported. Blank image will be substituted.");
return kTfLiteOk;
}
// Begin the capture and wait for it to finish
TfLiteStatus PerformCapture(tflite::ErrorReporter* error_reporter) {
// This function kept for future implementation
TF_LITE_REPORT_ERROR(error_reporter, "Starting capture");
delay(50);
TF_LITE_REPORT_ERROR(error_reporter, "Image captured");
return kTfLiteOk;
}
// Read data from the camera module into a local buffer
TfLiteStatus ReadData(tflite::ErrorReporter* error_reporter) {
// This function kept for future implementation
// until OV7675 supported, just fill with zeros (black image)
std::fill_n(image_buffer, kImageBufferLength, 0);
return kTfLiteOk;
}
// Decode the image, crop it, and convert it to grayscale
TfLiteStatus CropAndQuantizeImage(tflite::ErrorReporter* error_reporter,
size_t image_width, size_t image_height,
const TfLiteTensor* tensor) {
TF_LITE_REPORT_ERROR(error_reporter, "Cropping image and quantizing");
// cropping parameters
const size_t vert_top = (image_height - kNumRows) / 2;
const size_t vert_bottom = vert_top + kNumRows - 1;
const size_t horz_left = (image_width - kNumCols) / 2;
const size_t horz_right = horz_left + kNumCols - 1;
const uint8_t* p = image_buffer + (vert_top * image_width);
p += horz_left;
int8_t* image_data = tensor->data.int8;
for (size_t line = vert_top; line <= vert_bottom; line++) {
for (size_t row = horz_left; row <= horz_right; row++, p++) {
*image_data++ = tflite::FloatToQuantizedType<int8_t>(
p[0] / 255.0f, tensor->params.scale, tensor->params.zero_point);
}
// move to next line
p += ((image_width - 1) - horz_right) + horz_left;
}
TF_LITE_REPORT_ERROR(error_reporter, "Image cropped and quantized");
return kTfLiteOk;
}
// Get an image from the camera module
TfLiteStatus GetCameraImage(tflite::ErrorReporter* error_reporter,
const TfLiteTensor* tensor) {
static bool g_is_camera_initialized = false;
if (!g_is_camera_initialized) {
TfLiteStatus init_status = InitCamera(error_reporter);
if (init_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "InitCamera failed");
return init_status;
}
g_is_camera_initialized = true;
}
TfLiteStatus capture_status = PerformCapture(error_reporter);
if (capture_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "PerformCapture failed");
return capture_status;
}
TfLiteStatus read_data_status = ReadData(error_reporter);
if (read_data_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "ReadData failed");
return read_data_status;
}
TfLiteStatus decode_status =
CropAndQuantizeImage(error_reporter, kQQVGA_width, kQQVGA_height, tensor);
if (decode_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "CropAndQuantizeImage failed");
return decode_status;
}
return kTfLiteOk;
}
TfLiteStatus GetTestImage(tflite::ErrorReporter* error_reporter,
TestOverSerial& test, const TfLiteTensor* tensor) {
volatile bool done = false;
volatile bool aborted = false;
volatile size_t image_width = 0, image_height = 0;
InputHandler handler = [&aborted, &done, &image_width,
&image_height](const InputBuffer* const input) {
if (0 == input->offset) {
if ((kQQVGA_height * kQQVGA_width) == input->total) {
image_width = kQQVGA_width;
image_height = kQQVGA_height;
} else if ((kNumCols * kNumRows) == input->total) {
image_width = kNumCols;
image_height = kNumRows;
} else {
// image dimensions are not supported, abort input processing
aborted = true;
return false;
}
}
std::copy_n(input->data.uint8, input->length, &image_buffer[input->offset]);
if (input->total == (input->offset + input->length)) {
done = true;
}
return true;
};
while (!done) {
test.ProcessInput(&handler);
if (aborted) {
TF_LITE_REPORT_ERROR(error_reporter, "Input processing aborted");
return kTfLiteError;
}
// wait for a full image from serial port before processing
if (done) {
TfLiteStatus decode_status = CropAndQuantizeImage(
error_reporter, image_width, image_height, tensor);
if (decode_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "CropAndQuantizeImage failed");
return decode_status;
}
}
}
return kTfLiteOk;
}
} // namespace
TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter,
const TfLiteTensor* tensor) {
TestOverSerial& test = TestOverSerial::Instance(kIMAGE_GRAYSCALE);
if (!test.IsTestMode()) {
// check serial port for test mode command
test.ProcessInput(nullptr);
}
if (test.IsTestMode()) {
return GetTestImage(error_reporter, test, tensor);
} else {
// get an image from the camera
return GetCameraImage(error_reporter, tensor);
}
// NOTREACHED
}
#endif // ARDUINO_EXCLUDE_CODE

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
// Arduino automatically calls the setup() and loop() functions in a sketch, so
// where other systems need their own main routine in this file, it can be left
// empty.

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Provides an interface to take an action based on the output from the person
// detection model.
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// Called every time the results of a person detection run are available. The
// `person_score` has the numerical confidence that the captured image contains
// a person, and `no_person_score` has the numerical confidence that the image
// does not contain a person. Typically if person_score > no person score, the
// image is considered to contain a person. This threshold may be adjusted for
// particular applications.
void RespondToDetection(tflite::ErrorReporter* error_reporter,
float person_score, float no_person_score);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_

View File

@ -0,0 +1,40 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
// This is an abstraction around an image source like a camera, and is
// expected to place 8-bit quantized sample data into the tensor.
//
// The assumption is that this will be
// called in a low duty-cycle fashion in a low-power application. In these
// cases, the imaging sensor need not be run in a streaming mode, but rather can
// be idled in a relatively low-power mode between calls to GetImage(). The
// assumption is that the overhead and time of bringing the low-power sensor out
// of this standby mode is commensurate with the expected duty cycle of the
// application. The underlying sensor may actually be put into a streaming
// configuration, but the tensor provided to GetImage should not be
// overwritten by the driver code until the next call to GetImage();
//
// For real applications, you should
// ensure there's a specialized implementation that accesses hardware APIs.
TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter,
const TfLiteTensor* tensor);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_
// Expose a C friendly interface for main functions.
#ifdef __cplusplus
extern "C" {
#endif
// Initializes all data needed for the example. The name is important, and needs
// to be setup() for Arduino compatibility.
void setup();
// Runs one iteration of data gathering and inference. This should be called
// repeatedly from the application code. The name needs to be loop() for Arduino
// compatibility.
void loop();
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_

View File

@ -0,0 +1,21 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "model_settings.h"
const char* kCategoryLabels[kCategoryCount] = {
"notperson",
"person",
};

View File

@ -0,0 +1,35 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_
// Keeping these as constant expressions allow us to allocate fixed-sized arrays
// on the stack for our working memory.
// All of these values are derived from the values used during model training,
// if you change your model you'll need to update these constants.
constexpr int kNumCols = 96;
constexpr int kNumRows = 96;
constexpr int kNumChannels = 1;
constexpr int kMaxImageSize = kNumCols * kNumRows * kNumChannels;
constexpr int kCategoryCount = 2;
constexpr int kPersonIndex = 1;
constexpr int kNotAPersonIndex = 0;
extern const char* kCategoryLabels[kCategoryCount];
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a standard TensorFlow Lite model file that has been converted into a
// C data array, so it can be easily compiled into a binary for devices that
// don't have a file system. It was created using the command:
// xxd -i person_detect.tflite > person_detect_model_data.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_
extern const unsigned char g_person_detect_model_data[];
extern const int g_person_detect_model_data_len;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <TensorFlowLite.h>
#include "detection_responder.h"
#include "image_provider.h"
#include "main_functions.h"
#include "model_settings.h"
#include "person_detect_model_data.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
// In order to use optimized tensorflow lite kernels, a signed int8_t quantized
// model is preferred over the legacy unsigned model format. This means that
// throughout this project, input images must be converted from unisgned to
// signed format. The easiest and quickest way to convert from unsigned to
// signed 8-bit integers is to subtract 128 from the unsigned value to get a
// signed value.
// An area of memory to use for input, output, and intermediate arrays.
constexpr int kTensorArenaSize = 136 * 1024;
static uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
tflite::InitializeTarget();
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_person_detect_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
//
// tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddAveragePool2D();
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax();
// Build an interpreter to run the model with.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
// Get information about the memory area to use for the model's input.
input = interpreter->input(0);
if ((input->dims->size != 4) || (input->dims->data[0] != 1) ||
(input->dims->data[1] != kNumRows) ||
(input->dims->data[2] != kNumCols) ||
(input->dims->data[3] != kNumChannels) || (input->type != kTfLiteInt8)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
return;
}
}
// The name of this function is important for Arduino compatibility.
void loop() {
// Get image from provider.
if (kTfLiteOk != GetImage(error_reporter, input)) {
TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
}
// Run the model on this input and make sure it succeeds.
if (kTfLiteOk != interpreter->Invoke()) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");
}
TfLiteTensor* output = interpreter->output(0);
// Process the inference results.
int8_t person_score = output->data.uint8[kPersonIndex];
int8_t no_person_score = output->data.uint8[kNotAPersonIndex];
float person_score_f =
(person_score - output->params.zero_point) * output->params.scale;
float no_person_score_f =
(no_person_score - output->params.zero_point) * output->params.scale;
RespondToDetection(error_reporter, person_score_f, no_person_score_f);
}

View File

@ -0,0 +1,15 @@
{
"name": "tf_lite_esp32",
"version": "0.0.1",
"keywords": "tensor flow",
"description": "Tensor flow lite for Arduino-ESP32",
"frameworks": "arduino",
"platforms": "espressif32",
"build": {
"libArchive": true,
"flags": [
"-I Source/include",
"-I Source/esp-nn/include"
]
}
}

View File

@ -0,0 +1,13 @@
name=Arduino_TensorFlowLite for ESP32
version=0.0.1-ALPHA
author=TensorFlow Authors
maintainer=Pete Warden <petewarden@google.com>
sentence=Allows you to run machine learning models locally on your device.
paragraph=This library runs TensorFlow machine learning models on microcontrollers, allowing you to build AI/ML applications powered by deep learning and neural networks. With the included examples, you can recognize speech, detect people using a camera, and recognise "magic wand" gestures using an accelerometer.
category=Data Processing
url=https://www.tensorflow.org/lite/microcontrollers/overview
ldflags=
includes=TensorFlowLite.h
precompiled=full
dot_a_linkage=false
depends=Arduino

View File

@ -0,0 +1,26 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_TOOLS_MAKE_TEMPLATES_TENSORFLOWLITE_H_
#define TENSORFLOW_LITE_MICRO_TOOLS_MAKE_TEMPLATES_TENSORFLOWLITE_H_
// This header is deliberately empty, and is only present because including it
// in a .ino sketch forces the Arduino toolchain to build the rest of the
// library.
#include "third_party/flatbuffers/include/flatbuffers/flatbuffers.h"
#include "esp-nn/include/esp_nn.h"
#endif // TENSORFLOW_LITE_MICRO_TOOLS_MAKE_TEMPLATES_TENSORFLOWLITE_H_

View File

@ -0,0 +1,57 @@
.config
*.o
*.i
*.s
*.orig
*.pyc
# gtags
GTAGS
GRTAGS
GPATH
# emacs
.dir-locals.el
# emacs temp file suffixes
*~
.#*
\#*#
# eclipse setting
.settings
# MacOS directory files
.DS_Store
# Example project files
examples/**/sdkconfig
examples/**/sdkconfig.old
examples/**/build
# Test app files
test_app/build
test_app/sdkconfig
test_app/sdkconfig.old
# Doc build artifacts
docs/_build/
docs/doxygen-warning-log.txt
docs/sphinx-warning-log.txt
docs/sphinx-warning-log-sanitized.txt
docs/xml/
docs/xml_in/
docs/man/
docs/doxygen_sqlite3.db
TEST_LOGS
# gcov coverage reports
*.gcda
*.gcno
coverage.info
coverage_report/
# VS Code Settings
.vscode/

View File

@ -0,0 +1,55 @@
stages:
- build
variables:
BATCH_BUILD: "1"
V: "0"
MAKEFLAGS: "-j8 --no-keep-going"
IDF_PATH: "$CI_PROJECT_DIR/esp-idf"
LOG_PATH: "$CI_PROJECT_DIR"
.set_git_config: &set_git_config
# Set git config
- git config user.email "test@espressif.com"
- git config user.name "Espressif"
.add_ssh_key: &add_ssh_key
# Add gitlab ssh key
- mkdir -p ~/.ssh
- chmod 700 ~/.ssh
- echo -n $GITLAB_KEY > ~/.ssh/id_rsa_base64
- base64 --decode --ignore-garbage ~/.ssh/id_rsa_base64 > ~/.ssh/id_rsa
- chmod 600 ~/.ssh/id_rsa
- echo -e "Host gitlab.espressif.cn\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config
before_script:
# Add gitlab ssh key
- *add_ssh_key
# Set git config
- *set_git_config
.build_esp32s3: &build_esp32s3
- idf.py set-target esp32s3 build
.build_esp32: &build_esp32
- idf.py set-target esp32 build
build_demo:
stage: build
image: $CI_DOCKER_REGISTRY/esp32-ci-env:esp-nn
tags:
- build
script:
# Clone IDF
- git clone --recursive --single-branch -b release/v4.4 --reference-if-able /local_references/gitlab/ https://gitlab-ci-token:${BOT_TOKEN}@gitlab.espressif.cn:6688/espressif/esp-idf.git
- cd esp-idf
- ./install.sh
- . ./export.sh
- cd ..
# Build examples now
- cd test_app
# Build esp32s3
- *build_esp32s3
# Build esp32
- *build_esp32
- cd -

View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,55 @@
# ESP-NN
The library contains optimised NN (Neural Network) functions for various Espressif chipsets.
* Supported platforms:
* TensorFlow Lite Micro (TFLite Micro). Repo can be found [here](https://github.com/espressif/tflite-micro-esp-examples)
* Supported ESP chipsets include:
* ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
* ESP32 (Generic optimisations)
* ESP32-C3 (Generic optimisations)
## Performance
### Kernelwise performance for s8 versions:
* Kernelwise performance on ESP32-S3 chip
* Numbers are ticks taken for kernel to execute
* Chip config: 240MHz, SPI: QPI 80MHz, Data cache: 64KB
| Function | ANSI C | ESP32-S3 Opt | Opt Ratio | Data info | Memory |
| ----------------| --------|---------|---------|-------------|-----------|
| elementwise_add | 320397 | 87119 | 3.68 | size = 1615 | External |
| elementwise_mul | 125958 | 44239 | 2.85 | size = 1615 | External |
| convolution | 4663012 | 428675 | 10.88 | input(10,10), filter(64x1x1x64) | External |
| convolution | 301014 | 32433 | 9.28 | input(8,8), filter(16x1x1x16) | External |
| convolution | 2115418 | 1020923 | 2.07 | input(10,10), filter(64x3x3x3) | External |
| depthwise conv | 1190062 | 203278 | 5.85 | input (18, 18), pad(0,0), stride(1,1) filter: 1x3x3x16 | External |
| depthwise conv | 837072 | 182335 | 4.59 | input (12, 12), pad(1,1), stride(1,1) filter: 8x5x5x4 | External |
| max pool | 485714 | 76747 | 6.33 | input(16,16), filter (1x3x3x16) | Internal |
| avg pool | 541462 | 160580 | 3.37 | input(16,16), filter (1x3x3x16) | Internal |
| fully connected | 15853 | 9547 | 1.66 | len: 265, ch = 3 | Internal |
| prelu (relu6) | 19472 | 2734 | 7.12 | size, 1615 | Internal |
## Configuration
* To configure, please use `idf.py menuconfig` and under `ESP-NN` select `NN_OPTIMIZATIONS`
* There are two options presented:
* Optimized versions
* ANSI C
* Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for other chipsets (viz., ESP32, ESP32-C3), generic optimisations are selected.
* For debugging purposes, you may want to select `ANSI C` reference versions.
## Contributing
If you encounter an issue with ESP-NN, or wish to submit a feature request, please use the Issues section on the Github.
For general questions related to this library, please use the esp32.com forum.
## Copyrights and License
All original source code in this repository is Copyright (C) 2020-2021 Espressif Systems. This source code is licensed under the Apache License 2.0 as described in the file LICENSE.

View File

@ -0,0 +1,46 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(CONFIG_NN_OPTIMIZED)
// select apt optimisations
#ifdef CONFIG_IDF_TARGET_ESP32S3
#define ARCH_ESP32_S3 1
#endif
#ifdef CONFIG_IDF_TARGET_ESP32
#define ARCH_ESP32 1
#endif
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* reference kernels included by default */
#include "esp_nn_ansi_headers.h"
#if defined(CONFIG_NN_OPTIMIZED)
#if defined(ARCH_ESP32_S3)
#include "esp_nn_esp32s3.h"
#else // for other platforms use generic optimisations
#include "esp_nn_generic_opt.h"
#endif // #if defined(ARCH_ESP32_S3)
#else
#include "esp_nn_ansi_c.h"
#endif
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,47 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file Header definitions to include for ANSI C versions.
* These are just typedefs to pick up ANSI versions.
*/
#pragma once
#include "esp_nn_defs.h"
#include "esp_nn_ansi_headers.h"
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_ansi
#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_ansi
#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_ansi
#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_ansi
#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_ansi
#define esp_nn_softmax_s8 esp_nn_softmax_s8_ansi

View File

@ -0,0 +1,309 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/**
* @file Header definitions to include for esp_nn reference functions
*/
#include "esp_nn_defs.h"
/************************** Basic math functions ****************************/
/**
* @brief elementwise addition
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*
* shift values are expected to be <= 0
*/
void esp_nn_add_elementwise_s8_ansi(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
const int32_t input1_mult,
const int32_t input2_mult,
const int32_t input1_shift,
const int32_t input2_shift,
const int32_t left_shift,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size);
/**
* @brief elementwise multiplication
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*
* output shift is expected to be <= 0
*/
void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size);
/************************** Convolution functions *****************************/
/**
* @brief depthwise convolution per channel
*
* @note inputs type: int8_t, output: int8_t
* Version used in tflite is per channel.
* This version follows the same footsprints.
* Meaning, it has per out_channel shift and multiplier for
* requantization
*
* optimization notes: Though input_offset is int32 type,
* offset values are contained in 8 bits [-128, 127]
*/
void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data);
/**
* @brief 2d-convolution channelwise
*
* @note operation: result += (input + offset) * filter
*
* inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data);
int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params);
void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params);
void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
/************************** Activation functions *****************************/
/**
* @brief relu6
*
* @note inout: int8_t
*/
void esp_nn_relu6_s8_ansi(int8_t *data, uint16_t size);
/************************** Pooling functions *****************************/
/**
* @brief max_pool
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_max_pool_s8_ansi(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels);
/**
* @brief avg_pool
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_avg_pool_s8_ansi(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels);
/************************** Fully connected functions ***********************/
/**
* @brief fully connected
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_fully_connected_s8_ansi(const int8_t *input_data,
const int32_t input_offset,
const uint16_t row_len,
const int8_t *filter_data,
const int32_t filter_offset,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t out_shift,
const int32_t out_mult,
const int32_t activation_min,
const int32_t activation_max);
/**
* @brief Get scratch buffer size needed by softmax function
*
* @param width
* @param height
* @return size in bytes
*
* @note buffer must be 4 byte aligned
*/
int32_t esp_nn_get_softmax_scratch_size_ansi(const int32_t width, const int32_t height);
/* ANSI C function to be hooked up when optimised version needed */
int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t height);
/**
* @brief Set scratch buffer to be used by softmax function
*
* @param buffer this can be NULL if one needs to unset it
* must be aligned to 4 bytes
*/
void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
/**
* @brief reference softmax function
*
* @note inputs type: int8_t, output: int8_t
*/
void esp_nn_softmax_s8_ansi(const int8_t *input_data,
const int32_t height,
const int32_t width,
const int32_t mult,
const int32_t shift,
const int32_t diff_min,
int8_t *output_data);
//////////////////////////// Generic optimisations /////////////////////////////
/************************** Convolution functions *****************************/
/**
* @brief 2d-convolution channelwise optimized version
*
* @note operation: result += (input + offset) * filter
*
* inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data);
/**
* @brief depthwise convolution per channel optimized version
*
* @note inputs type: int8_t, output: int8_t
* Version used in tflite is per channel.
* This version follows the same footsprints.
* Meaning, it has per out_channel shift and multiplier for
* requantization
*
* optimization notes: Though input_offset is int32 type,
* offset values are contained in 8 bits [-128, 127]
*/
void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data);
int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params);
void esp_nn_set_conv_scratch_buf_opt(const void *buf);
int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params);
void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
/* ANSI C function to be hooked up when optimised version needed */
void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
/**
* @brief optimised version of softmax function
*
* @note the function uses extra buffer (4 * width bytes)
* hence, scratch buffers must be set before calling this.
*/
void esp_nn_softmax_s8_opt(const int8_t *input_data,
const int32_t height,
const int32_t width,
const int32_t mult,
const int32_t shift,
const int32_t diff_min,
int8_t *output_data);

View File

@ -0,0 +1,83 @@
// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
/**
* @brief structure to club data dims
* this structure can be used for input, output and filter
*/
typedef struct data_dims {
int32_t width;
int32_t height;
int32_t channels;
int32_t extra; // can be used as batch or any other param
} data_dims_t;
/**
* @brief 2d data structure (width, height)
*
*/
typedef struct data_2d {
int32_t width;
int32_t height;
} data_2d_t;
/**
* @brief min/max activation
*/
typedef struct act_params {
int32_t min;
int32_t max;
} act_params_t;
/**
* @brief per channel quant data
*
* @note number of shift and mult elements are equal to output channels
*/
typedef struct quant_data {
int32_t *shift;
int32_t *mult;
} quant_data_t;
/**
* @brief params specific to convolution 2d
*
*/
typedef struct conv_params {
int32_t in_offset;
int32_t out_offset;
data_2d_t stride;
data_2d_t padding;
data_2d_t dilation;
act_params_t activation;
} conv_params_t;
/**
* @brief params specific to depthwise convolution 2d
*
*/
typedef struct dw_conv_params {
int32_t in_offset;
int32_t out_offset;
int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch)
data_2d_t stride;
data_2d_t padding;
data_2d_t dilation;
act_params_t activation;
} dw_conv_params_t;

View File

@ -0,0 +1,233 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file Header definitions to include for esp_nn optimized functions for
* the ESP32-S3 platform
*/
#pragma once
#ifdef CONFIG_IDF_TARGET_ESP32S3
#include "esp_nn_defs.h"
#include "esp_nn_ansi_headers.h"
/************************** Basic math functions *****************************/
/**
* @brief elementwise addition
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*
* shift values are expected to be <= 0
*/
void esp_nn_add_elementwise_s8_esp32s3(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
const int32_t input1_mult,
const int32_t input2_mult,
const int32_t input1_shift,
const int32_t input2_shift,
const int32_t left_shift,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size);
/**
* @brief elementwise multiplication
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*
* output shift is expected to be <= 0
*/
void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size);
/************************** Convolution functions *****************************/
/**
* @brief depthwise convolution per channel
*
* @note inputs type: int8_t, output: int8_t
* Version used in tflite is per channel.
* This version follows the same footsprints.
* Meaning, it has per out_channel shift and multiplier for
* requantization
*
* optimization notes: Though input_offset is int32 type,
* offset values are contained in 8 bits [-128, 127]
*/
void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *output_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data);
/**
* @brief 2d - convolution channelwise
*
* @note operation: result += (input + offset) * filter
*
* inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *output_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data);
int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params);
void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params);
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
/************************** Pooling functions *****************************/
/**
* @brief max_pool
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_max_pool_s8_esp32s3(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels);
/**
* @brief avg_pool
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*/
void esp_nn_avg_pool_s8_esp32s3(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels);
/************************** Fully connected functions *****************************/
/**
* @brief fully connected
*
* @note inputs type: int8_t, output: int8_t
* input offsets: although int32_t, they are contained in 8 bits [-128, 127]
*
* Current version works only on aligned input.
* row_len and channels should both be multiple of 8.
*/
void esp_nn_fully_connected_s8_esp32s3(const int8_t *input_data,
const int32_t input_offset,
const uint16_t row_len,
const int8_t *filter_data,
const int32_t filter_offset,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t out_shift,
const int32_t out_mult,
const int32_t activation_min,
const int32_t activation_max);
/**
* @brief relu6
*
* @note inout: int8_t
*/
void esp_nn_relu6_s8_esp32s3(int8_t *data, uint16_t size);
/********************** function defines ***************************/
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_esp32s3
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_esp32s3
#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_esp32s3
#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_esp32s3
#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_esp32s3
#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_esp32s3
#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_esp32s3
#define esp_nn_conv_s8 esp_nn_conv_s8_esp32s3
#define esp_nn_relu6_s8 esp_nn_relu6_s8_esp32s3
#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_esp32s3
#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_esp32s3
#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_esp32s3
#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_opt
#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_opt
#define esp_nn_softmax_s8 esp_nn_softmax_s8_opt
#endif // CONFIG_IDF_TARGET_ESP32S3

View File

@ -0,0 +1,47 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file Header definitions to include for esp_nn generic optimisations
* For functions which not having optimisations, _ansi versions are picked.
*/
#pragma once
#include "esp_nn_defs.h"
#include "esp_nn_ansi_headers.h"
#define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
#define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt
#define esp_nn_conv_s8 esp_nn_conv_s8_opt
#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt
#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt
#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt
#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt
#define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
#define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_ansi
#define esp_nn_max_pool_s8 esp_nn_max_pool_s8_ansi
#define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_ansi
#define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_opt
#define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_opt
#define esp_nn_softmax_s8 esp_nn_softmax_s8_opt

View File

@ -0,0 +1,30 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <stdlib.h>
#include "../common/common_functions.h"
void esp_nn_relu6_s8_ansi(int8_t *data, uint16_t size)
{
int32_t i;
for (i = 0; i < size; i++) {
int32_t ip = data[i];
ip = max(ip, 0);
data[i] = min(ip, 6);
}
}

View File

@ -0,0 +1,97 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
void esp_nn_add_elementwise_u8_ansi(const uint8_t *input1_data,
const uint8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
const int32_t input1_mult,
const int32_t input2_mult,
const int32_t input1_shift,
const int32_t input2_shift,
const int32_t left_shift,
uint8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size)
{
for (int i = 0; i < size; i++) {
int32_t tmp1 = input1_data[i] + input1_offset;
int32_t tmp2 = input2_data[i] + input2_offset;
tmp1 <<= left_shift;
tmp2 <<= left_shift;
tmp1 = esp_nn_sat_round_doubling_high_mul(tmp1, input1_mult);
tmp2 = esp_nn_sat_round_doubling_high_mul(tmp2, input2_mult);
tmp1 = esp_nn_div_by_power_of_two(tmp1, -input1_shift);
tmp2 = esp_nn_div_by_power_of_two(tmp2, -input2_shift);
int32_t out = tmp1 + tmp2;
out = esp_nn_sat_round_doubling_high_mul(out, out_mult);
out = esp_nn_div_by_power_of_two(out, -out_shift);
out = out + out_offset;
out = max(activation_min, min(out, activation_max));
output[i] = (uint8_t) out;
}
}
void esp_nn_add_elementwise_s8_ansi(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
const int32_t input1_mult,
const int32_t input2_mult,
const int32_t input1_shift,
const int32_t input2_shift,
const int32_t left_shift,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size)
{
for (int i = 0; i < size; i++) {
int32_t tmp1 = input1_data[i] + input1_offset;
int32_t tmp2 = input2_data[i] + input2_offset;
tmp1 <<= left_shift;
tmp2 <<= left_shift;
tmp1 = esp_nn_sat_round_doubling_high_mul(tmp1, input1_mult);
tmp2 = esp_nn_sat_round_doubling_high_mul(tmp2, input2_mult);
tmp1 = esp_nn_div_by_power_of_two(tmp1, -input1_shift);
tmp2 = esp_nn_div_by_power_of_two(tmp2, -input2_shift);
int32_t out = tmp1 + tmp2;
out = esp_nn_sat_round_doubling_high_mul(out, out_mult);
out = esp_nn_div_by_power_of_two(out, -out_shift);
out = out + out_offset;
out = max(activation_min, min(out, activation_max));
output[i] = (int8_t) out;
}
}

View File

@ -0,0 +1,42 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
const int8_t *input2_data,
const int32_t input1_offset,
const int32_t input2_offset,
int8_t *output,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
const int32_t activation_min,
const int32_t activation_max,
const int32_t size)
{
for (int i = 0; i < size; i++) {
int32_t tmp1 = input1_data[i] + input1_offset;
int32_t tmp2 = input2_data[i] + input2_offset;
int32_t out = tmp1 * tmp2;
out = esp_nn_multiply_by_quantized_mult(out, out_mult, out_shift);
out = out + out_offset;
out = max(activation_min, min(out, activation_max));
output[i] = (int8_t) out;
}
}

View File

@ -0,0 +1,255 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
/**
* c99 standard still doesn't strictly inline functions
* We need to use attribute as well to do this.
*/
#define __NN_FORCE_INLINE__ __attribute((always_inline)) static inline
/* min/max macros */
#ifndef max
#define max(a, b) ({ \
__typeof__ (a) _a = (a); \
__typeof__ (b) _b = (b); \
_a > _b ? _a : _b; \
})
#define min(a, b) ({ \
__typeof__ (a) _a = (a); \
__typeof__ (b) _b = (b); \
_a < _b ? _a : _b; \
})
#endif
__NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
{
#if CONFIG_IDF_TARGET_ARCH_XTENSA
__asm__ volatile("nsau %0, %0" : "+r" (in));
return in;
#elif defined(__GNUC__)
return __builtin_clz(in);
#else
int32_t count = 32;
uint32_t x = in, y = in >> 16;
if (y != 0) {
count -= 16;
x = y;
}
y = x >> 8;
if (y != 0) {
count -= 8;
x = y;
}
y = x >> 4;
if (y != 0) {
count -= 4;
x = y;
}
y = x >> 2;
if (y != 0) {
count -= 2;
x = y;
}
y = x >> 1;
if (y != 0) {
return count - 2;
}
return count - x;
#endif
}
/**
* Signed saturate a 32 bit value to 8 bits keeping output in 32 bit variable.
*/
__NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
{
#if CONFIG_IDF_TARGET_ARCH_XTENSA
__asm__ volatile("clamps %0, %0, 7" : "+a"(in));
return in;
#else
return max(INT8_MIN, min(in, INT8_MAX));
#endif
}
__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
{
int32_t sign = (int32_t) (val64 >> 63);
int32_t to_add = sign & ((1ul << 31) - 1);
return (int32_t) ((int64_t) (val64 + to_add) >> 31);
}
__NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
{
int32_t result;
int64_t in0_64 = (int64_t) in0;
bool overflow = (in0 == in1) && (in0 == (int32_t) INT32_MIN);
/* Nudge value */
int64_t nudge_val = 1 << 30;
if ((in0 < 0) ^ (in1 < 0)) {
nudge_val = 1 - nudge_val;
}
/* Multiply and add nudge */
int64_t mult = in0_64 * in1 + nudge_val;
/* Round and pickup 32 bits */
result = esp_nn_pick_sat_high32_of64(mult);
return overflow ? INT32_MAX : result;
}
/**
* fast version
* this will fail for values closer to INT32_MAX and INT32_MIN by `1 << (exponent - 1)`.
* We can afford to do this because we are at the very last stage of filter.
* Also it is pretty rare condition as our output is going to be 8 bit.
*/
__NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two_fast(int32_t val, int32_t exponent)
{
int32_t to_add = (1 << (exponent - 1)) - (val < 0);
return (int32_t) ((val + to_add) >> exponent);
}
__NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two(int32_t val, int32_t exponent)
{
int32_t result;
const int32_t mask = (1 << exponent) - 1;
const int32_t remainder = val & mask;
result = val >> exponent;
int32_t threshold = (mask >> 1) + (result < 0);
if (remainder > threshold) {
result += 1;
}
return result;
}
__NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult(int32_t x, int32_t mult, int32_t shift)
{
int32_t left_shift = shift > 0 ? shift : 0;
int32_t right_shift = shift > 0 ? 0 : -shift;
int32_t result = esp_nn_sat_round_doubling_high_mul(x * (1 << left_shift), mult);
return esp_nn_div_by_power_of_two(result, right_shift);
}
__NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult_fast(int32_t x, int32_t mult, int32_t shift)
{
int32_t left_shift = max(shift, 0);
int32_t right_shift = left_shift - shift;
int64_t nudge_val = 1 << 30;
int64_t in0_64 = (int64_t) (x << left_shift);
/* Multiply and add nudge */
int64_t mult_64 = in0_64 * mult + nudge_val;
int32_t result = (int32_t) (mult_64 >> 31);
if (right_shift) {
result = esp_nn_div_by_power_of_two_fast(result, right_shift);
}
return result;
}
static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const int32_t pad_val,
const uint16_t pad_wd,
const uint16_t pad_ht)
{
/* memset with pad_val */
memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
dst += (pad_wd + input_wd + pad_wd) * channels;
for (int i = 0; i < input_ht; i++) {
dst += pad_wd * channels;
for (int j = 0; j < input_wd * channels; j++) {
*dst++ = *src++;
}
dst += pad_wd * channels;
}
}
static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const int32_t pad_val,
const uint16_t pad_wd,
const uint16_t pad_ht)
{
for (int i = 0; i < input_ht; i++) {
for (int j = 0; j < input_wd * channels; j++) {
*dst++ = *src++;
}
if (pad_wd) {
memset(dst, pad_val, pad_wd * channels);
dst += pad_wd * channels;
}
}
/* pad end `pad_ht` lines at end */
if (pad_ht) {
memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
}
}
/**
* @brief convert 8 bit input data to 16 bit
*
* @param src int8_t source data
* @param dst int16_t dst data
* @param size length of data
* @param offset offset to be added to src data. Range: [-128, 127]
*/
__NN_FORCE_INLINE__ void esp_nn_s8_to_s16_with_offset(const int8_t *src, int16_t *dst,
const int size, const int32_t offset)
{
int i = 0;
for (; i < size; i += 2) {
dst[i + 0] = src[i + 0] + offset;
dst[i + 1] = src[i + 1] + offset;
}
if(i < size) {
dst[i] = src[i] + offset;
}
}
/**
* @brief convert 8 bit input data to 16 bit
*
* @param src int8_t source data
* @param dst int16_t dst data
* @param size length of data
*/
__NN_FORCE_INLINE__ void esp_nn_s8_to_s16(const int8_t *src, int16_t *dst, const int size)
{
int i = 0;
for (; i < size; i += 2) {
dst[i + 0] = src[i + 0];
dst[i + 1] = src[i + 1];
}
if(i < size) {
dst[i] = src[i];
}
}

View File

@ -0,0 +1,179 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params)
{
return 0;
}
void esp_nn_set_conv_scratch_buf_ansi(const void *buf)
{
}
/**
* Assumption 1: i/p channels == o/p channels
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
void esp_nn_conv_u8_ansi(const uint8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const int32_t input_offset,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint8_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t filter_offset,
const int32_t *bias,
uint8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t out_shift,
const int32_t out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
for (int out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {//channel_loop
int32_t result = 0;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
for (int in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
int32_t input_index = (idx_y * input_wd + idx_x) * in_channels + in_ch_idx;
int32_t filter_index = ((out_ch_idx * filter_ht + filter_y_idx)
* filter_wd + filter_x_idx) * in_channels
+ in_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index] + filter_offset;
result += input_val * filter_val;
}
}
}
if (bias) {
result += bias[out_ch_idx];
}
result = esp_nn_multiply_by_quantized_mult(result, out_mult, out_shift);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
int out_index = (out_y * out_wd + out_x) * out_channels + out_ch_idx;
out_data[out_index] = (uint8_t) result;
}
}
}
}
/**
* Assumption 1: i/p channels == o/p channels
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t in_channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t out_channels = output_dims->channels;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
const int32_t base_y = stride_ht * out_y - pad_ht;
const int32_t base_x = stride_wd * out_x - pad_wd;
const int32_t filter_y_start = max(0, -base_y);
const int32_t filter_x_start = max(0, -base_x);
const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
(filter_y_idx * filter_wd + filter_x_idx) * in_channels;
for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
conv_out +=
(input_data[input_base_offset + in_ch_idx] + input_offset) *
filter_data[filter_base_offset + in_ch_idx];
}
}
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}

View File

@ -0,0 +1,463 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
static int16_t *scratch_buffer = NULL;
extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const int32_t input_offset,
const int8_t *filter_aligned,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max,
void *buffer /* scratch buffer */);
extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const int16_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max,
void *buffer /* scratch buffer */);
extern void esp_nn_conv_s16_mult8_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int16_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
const int size, const int32_t offset);
extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t in_ch = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t out_ch = output_dims->channels;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
int32_t conv_out = 0;
const int32_t base_y = stride_ht * out_y - pad_ht;
const int32_t base_x = stride_wd * out_x - pad_wd;
const int32_t filter_y_start = max(0, -base_y);
const int32_t filter_x_start = max(0, -base_x);
const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
(filter_y_idx * filter_wd + filter_x_idx) * in_ch;
for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
conv_out +=
(input_data[input_base_offset + in_ch_idx] + input_offset) *
filter_data[filter_base_offset + in_ch_idx];
}
}
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}
static void esp_nn_conv_s8_pad_valid(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const int32_t input_offset,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int8_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
const int32_t base_y = stride_ht * out_y;
const int32_t base_x = stride_wd * out_x;
for (filter_y_idx = 0; filter_y_idx < filter_ht; filter_y_idx++) {
for (filter_x_idx = 0; filter_x_idx < filter_wd; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
(filter_y_idx * filter_wd + filter_x_idx) * in_channels;
const int8_t *input_data_ptr = input_data + input_base_offset;
const int8_t *filter_data_ptr = filter_data + filter_base_offset;
for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
}
}
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}
static void esp_nn_conv_s8_pad_valid_3x3(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t in_channels,
const int32_t input_offset,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int8_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
const int32_t base_y = stride_ht * out_y;
const int32_t base_x = stride_wd * out_x;
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
for (filter_x_idx = 0; filter_x_idx < 3; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
int32_t filter_base_offset = out_ch_idx * in_channels * 3 * 3 +
(filter_y_idx * 3 + filter_x_idx) * in_channels;
const int8_t *input_data_ptr = input_data + input_base_offset;
const int8_t *filter_data_ptr = filter_data + filter_base_offset;
for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
}
}
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}
static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const int32_t input_offset,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int8_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
int32_t out_ch_idx, out_y, out_x, filter_y_idx;
/* use scratch_buffer to pre-compute offset factor */
int16_t *filter_sum = (int16_t *) scratch_buffer;
const int8_t *filter_ptr = filter_data;
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int16_t sum_val = 0;
for (int i = 0; i < 9; i++) {
sum_val += *filter_ptr++;
sum_val += *filter_ptr++;
sum_val += *filter_ptr++;
}
*filter_sum++ = sum_val;
}
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
const int8_t *filter_data_ptr = filter_data;
const int32_t base_y = stride_ht * out_y;
const int32_t base_x = stride_wd * out_x;
const int8_t *input_base_ptr = input_data + (base_y * input_wd + base_x) * 3;
int16_t *filter_sum = (int16_t *) scratch_buffer;
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
const int8_t *input_data_ptr = input_base_ptr + (filter_y_idx * input_wd) * 3;
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
}
conv_out += *filter_sum++ * input_offset;
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}
int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t in_ch = input_dims->channels;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_ch = output_dims->channels;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
int filter_size = filter_wd * filter_ht * in_ch * out_ch;
int input_size = input_wd * input_ht * in_ch;
int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
if (input_wd * input_ht < 8) {
transpose_buf_size = 0; // not using this for leftover
}
int align_buf_size = 32; /* extra buffer for alignment */
if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
return filter_size + transpose_buf_size + align_buf_size;
}
return 2 * (filter_size + input_size) + transpose_buf_size + align_buf_size;
}
void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
{
scratch_buffer = (int16_t *) buf;
}
void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
const int8_t *input,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t out_channels = output_dims->channels;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int filter_size = filter_wd * filter_ht * channels * out_channels;
int input_size = input_wd * input_ht * channels;
int align_len = 16 - (filter_size & 15);
int16_t *filter_data16 = scratch_buffer;
int16_t *input_data16 = scratch_buffer + filter_size + align_len;
if (scratch_buffer == NULL) {
printf("esp_nn_conv error! scratch_buffer not set!\n");
return;
}
if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
int8_t *filter_aligned = (int8_t *) scratch_buffer;
int scratch_offset = (int) (filter_aligned + filter_size);
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
esp_nn_conv_s8_mult8_1x1_esp32s3(
input, input_wd, input_ht, channels, input_offset, filter_aligned,
bias, out_data, out_wd, out_ht, out_channels, out_offset,
out_shift, out_mult, activation_min, activation_max, scratch_buf);
} else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
(input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
int scratch_offset = (int) (input_data16 + input_size);
void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
esp_nn_conv_s16_mult4_1x1_esp32s3(
input_data16, input_wd, input_ht, channels, filter_data16,
bias, out_data, out_wd, out_ht, out_channels, out_offset,
out_shift, out_mult, activation_min, activation_max, scratch_buf);
} else if (channels % 8 == 0) {
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
esp_nn_conv_s16_mult8_esp32s3(
input_data16, input_wd, input_ht, channels, pad_wd, pad_ht,
stride_wd, stride_ht, filter_data16, filter_wd, filter_ht, bias,
out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
out_mult, activation_min, activation_max);
} else if (pad_wd == 0 && pad_ht == 0) {
if (filter_wd == 3 && filter_ht == 3 && channels == 3) {
esp_nn_conv_s8_pad_valid_ch3_3x3(input, input_wd, input_ht, input_offset,
stride_wd, stride_ht, filter_data, bias,
out_data, out_wd, out_ht, out_channels, out_offset,
out_shift, out_mult, activation_min, activation_max);
} else {
esp_nn_conv_s8_pad_valid(input, input_wd, input_ht, channels, input_offset,
stride_wd, stride_ht, filter_data, filter_wd, filter_ht, bias,
out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
out_mult, activation_min, activation_max);
}
} else {
/* Basic unrolled version */
esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
bias, output_dims, out_data, conv_params, quant_data);
}
}

View File

@ -0,0 +1,179 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const conv_params_t *conv_params)
{
return 0;
}
void esp_nn_set_conv_scratch_buf_opt(const void *buf)
{
}
__attribute__ ((noinline))
static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
const int8_t *input_data,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t in_channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t out_channels = output_dims->channels;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
const int32_t *out_mult = quant_data->mult;
const int32_t *out_shift = quant_data->shift;
const int8_t *filter_ptr = filter_data;
const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
int32_t out_ch_idx = 0;
for (; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
const int8_t *input_ptr = input_base_ptr;
int32_t in_ch_idx = 0;
for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
}
for (; in_ch_idx < in_channels; in_ch_idx ++) {
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}
/**
* Assumption 1: i/p channels == o/p channels
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
if (filter_wd == 1 && filter_ht == 1) {
esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
output_dims, out_data, conv_params, quant_data);
return;
}
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t in_channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t out_channels = output_dims->channels;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
for (out_y = 0; out_y < out_ht; out_y++) {
for (out_x = 0; out_x < out_wd; out_x++) {
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
int32_t conv_out = 0;
const int32_t base_y = stride_ht * out_y - pad_ht;
const int32_t base_x = stride_wd * out_x - pad_wd;
const int32_t filter_y_start = max(0, -base_y);
const int32_t filter_x_start = max(0, -base_x);
const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t in_row = base_y + filter_y_idx;
const int32_t in_col = base_x + filter_x_idx;
const int8_t *input_ptr = input_data +
(in_row * input_wd + in_col) * in_channels;
const int8_t *filter_ptr = filter_data +
out_ch_idx * in_channels * filter_ht * filter_wd +
(filter_y_idx * filter_wd + filter_x_idx) * in_channels;
int32_t in_ch_idx = 0;
for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
}
for (; in_ch_idx < in_channels; in_ch_idx ++) {
conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
}
}
}
if (bias) {
conv_out += bias[out_ch_idx];
}
conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
conv_out += out_offset;
conv_out = max(conv_out, activation_min);
conv_out = min(conv_out, activation_max);
*out_data++ = (int8_t) conv_out;
}
}
}
}

View File

@ -0,0 +1,100 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params)
{
return 0;
}
void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
{
}
void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
const uint16_t ch_mult = conv_params->ch_mult;
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
for (int ch_mult_idx = 0; ch_mult_idx < ch_mult; ch_mult_idx++) {
int32_t result = 0;
const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index];
result += input_val * filter_val;
}
}
if (bias) {
result += bias[out_ch_idx];
}
result = esp_nn_multiply_by_quantized_mult(result, out_mult[out_ch_idx], out_shift[out_ch_idx]);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_idx++] = result;
}
}
}
}
}

View File

@ -0,0 +1,291 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params)
{
return 0;
}
void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
{
}
/* common channel multiplier == 1 case */
__attribute__ ((noinline))
static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
int ch_idx = 0;
for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
int32_t result0 = 0;
int32_t result1 = 0;
int32_t result2 = 0;
int32_t result3 = 0;
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
int32_t input_val0 = input_data[input_index + 0] + input_offset;
int32_t input_val1 = input_data[input_index + 1] + input_offset;
int32_t input_val2 = input_data[input_index + 2] + input_offset;
int32_t input_val3 = input_data[input_index + 3] + input_offset;
int32_t filter_val0 = filter_data[filter_index + 0];
int32_t filter_val1 = filter_data[filter_index + 1];
int32_t filter_val2 = filter_data[filter_index + 2];
int32_t filter_val3 = filter_data[filter_index + 3];
result0 += input_val0 * filter_val0;
result1 += input_val1 * filter_val1;
result2 += input_val2 * filter_val2;
result3 += input_val3 * filter_val3;
}
}
if (bias) {
result0 += bias[ch_idx + 0];
result1 += bias[ch_idx + 1];
result2 += bias[ch_idx + 2];
result3 += bias[ch_idx + 3];
}
result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
result0 += out_offset;
result1 += out_offset;
result2 += out_offset;
result3 += out_offset;
result0 = max(result0, activation_min);
result1 = max(result1, activation_min);
result2 = max(result2, activation_min);
result3 = max(result3, activation_min);
result0 = min(result0, activation_max);
result1 = min(result1, activation_max);
result2 = min(result2, activation_max);
result3 = min(result3, activation_max);
out_data[out_idx++] = result0;
out_data[out_idx++] = result1;
out_data[out_idx++] = result2;
out_data[out_idx++] = result3;
}
for (; ch_idx < channels; ch_idx++) {//channel_loop
int32_t result = 0;
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index];
result += input_val * filter_val;
}
}
if (bias) {
result += bias[ch_idx];
}
result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_idx++] = result;
}
}
}
}
void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t ch_mult = conv_params->ch_mult;
if (ch_mult == 1) {
esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
bias, output_dims, out_data, conv_params, quant_data);
return;
}
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
int ch_mult_idx = 0;
for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
int32_t result0 = 0;
int32_t result1 = 0;
int32_t result2 = 0;
int32_t result3 = 0;
const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val0 = filter_data[filter_index + 0];
int32_t filter_val1 = filter_data[filter_index + 1];
int32_t filter_val2 = filter_data[filter_index + 2];
int32_t filter_val3 = filter_data[filter_index + 3];
result0 += input_val * filter_val0;
result1 += input_val * filter_val1;
result2 += input_val * filter_val2;
result3 += input_val * filter_val3;
}
}
if (bias) {
result0 += bias[out_ch_idx + 0];
result1 += bias[out_ch_idx + 1];
result2 += bias[out_ch_idx + 2];
result3 += bias[out_ch_idx + 3];
}
result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
result0 += out_offset;
result1 += out_offset;
result2 += out_offset;
result3 += out_offset;
result0 = max(result0, activation_min);
result1 = max(result1, activation_min);
result2 = max(result2, activation_min);
result3 = max(result3, activation_min);
result0 = min(result0, activation_max);
result1 = min(result1, activation_max);
result2 = min(result2, activation_max);
result3 = min(result3, activation_max);
out_data[out_idx++] = result0;
out_data[out_idx++] = result1;
out_data[out_idx++] = result2;
out_data[out_idx++] = result3;
}
for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
int32_t result = 0;
const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index];
result += input_val * filter_val;
}
}
if (bias) {
result += bias[out_ch_idx];
}
result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_idx++] = result;
}
}
}
}
}

View File

@ -0,0 +1,543 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include "../../include/esp_nn_defs.h"
#include "../common/common_functions.h"
static int16_t *scratch_buffer = NULL;
extern void esp_nn_depthwise_conv_s16_mult8_3x3_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t ch_mult,
const int16_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const int32_t input_offset,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int8_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s16_mult1_3x3_no_pad_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int16_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s16_mult8_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t ch_mult,
const int16_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s16_mult4_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t ch_mult,
const int16_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int16_t *filter_data,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_depthwise_conv_s16_mult1_esp32s3(const int16_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int16_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max);
extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
const int size, const int32_t offset);
static void esp_nn_depthwise_conv_s8_unrolled(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const int32_t input_offset,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t ch_mult,
const int8_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
int ch_mult_idx = 0;
for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
int32_t result0 = 0, result1 = 0, result2 = 0, result3 = 0;
const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val0 = filter_data[filter_index + 0];
int32_t filter_val1 = filter_data[filter_index + 1];
int32_t filter_val2 = filter_data[filter_index + 2];
int32_t filter_val3 = filter_data[filter_index + 3];
result0 += input_val * filter_val0;
result1 += input_val * filter_val1;
result2 += input_val * filter_val2;
result3 += input_val * filter_val3;
}
}
if (bias) {
result0 += bias[out_ch_idx + 0];
result1 += bias[out_ch_idx + 1];
result2 += bias[out_ch_idx + 2];
result3 += bias[out_ch_idx + 3];
}
result0 = esp_nn_multiply_by_quantized_mult(result0,
out_mult[out_ch_idx + 0], out_shift[out_ch_idx + 0]);
result1 = esp_nn_multiply_by_quantized_mult(result1,
out_mult[out_ch_idx + 1], out_shift[out_ch_idx + 1]);
result2 = esp_nn_multiply_by_quantized_mult(result2,
out_mult[out_ch_idx + 2], out_shift[out_ch_idx + 2]);
result3 = esp_nn_multiply_by_quantized_mult(result3,
out_mult[out_ch_idx + 3], out_shift[out_ch_idx + 3]);
result0 += out_offset;
result1 += out_offset;
result2 += out_offset;
result3 += out_offset;
result0 = max(result0, activation_min);
result1 = max(result1, activation_min);
result2 = max(result2, activation_min);
result3 = max(result3, activation_min);
result0 = min(result0, activation_max);
result1 = min(result1, activation_max);
result2 = min(result2, activation_max);
result3 = min(result3, activation_max);
out_data[out_idx++] = result0;
out_data[out_idx++] = result1;
out_data[out_idx++] = result2;
out_data[out_idx++] = result3;
}
/* left-over */
for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
int32_t result = 0;
const int out_ch_idx = ch_mult_idx + ch_idx * ch_mult;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index];
result += input_val * filter_val;
}
}
if (bias) {
result += bias[out_ch_idx];
}
result = esp_nn_multiply_by_quantized_mult(result, out_mult[out_ch_idx], out_shift[out_ch_idx]);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_idx++] = result;
}
}
}
}
}
void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
const uint16_t input_wd,
const uint16_t input_ht,
const uint16_t channels,
const int32_t input_offset,
const uint16_t pad_wd,
const uint16_t pad_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const int8_t *filter_data,
const uint16_t filter_wd,
const uint16_t filter_ht,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_wd,
const uint16_t out_ht,
const int32_t out_offset,
const int32_t *out_shift,
const int32_t *out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
int out_idx = 0;
for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
const int16_t base_y = (out_y * stride_ht) - pad_ht;
for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
const int16_t base_x = (out_x * stride_wd) - pad_wd;
for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
int32_t result = 0;
/* Select filter so as the point doesn't lie outside block */
int filter_y_start = max(0, -base_y);
int filter_x_start = max(0, -base_x);
int filter_y_end = min(filter_ht, input_ht - base_y);
int filter_x_end = min(filter_wd, input_wd - base_x);
for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
const int32_t idx_y = base_y + filter_y_idx;
for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
const int32_t idx_x = base_x + filter_x_idx;
int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * channels + ch_idx;
int32_t input_val = input_data[input_index] + input_offset;
int32_t filter_val = filter_data[filter_index];
result += input_val * filter_val;
}
}
if (bias) {
result += bias[ch_idx];
}
result = esp_nn_multiply_by_quantized_mult(result, out_mult[ch_idx], out_shift[ch_idx]);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_idx++] = result;
}
}
}
}
int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
const data_dims_t *filter_dims,
const data_dims_t *output_dims,
const dw_conv_params_t *conv_params)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t ch_mult = conv_params->ch_mult;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
int filter_size = filter_wd * filter_ht * channels * ch_mult;
int pad_width = 0, pad_height = 0;
if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) {
if (channels % 16 == 0) {
if (pad_wd || pad_ht) {
pad_width = pad_wd * 2;
pad_height = pad_ht * 2;
} else {
// check if we need to pad additionally
pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd;
pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht;
// printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)",
// input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht,
// stride_wd, stride_ht, pad_wd, pad_ht);
}
if (pad_width || pad_height) {
int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels;
// printf("ask1 %d\n", filter_size + input_size + 16);
return filter_size + input_size + 16; // 16 for alignment
} else {
// printf("ask2 %d\n", filter_size + 16);
return filter_size + 16; // 16 for alignment
}
} else {
int input_size = input_wd * input_ht * channels;
// printf("ask3 %d\n", 2 * (filter_size + input_size) + 16);
return 2 * (filter_size + input_size) + 16; // 16 for alignment
}
} else if (ch_mult % 4 == 0) {
int input_size = input_wd * input_ht * channels;
// printf("ask4 %d\n", 2 * (filter_size + input_size) + 16);
return 2 * (filter_size + input_size) + 16; // 16 for alignment
}
return 32; // just few bytes
}
void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
{
scratch_buffer = (int16_t *) buf;
}
/**
* Assumption 1: i/p channels == o/p channels
* Assumption 2: Pointers are valid
* Assumption 3: dialation width = 1
*/
void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
const int8_t *input_data,
const data_dims_t *filter_dims,
const int8_t *filter_data,
const int32_t *bias,
const data_dims_t *output_dims,
int8_t *out_data,
const dw_conv_params_t *conv_params,
const quant_data_t *quant_data)
{
const uint16_t input_wd = input_dims->width;
const uint16_t input_ht = input_dims->height;
const uint16_t channels = input_dims->channels;
const int32_t input_offset = conv_params->in_offset;
const int32_t out_offset = conv_params->out_offset;
const uint16_t pad_wd = conv_params->padding.width;
const uint16_t pad_ht = conv_params->padding.height;
const uint16_t stride_wd = conv_params->stride.width;
const uint16_t stride_ht = conv_params->stride.height;
const uint16_t filter_wd = filter_dims->width;
const uint16_t filter_ht = filter_dims->height;
const uint16_t out_wd = output_dims->width;
const uint16_t out_ht = output_dims->height;
const int32_t *out_shift = quant_data->shift;
const int32_t *out_mult = quant_data->mult;
const int32_t activation_min = conv_params->activation.min;
const int32_t activation_max = conv_params->activation.max;
const uint16_t ch_mult = conv_params->ch_mult;
int filter_size = filter_wd * filter_ht * channels * ch_mult;
int align_len = 16 - (filter_size & 15);
int input_size = input_wd * input_ht * channels;
int16_t *filter_data16 = scratch_buffer;
int16_t *input_data16 = scratch_buffer + filter_size + align_len;
if (scratch_buffer == NULL) {
printf("esp_nn_depthwise_conv error! scratch_buffer not set!\n");
return;
}
if ((ch_mult == 1) && (channels % 8 == 0)) {
if ((filter_wd == 3) && (filter_ht == 3)) {
if ((channels % 16 == 0) && (pad_wd == 1) && (pad_ht == 1)) {
/* process in 8 bits */
int8_t *filter_aligned = (int8_t *) scratch_buffer;
int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
memcpy(filter_aligned, filter_data, filter_size);
esp_nn_aligned_s8_pad_with_value(input_data, input_padded, input_wd, input_ht, channels,
-input_offset, pad_wd, pad_ht);
esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + 2 * pad_wd,
input_ht + 2 * pad_ht, channels, input_offset,
stride_wd, stride_ht, filter_aligned, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
} else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) {
/* process in 8 bits */
int8_t *filter_aligned = (int8_t *) scratch_buffer;
int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
// check if we need to pad additionally
int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd;
int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht;
if (pad_right || pad_bottom) { // pad right and bottom
esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht,
channels, -input_offset, pad_right, pad_bottom);
} else {
input_padded = (int8_t *) input_data;
}
memcpy(filter_aligned, filter_data, filter_size);
esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right,
input_ht + pad_bottom, channels, input_offset,
stride_wd, stride_ht, filter_aligned, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
} else { /* (channels % 8) == 0 */
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
pad_wd, pad_ht, stride_wd, stride_ht, filter_data16,
bias, out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
}
} else { // all other ch_mult == 1, `channels % 8 == 0`
esp_nn_depthwise_conv_s8_ch_mult1(input_data, input_wd, input_ht, channels, input_offset,
pad_wd, pad_ht, stride_wd, stride_ht,
filter_data, filter_wd, filter_ht,
bias, out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
}
} else if (ch_mult % 8 == 0) {
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
if (filter_wd == 3 && filter_ht == 3) {
esp_nn_depthwise_conv_s16_mult8_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
filter_data16, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
} else {
esp_nn_depthwise_conv_s16_mult8_esp32s3(input_data16, input_wd, input_ht, channels,
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
filter_data16, filter_wd, filter_ht, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
}
} else if (ch_mult % 4 == 0) {
esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
esp_nn_depthwise_conv_s16_mult4_esp32s3(input_data16, input_wd, input_ht, channels,
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
filter_data16, filter_wd, filter_ht, bias,
out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
} else {
esp_nn_depthwise_conv_s8_unrolled(input_data, input_wd, input_ht, channels, input_offset,
pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
filter_data, filter_wd, filter_ht,
bias, out_data, out_wd, out_ht, out_offset, out_shift,
out_mult, activation_min, activation_max);
}
}

View File

@ -0,0 +1,50 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
void esp_nn_fully_connected_s8_ansi(const int8_t *input_data,
const int32_t input_offset,
const uint16_t row_len,
const int8_t *filter_data,
const int32_t filter_offset,
const int32_t *bias,
int8_t *out_data,
const uint16_t out_channels,
const int32_t out_offset,
const int32_t out_shift,
const int32_t out_mult,
const int32_t activation_min,
const int32_t activation_max)
{
for (int32_t out_c = 0; out_c < out_channels; ++out_c) {
int32_t result = 0;
for (int32_t data_idx = 0; data_idx < row_len; data_idx++) {
int32_t filter_index = row_len * out_c + data_idx;
int32_t input_val = input_data[data_idx];
int32_t filter_val = filter_data[filter_index];
result += (filter_val + filter_offset) * (input_val + input_offset);
}
if (bias) {
result += bias[out_c];
}
result = esp_nn_multiply_by_quantized_mult(result, out_mult, out_shift);
result += out_offset;
result = max(result, activation_min);
result = min(result, activation_max);
out_data[out_c] = (int8_t) result;
}
}

View File

@ -0,0 +1,72 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
void esp_nn_avg_pool_s8_ansi(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels)
{
int32_t base_y = -pad_ht;
for (int32_t out_y = 0; out_y < output_ht; out_y++, base_y += stride_ht) {
int32_t base_x = -pad_wd;
for (int32_t out_x = 0; out_x < output_wd; out_x++, base_x += stride_wd) {
for (int32_t ch_idx = 0; ch_idx < channels; ch_idx++) {
int32_t result = 0;
int32_t filter_cnt = 0;
/* Make sure filter does not cross the input box */
int32_t filter_y_start = max(0, -base_y);
int32_t filter_x_start = max(0, -base_x);
int32_t filter_y_end = min(filter_ht, input_ht - base_y);
int32_t filter_x_end = min(filter_wd, input_wd - base_x);
for (int32_t filter_y = filter_y_start; filter_y < filter_y_end; filter_y++) {
for (int32_t filter_x = filter_x_start; filter_x < filter_x_end; filter_x++) {
int32_t in_x_idx = base_x + filter_x;
int32_t in_y_idx = base_y + filter_y;
int32_t input_index = (in_y_idx * input_wd + in_x_idx) * channels + ch_idx;
result += input[input_index];
filter_cnt++;
}
}
/* Rounded average */
result = result > 0 ? (result + filter_cnt / 2) / filter_cnt
: (result - filter_cnt / 2) / filter_cnt;
/* Activation function */
result = max(result, activation_min);
result = min(result, activation_max);
int32_t output_index = (out_y * output_wd + out_x) * channels + ch_idx;
output[output_index] = (int8_t) result;
}
}
}
}

View File

@ -0,0 +1,66 @@
// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
void esp_nn_max_pool_s8_ansi(const int8_t *input,
const uint16_t input_wd,
const uint16_t input_ht,
int8_t *output,
const uint16_t output_wd,
const uint16_t output_ht,
const uint16_t stride_wd,
const uint16_t stride_ht,
const uint16_t filter_wd,
const uint16_t filter_ht,
const uint16_t pad_wd,
const uint16_t pad_ht,
const int32_t activation_min,
const int32_t activation_max,
const uint16_t channels)
{
int32_t base_y = -pad_ht;
for (int32_t out_y = 0; out_y < output_ht; out_y++, base_y += stride_ht) {
int32_t base_x = -pad_wd;
for (int32_t out_x = 0; out_x < output_wd; out_x++, base_x += stride_wd) {
/* Make sure filter does not cross the input box */
int32_t filter_y_start = max(0, -base_y);
int32_t filter_x_start = max(0, -base_x);
int32_t filter_y_end = min(filter_ht, input_ht - base_y);
int32_t filter_x_end = min(filter_wd, input_wd - base_x);
for (int32_t ch_idx = 0; ch_idx < channels; ch_idx++) {
int8_t result = INT8_MIN;
for (int32_t filter_y = filter_y_start; filter_y < filter_y_end; filter_y++) {
for (int32_t filter_x = filter_x_start; filter_x < filter_x_end; filter_x++) {
int32_t in_x_idx = base_x + filter_x;
int32_t in_y_idx = base_y + filter_y;
int32_t input_index = (in_y_idx * input_wd + in_x_idx) * channels + ch_idx;
result = max(input[input_index], result);
}
}
/* Activation function */
result = max(result, activation_min);
result = min(result, activation_max);
int32_t output_index = (out_y * output_wd + out_x) * channels + ch_idx;
output[output_index] = result;
}
}
}
}

View File

@ -0,0 +1,88 @@
// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "softmax_common.h"
int32_t esp_nn_get_softmax_scratch_size_ansi(const int32_t width, const int32_t height)
{
(void) width;
(void) height;
return 0;
}
void esp_nn_set_softmax_scratch_buf_ansi(void *buffer)
{
(void) buffer;
return;
}
void esp_nn_softmax_s8_ansi(const int8_t *input_data,
const int32_t height,
const int32_t width,
const int32_t mult,
const int32_t shift,
const int32_t diff_min,
int8_t *output_data)
{
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input mult, and therefore as large as
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
// accumulation, but exp(-16) definitely is.
#define ACCUM_BITS 12
#define DIFF_BITS 5
const int32_t mask = (1 << shift);
int32_t col = 0;
const int8_t *in_ptr = input_data;
int8_t *out_ptr = output_data;
for (int row_idx = 0; row_idx < height; row_idx++) {
int8_t max_in_row = in_ptr[0];
for (col = 1; col < width; col++) {
max_in_row = max(max_in_row, in_ptr[col]);
}
int32_t input_diff = 0;
int32_t sum_of_exps = 0;
for (col = 0; col < width; col++) {
input_diff = in_ptr[col] - max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
sum_of_exps += DIV_POW2(exp_raw, ACCUM_BITS);
}
}
const int32_t headroom_plus1 = esp_nn_clz32((uint32_t) sum_of_exps);
const int32_t shifted_scale = ONE_OVER_ONE_X((sum_of_exps << headroom_plus1) - (1 << 31));
const int32_t bits_over_unit = ACCUM_BITS - headroom_plus1 + 31 - sizeof(int8_t) * 8;
for (col = 0; col < width; col++) {
input_diff = in_ptr[col] - max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
const int32_t shifted_output = SAT_HIGH_MUL(shifted_scale, exp_raw);
const int32_t result = DIV_POW2(shifted_output, bits_over_unit) - 128;
out_ptr[col] = (int8_t) esp_nn_saturate8(result);
} else {
out_ptr[col] = -128;
}
}
in_ptr += width;
out_ptr += width;
}
}

View File

@ -0,0 +1,108 @@
// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "softmax_common.h"
#include <stdio.h>
static int32_t *scratch_buf = NULL;
/**
* @brief Get scratch buffer size needed by softmax function
*
* @param width
* @param height
* @return size in bytes
*
* @note buffer must be 4 byte aligned
*/
int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t height)
{
(void) height;
return width * 4;
}
/**
* @brief Set scratch buffer to be used by softmax function
*
* @param buffer this can be NULL if one needs to unset it
* must be aligned to 4 bytes
*/
void esp_nn_set_softmax_scratch_buf_opt(void *buffer)
{
scratch_buf = (int32_t *) buffer;
}
void esp_nn_softmax_s8_opt(const int8_t *input_data,
const int32_t height,
const int32_t width,
const int32_t mult,
const int32_t shift,
const int32_t diff_min,
int8_t *output_data)
{
if (scratch_buf == NULL) {
printf("%s error! scratch buffer not set\n", __FUNCTION__);
return;
}
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input mult, and therefore as large as
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
// accumulation, but exp(-16) definitely is.
#define ACCUM_BITS 12
#define DIFF_BITS 5
const int32_t mask = (1 << shift);
int32_t col = 0;
const int8_t *in_ptr = input_data;
int8_t *out_ptr = output_data;
for (int row_idx = 0; row_idx < height; row_idx++) {
int8_t max_in_row = in_ptr[0];
for (col = 1; col < width; col++) {
max_in_row = max(max_in_row, in_ptr[col]);
}
int32_t input_diff = 0;
int32_t sum_of_exps = 0;
for (col = 0; col < width; col++) {
input_diff = in_ptr[col] - max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_rescaled = SAT_HIGH_MUL(input_diff * mask, mult);
const int32_t exp_raw = esp_nn_exp_on_negative_values(input_diff_rescaled);
scratch_buf[col] = exp_raw; // store to avoid duplicate calculation later
sum_of_exps += DIV_POW2(exp_raw, ACCUM_BITS);
}
}
const int32_t headroom_plus1 = esp_nn_clz32((uint32_t) sum_of_exps);
const int32_t shifted_scale = ONE_OVER_ONE_X((sum_of_exps << headroom_plus1) - (1 << 31));
const int32_t bits_over_unit = ACCUM_BITS - headroom_plus1 + 31 - sizeof(int8_t) * 8;
for (col = 0; col < width; col++) {
input_diff = in_ptr[col] - max_in_row;
if (input_diff >= diff_min) {
int32_t exp_raw = scratch_buf[col];
const int32_t shifted_output = SAT_HIGH_MUL(shifted_scale, exp_raw);
const int32_t result = DIV_POW2(shifted_output, bits_over_unit) - 128;
out_ptr[col] = (int8_t) esp_nn_saturate8(result);
} else {
out_ptr[col] = -128;
}
}
in_ptr += width;
out_ptr += width;
}
}

View File

@ -0,0 +1,104 @@
// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include "../common/common_functions.h"
#define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
#define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
#define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
#define SAT_HIGH_MUL(x, y) esp_nn_sat_round_doubling_high_mul((x), (y))
#define DIV_POW2(x,y) esp_nn_div_by_power_of_two((x), (y))
__NN_FORCE_INLINE__ int32_t mul_power_of_2(int val, int exp)
{
const int32_t thresh = ((1 << (31 - exp)) - 1);
int32_t result = val << exp;
result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), INT32_MAX, result);
result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), INT32_MIN, result);
return result;
}
/**
* @brief Calculate `1 / (1 + x)` for x in [0, 1]
*
* @param val input value to calculate `1/(1+x)` for
* @return `int32_t` result
* @note Newton-Raphson division
*
* https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
* Refer to that page for the logic behind the 48/17 and 32/17 constants.
* Pseudocode: https://en.wikipedia.org/wiki/Division_algorithm#Pseudocode
*/
__NN_FORCE_INLINE__ int32_t esp_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
{
const int64_t sum = (int64_t) val + INT32_MAX;
const int32_t half_denominator = (int32_t) ((sum + (sum >= 0 ? 1 : -1)) / 2L);
int32_t constant_48_over_17 = 1515870810;
int32_t constant_neg_32_over_17 = -1010580540;
int32_t x = constant_48_over_17 + SAT_HIGH_MUL(half_denominator, constant_neg_32_over_17);
const int32_t fixed_2_one = (1 << 29);
x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
return mul_power_of_2(x, 1);
}
#define ONE_OVER_ONE_X(x) esp_nn_one_over_one_plus_x_for_x_in_0_1((x))
/**
* @brief Return exp(x) for x < 0.
*
*/
__NN_FORCE_INLINE__ int32_t esp_nn_exp_on_negative_values(int32_t val)
{
int32_t shift = 24;
const int32_t one_quarter = (1 << shift);
int32_t mask = one_quarter - 1;
const int32_t val_mod_minus_quarter = (val & mask) - one_quarter;
const int32_t remainder = val_mod_minus_quarter - val;
// calculate exponent for x in [-1/4, 0) in `result`
const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
const int32_t x2 = SAT_HIGH_MUL(x, x);
const int32_t x3 = SAT_HIGH_MUL(x2, x);
const int32_t x4 = SAT_HIGH_MUL(x2, x2);
const int32_t one_over_3 = 715827883;
const int32_t one_over_8 = 1895147668;
const int32_t x4_over_4 = DIV_POW2(x4, 2);
const int32_t x4_over_4_plus_x3_over_6_plus_x2_over_2 = DIV_POW2(SAT_HIGH_MUL(x4_over_4 + x3, one_over_3) + x2, 1);
int32_t result = one_over_8 + SAT_HIGH_MUL(one_over_8, x + x4_over_4_plus_x3_over_6_plus_x2_over_2);
#define SELECT_IF_NON_ZERO(x) { \
mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
result = SELECT_USING_MASK(mask, SAT_HIGH_MUL(result, x), result); \
}
SELECT_IF_NON_ZERO(1672461947)
SELECT_IF_NON_ZERO(1302514674)
SELECT_IF_NON_ZERO(790015084)
SELECT_IF_NON_ZERO(290630308)
SELECT_IF_NON_ZERO(39332535)
SELECT_IF_NON_ZERO(720401)
SELECT_IF_NON_ZERO(242)
#undef SELECT_IF_NON_ZERO
mask = MASK_IF_ZERO(val);
return SELECT_USING_MASK(mask, INT32_MAX, result);
}

View File

@ -0,0 +1,22 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Compatibility shim for new location of interface definitions.
#ifndef TENSORFLOW_LITE_BUILTIN_OP_DATA_H_
#define TENSORFLOW_LITE_BUILTIN_OP_DATA_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#endif // TENSORFLOW_LITE_BUILTIN_OP_DATA_H_

View File

@ -0,0 +1,194 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_BUILTIN_OPS_H_
#define TENSORFLOW_LITE_BUILTIN_OPS_H_
// DO NOT EDIT MANUALLY: This file is automatically generated by
// `schema/builtin_ops_header/generator.cc`.
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// The enum for builtin operators.
// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special
// ops which are not real built-in ops.
typedef enum {
kTfLiteBuiltinAdd = 0,
kTfLiteBuiltinAveragePool2d = 1,
kTfLiteBuiltinConcatenation = 2,
kTfLiteBuiltinConv2d = 3,
kTfLiteBuiltinDepthwiseConv2d = 4,
kTfLiteBuiltinDepthToSpace = 5,
kTfLiteBuiltinDequantize = 6,
kTfLiteBuiltinEmbeddingLookup = 7,
kTfLiteBuiltinFloor = 8,
kTfLiteBuiltinFullyConnected = 9,
kTfLiteBuiltinHashtableLookup = 10,
kTfLiteBuiltinL2Normalization = 11,
kTfLiteBuiltinL2Pool2d = 12,
kTfLiteBuiltinLocalResponseNormalization = 13,
kTfLiteBuiltinLogistic = 14,
kTfLiteBuiltinLshProjection = 15,
kTfLiteBuiltinLstm = 16,
kTfLiteBuiltinMaxPool2d = 17,
kTfLiteBuiltinMul = 18,
kTfLiteBuiltinRelu = 19,
kTfLiteBuiltinReluN1To1 = 20,
kTfLiteBuiltinRelu6 = 21,
kTfLiteBuiltinReshape = 22,
kTfLiteBuiltinResizeBilinear = 23,
kTfLiteBuiltinRnn = 24,
kTfLiteBuiltinSoftmax = 25,
kTfLiteBuiltinSpaceToDepth = 26,
kTfLiteBuiltinSvdf = 27,
kTfLiteBuiltinTanh = 28,
kTfLiteBuiltinConcatEmbeddings = 29,
kTfLiteBuiltinSkipGram = 30,
kTfLiteBuiltinCall = 31,
kTfLiteBuiltinCustom = 32,
kTfLiteBuiltinEmbeddingLookupSparse = 33,
kTfLiteBuiltinPad = 34,
kTfLiteBuiltinUnidirectionalSequenceRnn = 35,
kTfLiteBuiltinGather = 36,
kTfLiteBuiltinBatchToSpaceNd = 37,
kTfLiteBuiltinSpaceToBatchNd = 38,
kTfLiteBuiltinTranspose = 39,
kTfLiteBuiltinMean = 40,
kTfLiteBuiltinSub = 41,
kTfLiteBuiltinDiv = 42,
kTfLiteBuiltinSqueeze = 43,
kTfLiteBuiltinUnidirectionalSequenceLstm = 44,
kTfLiteBuiltinStridedSlice = 45,
kTfLiteBuiltinBidirectionalSequenceRnn = 46,
kTfLiteBuiltinExp = 47,
kTfLiteBuiltinTopkV2 = 48,
kTfLiteBuiltinSplit = 49,
kTfLiteBuiltinLogSoftmax = 50,
kTfLiteBuiltinDelegate = 51,
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
kTfLiteBuiltinCast = 53,
kTfLiteBuiltinPrelu = 54,
kTfLiteBuiltinMaximum = 55,
kTfLiteBuiltinArgMax = 56,
kTfLiteBuiltinMinimum = 57,
kTfLiteBuiltinLess = 58,
kTfLiteBuiltinNeg = 59,
kTfLiteBuiltinPadv2 = 60,
kTfLiteBuiltinGreater = 61,
kTfLiteBuiltinGreaterEqual = 62,
kTfLiteBuiltinLessEqual = 63,
kTfLiteBuiltinSelect = 64,
kTfLiteBuiltinSlice = 65,
kTfLiteBuiltinSin = 66,
kTfLiteBuiltinTransposeConv = 67,
kTfLiteBuiltinSparseToDense = 68,
kTfLiteBuiltinTile = 69,
kTfLiteBuiltinExpandDims = 70,
kTfLiteBuiltinEqual = 71,
kTfLiteBuiltinNotEqual = 72,
kTfLiteBuiltinLog = 73,
kTfLiteBuiltinSum = 74,
kTfLiteBuiltinSqrt = 75,
kTfLiteBuiltinRsqrt = 76,
kTfLiteBuiltinShape = 77,
kTfLiteBuiltinPow = 78,
kTfLiteBuiltinArgMin = 79,
kTfLiteBuiltinFakeQuant = 80,
kTfLiteBuiltinReduceProd = 81,
kTfLiteBuiltinReduceMax = 82,
kTfLiteBuiltinPack = 83,
kTfLiteBuiltinLogicalOr = 84,
kTfLiteBuiltinOneHot = 85,
kTfLiteBuiltinLogicalAnd = 86,
kTfLiteBuiltinLogicalNot = 87,
kTfLiteBuiltinUnpack = 88,
kTfLiteBuiltinReduceMin = 89,
kTfLiteBuiltinFloorDiv = 90,
kTfLiteBuiltinReduceAny = 91,
kTfLiteBuiltinSquare = 92,
kTfLiteBuiltinZerosLike = 93,
kTfLiteBuiltinFill = 94,
kTfLiteBuiltinFloorMod = 95,
kTfLiteBuiltinRange = 96,
kTfLiteBuiltinResizeNearestNeighbor = 97,
kTfLiteBuiltinLeakyRelu = 98,
kTfLiteBuiltinSquaredDifference = 99,
kTfLiteBuiltinMirrorPad = 100,
kTfLiteBuiltinAbs = 101,
kTfLiteBuiltinSplitV = 102,
kTfLiteBuiltinUnique = 103,
kTfLiteBuiltinCeil = 104,
kTfLiteBuiltinReverseV2 = 105,
kTfLiteBuiltinAddN = 106,
kTfLiteBuiltinGatherNd = 107,
kTfLiteBuiltinCos = 108,
kTfLiteBuiltinWhere = 109,
kTfLiteBuiltinRank = 110,
kTfLiteBuiltinElu = 111,
kTfLiteBuiltinReverseSequence = 112,
kTfLiteBuiltinMatrixDiag = 113,
kTfLiteBuiltinQuantize = 114,
kTfLiteBuiltinMatrixSetDiag = 115,
kTfLiteBuiltinRound = 116,
kTfLiteBuiltinHardSwish = 117,
kTfLiteBuiltinIf = 118,
kTfLiteBuiltinWhile = 119,
kTfLiteBuiltinNonMaxSuppressionV4 = 120,
kTfLiteBuiltinNonMaxSuppressionV5 = 121,
kTfLiteBuiltinScatterNd = 122,
kTfLiteBuiltinSelectV2 = 123,
kTfLiteBuiltinDensify = 124,
kTfLiteBuiltinSegmentSum = 125,
kTfLiteBuiltinBatchMatmul = 126,
kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127,
kTfLiteBuiltinCumsum = 128,
kTfLiteBuiltinCallOnce = 129,
kTfLiteBuiltinBroadcastTo = 130,
kTfLiteBuiltinRfft2d = 131,
kTfLiteBuiltinConv3d = 132,
kTfLiteBuiltinImag = 133,
kTfLiteBuiltinReal = 134,
kTfLiteBuiltinComplexAbs = 135,
kTfLiteBuiltinHashtable = 136,
kTfLiteBuiltinHashtableFind = 137,
kTfLiteBuiltinHashtableImport = 138,
kTfLiteBuiltinHashtableSize = 139,
kTfLiteBuiltinReduceAll = 140,
kTfLiteBuiltinConv3dTranspose = 141,
kTfLiteBuiltinVarHandle = 142,
kTfLiteBuiltinReadVariable = 143,
kTfLiteBuiltinAssignVariable = 144,
kTfLiteBuiltinBroadcastArgs = 145,
kTfLiteBuiltinRandomStandardNormal = 146,
kTfLiteBuiltinBucketize = 147,
kTfLiteBuiltinRandomUniform = 148,
kTfLiteBuiltinMultinomial = 149,
kTfLiteBuiltinGelu = 150,
kTfLiteBuiltinDynamicUpdateSlice = 151,
kTfLiteBuiltinRelu0To1 = 152,
kTfLiteBuiltinUnsortedSegmentProd = 153,
kTfLiteBuiltinUnsortedSegmentMax = 154,
kTfLiteBuiltinUnsortedSegmentSum = 155,
kTfLiteBuiltinAtan2 = 156,
kTfLiteBuiltinUnsortedSegmentMin = 157,
kTfLiteBuiltinSign = 158,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_LITE_BUILTIN_OPS_H_

View File

@ -0,0 +1,525 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_
#define TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_
#include <stdint.h>
#include "tensorflow/lite/c/common.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible
// number of dimensions.
#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8
// TODO(aselle): Consider using "if this then that" for testing.
// Useful placeholder to put in otherwise empty structs to avoid size warnings.
typedef struct {
char dummy;
} EmptyStructPlaceholder;
// IMPORTANT: All new members of structs must be added at the end to ensure
// backwards compatibility.
// Possible padding types (for convolutions)
typedef enum {
kTfLitePaddingUnknown = 0,
kTfLitePaddingSame,
kTfLitePaddingValid,
} TfLitePadding;
typedef enum {
kTfLiteMirrorPaddingUnknown = 0,
kTfLiteMirrorPaddingReflect,
kTfLiteMirrorPaddingSymmetric,
} TfLiteMirrorPaddingMode;
// TODO(b/130259536): We should move this out of builtin_op_data.
typedef struct {
int width;
int height;
int width_offset;
int height_offset;
} TfLitePaddingValues;
typedef struct {
TfLiteMirrorPaddingMode mode;
} TfLiteMirrorPaddingParams;
// Possible fused activation functions.
typedef enum {
kTfLiteActNone = 0,
kTfLiteActRelu,
kTfLiteActReluN1To1, // min(max(-1, x), 1)
kTfLiteActRelu6, // min(max(0, x), 6)
kTfLiteActTanh,
kTfLiteActSignBit,
kTfLiteActSigmoid,
} TfLiteFusedActivation;
typedef struct {
// Parameters for CONV_2D version 1.
TfLitePadding padding;
int stride_width;
int stride_height;
TfLiteFusedActivation activation;
// Parameters for CONV_2D version 2.
// Note: Version 2 supports dilation values not equal to 1.
int dilation_width_factor;
int dilation_height_factor;
} TfLiteConvParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
int stride_depth;
int dilation_width_factor;
int dilation_height_factor;
int dilation_depth_factor;
TfLiteFusedActivation activation;
} TfLiteConv3DParams;
typedef TfLiteConv3DParams TfLiteConv3DTransposeParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
int filter_width;
int filter_height;
TfLiteFusedActivation activation;
struct {
TfLitePaddingValues padding;
} computed;
} TfLitePoolParams;
typedef struct {
// Parameters for DepthwiseConv version 1 or above.
TfLitePadding padding;
int stride_width;
int stride_height;
// `depth_multiplier` is redundant. It's used by CPU kernels in
// TensorFlow 2.0 or below, but ignored in versions above.
//
// The information can be deduced from the shape of input and the shape of
// weights. Since the TFLiteConverter toolchain doesn't support partially
// specified shapes, relying on `depth_multiplier` stops us from supporting
// graphs with dynamic shape tensors.
//
// Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this
// field.
int depth_multiplier;
TfLiteFusedActivation activation;
// Parameters for DepthwiseConv version 2 or above.
int dilation_width_factor;
int dilation_height_factor;
} TfLiteDepthwiseConvParams;
typedef struct {
int rank;
TfLiteFusedActivation activation;
// Parameter for SVDF version 4.
bool asymmetric_quantize_inputs;
} TfLiteSVDFParams;
typedef struct {
TfLiteFusedActivation activation;
// Parameter for RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteRNNParams;
typedef struct {
bool time_major;
TfLiteFusedActivation activation;
// Parameter for Sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteSequenceRNNParams;
typedef struct {
bool time_major;
TfLiteFusedActivation activation;
bool merge_outputs;
// Parameter for Bidirectional RNN verison 3.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceRNNParams;
typedef enum {
kTfLiteFullyConnectedWeightsFormatDefault = 0,
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
} TfLiteFullyConnectedWeightsFormat;
typedef struct {
// Parameters for FullyConnected version 1 or above.
TfLiteFusedActivation activation;
// Parameters for FullyConnected version 2 or above.
TfLiteFullyConnectedWeightsFormat weights_format;
// Parameters for FullyConnected version 5 or above.
// If set to true, then the number of dimensions in the input and the output
// tensors are the same. Furthermore, all but the last dimension of the input
// and output shapes will be equal.
bool keep_num_dims;
// Parameters for FullyConnected version 7 or above.
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
} TfLiteFullyConnectedParams;
typedef enum {
kTfLiteLshProjectionUnknown = 0,
kTfLiteLshProjectionSparse = 1,
kTfLiteLshProjectionDense = 2,
} TfLiteLSHProjectionType;
typedef struct {
TfLiteLSHProjectionType type;
} TfLiteLSHProjectionParams;
typedef struct {
float beta;
} TfLiteSoftmaxParams;
typedef struct {
int axis;
TfLiteFusedActivation activation;
} TfLiteConcatenationParams;
typedef struct {
TfLiteFusedActivation activation;
// Parameter added for the version 4.
bool pot_scale_int16;
} TfLiteAddParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteSpaceToBatchNDParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteBatchToSpaceNDParams;
typedef struct {
bool adj_x;
bool adj_y;
// Parameters for BatchMatMul version 4 or above.
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
} TfLiteBatchMatMulParams;
typedef struct {
TfLiteFusedActivation activation;
} TfLiteMulParams;
typedef struct {
TfLiteFusedActivation activation;
// Parameter added for the version 5.
bool pot_scale_int16;
} TfLiteSubParams;
typedef struct {
TfLiteFusedActivation activation;
} TfLiteDivParams;
typedef struct {
TfLiteFusedActivation activation;
} TfLiteL2NormParams;
typedef struct {
int radius;
float bias;
float alpha;
float beta;
} TfLiteLocalResponseNormParams;
typedef enum {
kTfLiteLSTMFullKernel = 0,
kTfLiteLSTMBasicKernel
} TfLiteLSTMKernelType;
typedef struct {
// Parameters for LSTM version 1.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
// Parameters for LSTM version 2.
// kTfLiteLSTMBasicKernel is only supported in version 2 or above.
TfLiteLSTMKernelType kernel_type;
// Parameters for LSTM version 4.
bool asymmetric_quantize_inputs;
} TfLiteLSTMParams;
typedef struct {
// Parameters needed for the underlying LSTM.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
// If set to true then the first dimension is time, otherwise batch.
bool time_major;
// Parameter for unidirectional sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteUnidirectionalSequenceLSTMParams;
typedef struct {
// Parameters supported by version 1:
// Parameters inherited for the LSTM kernel.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
// If true, store the outputs of both directions in the first output.
bool merge_outputs;
// Parameters supported by version 2:
// If set to true then the first dimension is time, otherwise batch.
bool time_major;
// Parameters supported by version 4:
// If set to true, then hybrid ops use asymmetric quantization for inputs.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceLSTMParams;
typedef struct {
bool align_corners;
// half_pixel_centers assumes pixels are of half the actual dimensions, and
// yields more accurate resizes. Corresponds to the same argument for the
// original TensorFlow op in TF2.0.
bool half_pixel_centers;
} TfLiteResizeBilinearParams;
typedef struct {
bool align_corners;
bool half_pixel_centers;
} TfLiteResizeNearestNeighborParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLitePadParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLitePadV2Params;
typedef struct {
// These fields are only used in old models for backward compatibility.
// In the current implementation, we use the 2nd input of the op as the shape,
// and these fields are unused.
int shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT];
int num_dimensions;
} TfLiteReshapeParams;
typedef struct {
int ngram_size;
int max_skip_size;
bool include_all_ngrams;
} TfLiteSkipGramParams;
typedef struct {
int block_size;
} TfLiteSpaceToDepthParams;
typedef struct {
int block_size;
} TfLiteDepthToSpaceParams;
typedef struct {
TfLiteType in_data_type;
TfLiteType out_data_type;
} TfLiteCastParams;
typedef enum {
kTfLiteCombinerTypeSum = 0,
kTfLiteCombinerTypeMean = 1,
kTfLiteCombinerTypeSqrtn = 2,
} TfLiteCombinerType;
typedef struct {
TfLiteCombinerType combiner;
} TfLiteEmbeddingLookupSparseParams;
typedef struct {
int axis;
int batch_dims;
} TfLiteGatherParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteTransposeParams;
typedef struct {
bool keep_dims;
} TfLiteReducerParams;
typedef struct {
int num_splits;
} TfLiteSplitParams;
typedef struct {
int num_splits;
} TfLiteSplitVParams;
typedef struct {
// TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
// For now we will fix the maximum possible number of dimensions.
int squeeze_dims[8];
int num_squeeze_dims;
} TfLiteSqueezeParams;
typedef struct {
int begin_mask;
int end_mask;
int ellipsis_mask;
int new_axis_mask;
int shrink_axis_mask;
} TfLiteStridedSliceParams;
typedef struct {
TfLiteType output_type;
} TfLiteArgMaxParams;
typedef struct {
TfLiteType output_type;
} TfLiteArgMinParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
} TfLiteTransposeConvParams;
typedef struct {
bool validate_indices;
} TfLiteSparseToDenseParams;
typedef struct {
TfLiteType out_type;
} TfLiteShapeParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteRankParams;
typedef struct {
// Parameters supported by version 1:
float min;
float max;
int num_bits;
// Parameters supported by version 2:
bool narrow_range;
} TfLiteFakeQuantParams;
typedef struct {
int values_count;
int axis;
} TfLitePackParams;
typedef struct {
int axis;
} TfLiteOneHotParams;
typedef struct {
int num;
int axis;
} TfLiteUnpackParams;
typedef struct {
float alpha;
} TfLiteLeakyReluParams;
typedef struct {
TfLiteType index_out_type;
} TfLiteUniqueParams;
typedef struct {
int seq_dim;
int batch_dim;
} TfLiteReverseSequenceParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixDiagParams;
typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixSetDiagParams;
typedef struct {
int then_subgraph_index;
int else_subgraph_index;
} TfLiteIfParams;
typedef struct {
int cond_subgraph_index;
int body_subgraph_index;
} TfLiteWhileParams;
typedef struct {
bool exclusive;
bool reverse;
} TfLiteCumsumParams;
typedef struct {
int init_subgraph_index;
} TfLiteCallOnceParams;
typedef struct {
int table_id;
TfLiteType key_dtype;
TfLiteType value_dtype;
} TfLiteHashtableParams;
typedef struct {
const char* container;
const char* shared_name;
} TfLiteVarHandleParams;
typedef struct {
int seed;
int seed2;
} TfLiteRandomParams;
typedef struct {
int num_boundaries;
// This points to the memory stored in the model (flatbuffer),
// and is not owned.
const float* boundaries;
} TfLiteBucketizeParams;
typedef struct {
bool approximate;
} TfLiteGeluParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file declares types used by the pure C inference API defined in c_api.h,
// some of which are also used in the C++ and C kernel and interpreter APIs.
#ifndef TENSORFLOW_LITE_C_C_API_TYPES_H_
#define TENSORFLOW_LITE_C_C_API_TYPES_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
// library.
#ifdef SWIG
#define TFL_CAPI_EXPORT
#elif defined(TFL_STATIC_LIBRARY_BUILD)
#define TFL_CAPI_EXPORT
#else // not definded TFL_STATIC_LIBRARY_BUILD
#if defined(_WIN32)
#ifdef TFL_COMPILE_LIBRARY
#define TFL_CAPI_EXPORT __declspec(dllexport)
#else
#define TFL_CAPI_EXPORT __declspec(dllimport)
#endif // TFL_COMPILE_LIBRARY
#else
#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
// Note that new error status values may be added in future in order to
// indicate more fine-grained internal states, therefore, applications should
// not rely on status values being members of the enum.
typedef enum TfLiteStatus {
kTfLiteOk = 0,
// Generally referring to an error in the runtime (i.e. interpreter)
kTfLiteError = 1,
// Generally referring to an error from a TfLiteDelegate itself.
kTfLiteDelegateError = 2,
// Generally referring to an error in applying a delegate due to
// incompatibility between runtime and delegate, e.g., this error is returned
// when trying to apply a TF Lite delegate onto a model graph that's already
// immutable.
kTfLiteApplicationError = 3,
// Generally referring to serialized delegate data not being found.
// See tflite::delegates::Serialization.
kTfLiteDelegateDataNotFound = 4,
// Generally referring to data-writing issues in delegate serialization.
// See tflite::delegates::Serialization.
kTfLiteDelegateDataWriteError = 5,
// Generally referring to data-reading issues in delegate serialization.
// See tflite::delegates::Serialization.
kTfLiteDelegateDataReadError = 6,
// Generally referring to issues when the TF Lite model has ops that cannot be
// resolved at runtime. This could happen when the specific op is not
// registered or built with the TF Lite framework.
kTfLiteUnresolvedOps = 7,
// Generally referring to invocation cancelled by the user.
// See `interpreter::Cancel`.
// TODO(b/194915839): Implement `interpreter::Cancel`.
// TODO(b/250636993): Cancellation triggered by `SetCancellationFunction`
// should also return this status code.
kTfLiteCancelled = 8,
} TfLiteStatus;
// Types supported by tensor
typedef enum {
kTfLiteNoType = 0,
kTfLiteFloat32 = 1,
kTfLiteInt32 = 2,
kTfLiteUInt8 = 3,
kTfLiteInt64 = 4,
kTfLiteString = 5,
kTfLiteBool = 6,
kTfLiteInt16 = 7,
kTfLiteComplex64 = 8,
kTfLiteInt8 = 9,
kTfLiteFloat16 = 10,
kTfLiteFloat64 = 11,
kTfLiteComplex128 = 12,
kTfLiteUInt64 = 13,
kTfLiteResource = 14,
kTfLiteVariant = 15,
kTfLiteUInt32 = 16,
kTfLiteUInt16 = 17,
kTfLiteInt4 = 18,
} TfLiteType;
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
// If per-layer quantization is specified this field will still be populated in
// addition to TfLiteAffineQuantization.
// Parameters for asymmetric quantization. Quantized values can be converted
// back to float using:
// real_value = scale * (quantized_value - zero_point)
typedef struct TfLiteQuantizationParams {
float scale;
int32_t zero_point;
} TfLiteQuantizationParams;
// --------------------------------------------------------------------------
// Opaque types used by c_api.h, c_api_opaque.h and common.h.
// TfLiteOpaqueContext is an opaque version of TfLiteContext;
typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
// TfLiteOpaqueNode is an opaque version of TfLiteNode;
typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
// TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
// TfLiteOpaqueDelegateStruct: opaque version of TfLiteDelegate; allows
// delegation of nodes to alternative backends.
//
// This is an abstract type that is intended to have the same
// role as TfLiteDelegate from common.h, but without exposing the implementation
// details of how delegates are implemented.
// WARNING: This is an experimental type and subject to change.
typedef struct TfLiteOpaqueDelegateStruct TfLiteOpaqueDelegateStruct;
#ifdef __cplusplus
} // extern C
#endif
#endif // TENSORFLOW_LITE_C_C_API_TYPES_H_

View File

@ -0,0 +1,321 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/c_api_types.h"
#ifdef TF_LITE_TENSORFLOW_PROFILER
#include "tensorflow/lite/tensorflow_profiler_logger.h"
#endif
#ifndef ARDUINO
#include <stdlib.h>
#include <string.h>
#endif // ARDUINO
extern "C" {
size_t TfLiteIntArrayGetSizeInBytes(int size) {
static TfLiteIntArray dummy;
size_t computed_size = sizeof(dummy) + sizeof(dummy.data[0]) * size;
#if defined(_MSC_VER)
// Context for why this is needed is in http://b/189926408#comment21
computed_size -= sizeof(dummy.data[0]);
#endif
return computed_size;
}
int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b) {
if (a == b) return 1;
if (a == nullptr || b == nullptr) return 0;
return TfLiteIntArrayEqualsArray(a, b->size, b->data);
}
int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
const int b_data[]) {
if (a == nullptr) return (b_size == 0);
if (a->size != b_size) return 0;
int i = 0;
for (; i < a->size; i++)
if (a->data[i] != b_data[i]) return 0;
return 1;
}
#ifndef ARDUINO
TfLiteIntArray* TfLiteIntArrayCreate(int size) {
size_t alloc_size = TfLiteIntArrayGetSizeInBytes(size);
if (alloc_size <= 0) return nullptr;
TfLiteIntArray* ret = (TfLiteIntArray*)malloc(alloc_size);
if (!ret) return ret;
ret->size = size;
return ret;
}
TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src) {
if (!src) return nullptr;
TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
if (ret) {
memcpy(ret->data, src->data, src->size * sizeof(int));
}
return ret;
}
void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); }
#endif // ARDUINO
int TfLiteFloatArrayGetSizeInBytes(int size) {
static TfLiteFloatArray dummy;
int computed_size = sizeof(dummy) + sizeof(dummy.data[0]) * size;
#if defined(_MSC_VER)
// Context for why this is needed is in http://b/189926408#comment21
computed_size -= sizeof(dummy.data[0]);
#endif
return computed_size;
}
#ifndef ARDUINO
TfLiteFloatArray* TfLiteFloatArrayCreate(int size) {
TfLiteFloatArray* ret =
(TfLiteFloatArray*)malloc(TfLiteFloatArrayGetSizeInBytes(size));
ret->size = size;
return ret;
}
void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
void TfLiteTensorDataFree(TfLiteTensor* t) {
if (t->allocation_type == kTfLiteDynamic ||
t->allocation_type == kTfLitePersistentRo) {
if (t->data.raw) {
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorDealloc(t);
#endif
free(t->data.raw);
}
}
t->data.raw = nullptr;
}
void TfLiteQuantizationFree(TfLiteQuantization* quantization) {
if (quantization->type == kTfLiteAffineQuantization) {
TfLiteAffineQuantization* q_params =
(TfLiteAffineQuantization*)(quantization->params);
if (q_params->scale) {
TfLiteFloatArrayFree(q_params->scale);
q_params->scale = nullptr;
}
if (q_params->zero_point) {
TfLiteIntArrayFree(q_params->zero_point);
q_params->zero_point = nullptr;
}
free(q_params);
}
quantization->params = nullptr;
quantization->type = kTfLiteNoQuantization;
}
void TfLiteSparsityFree(TfLiteSparsity* sparsity) {
if (sparsity == nullptr) {
return;
}
if (sparsity->traversal_order) {
TfLiteIntArrayFree(sparsity->traversal_order);
sparsity->traversal_order = nullptr;
}
if (sparsity->block_map) {
TfLiteIntArrayFree(sparsity->block_map);
sparsity->block_map = nullptr;
}
if (sparsity->dim_metadata) {
int i = 0;
for (; i < sparsity->dim_metadata_size; i++) {
TfLiteDimensionMetadata metadata = sparsity->dim_metadata[i];
if (metadata.format == kTfLiteDimSparseCSR) {
TfLiteIntArrayFree(metadata.array_segments);
metadata.array_segments = nullptr;
TfLiteIntArrayFree(metadata.array_indices);
metadata.array_indices = nullptr;
}
}
free(sparsity->dim_metadata);
sparsity->dim_metadata = nullptr;
}
free(sparsity);
}
void TfLiteTensorFree(TfLiteTensor* t) {
TfLiteTensorDataFree(t);
if (t->dims) TfLiteIntArrayFree(t->dims);
t->dims = nullptr;
if (t->dims_signature) {
TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
}
t->dims_signature = nullptr;
TfLiteQuantizationFree(&t->quantization);
TfLiteSparsityFree(t->sparsity);
t->sparsity = nullptr;
}
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
tensor->dims = dims;
tensor->params = quantization;
tensor->data.raw = buffer;
tensor->bytes = size;
tensor->allocation_type = allocation_type;
tensor->allocation = allocation;
tensor->is_variable = is_variable;
tensor->quantization.type = kTfLiteNoQuantization;
tensor->quantization.params = nullptr;
}
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
if (!src || !dst) return kTfLiteOk;
if (src->bytes != dst->bytes) return kTfLiteError;
if (src == dst) return kTfLiteOk;
dst->type = src->type;
if (dst->dims) TfLiteIntArrayFree(dst->dims);
dst->dims = TfLiteIntArrayCopy(src->dims);
memcpy(dst->data.raw, src->data.raw, src->bytes);
dst->buffer_handle = src->buffer_handle;
dst->data_is_stale = src->data_is_stale;
dst->delegate = src->delegate;
return kTfLiteOk;
}
void TfLiteTensorResizeMaybeCopy(size_t num_bytes, TfLiteTensor* tensor,
bool preserve_data) {
if (tensor->allocation_type != kTfLiteDynamic &&
tensor->allocation_type != kTfLitePersistentRo) {
return;
}
// TODO(b/145340303): Tensor data should be aligned.
if (!tensor->data.data) {
tensor->data.data = (char*)malloc(num_bytes);
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
#endif
} else if (num_bytes > tensor->bytes) {
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorDealloc(tensor);
#endif
if (preserve_data) {
tensor->data.data = (char*)realloc(tensor->data.data, num_bytes);
} else {
// Calling free and malloc can be more efficient as it avoids needlessly
// copying the data when it is not required.
free(tensor->data.data);
tensor->data.data = (char*)malloc(num_bytes);
}
#ifdef TF_LITE_TENSORFLOW_PROFILER
tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
#endif
}
tensor->bytes = num_bytes;
}
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
return TfLiteTensorResizeMaybeCopy(num_bytes, tensor, true);
}
#endif // ARDUINO
const char* TfLiteTypeGetName(TfLiteType type) {
switch (type) {
case kTfLiteNoType:
return "NOTYPE";
case kTfLiteFloat32:
return "FLOAT32";
case kTfLiteUInt16:
return "UINT16";
case kTfLiteInt16:
return "INT16";
case kTfLiteInt32:
return "INT32";
case kTfLiteUInt32:
return "UINT32";
case kTfLiteUInt8:
return "UINT8";
case kTfLiteInt8:
return "INT8";
case kTfLiteInt64:
return "INT64";
case kTfLiteUInt64:
return "UINT64";
case kTfLiteBool:
return "BOOL";
case kTfLiteComplex64:
return "COMPLEX64";
case kTfLiteComplex128:
return "COMPLEX128";
case kTfLiteString:
return "STRING";
case kTfLiteFloat16:
return "FLOAT16";
case kTfLiteFloat64:
return "FLOAT64";
case kTfLiteResource:
return "RESOURCE";
case kTfLiteVariant:
return "VARIANT";
case kTfLiteInt4:
return "INT4";
}
return "Unknown type";
}
TfLiteDelegate TfLiteDelegateCreate() { return TfLiteDelegate{}; }
struct TfLiteOpaqueDelegateStruct* TfLiteOpaqueDelegateCreate(
const TfLiteOpaqueDelegateBuilder* opaque_delegate_builder) {
if (!opaque_delegate_builder) return nullptr;
TfLiteDelegate* result = new TfLiteDelegate{};
result->opaque_delegate_builder = new TfLiteOpaqueDelegateBuilder{};
*(result->opaque_delegate_builder) = *opaque_delegate_builder;
return reinterpret_cast<struct TfLiteOpaqueDelegateStruct*>(result);
}
void TfLiteOpaqueDelegateDelete(
const struct TfLiteOpaqueDelegateStruct* opaque_delegate) {
if (!opaque_delegate) return;
const TfLiteDelegate* tflite_delegate =
reinterpret_cast<const TfLiteDelegate*>(opaque_delegate);
delete tflite_delegate->opaque_delegate_builder;
delete tflite_delegate;
}
} // extern "C"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,53 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/// \file
/// This provides a few C++ helpers that are useful for manipulating C
/// structures in C++.
#ifndef TENSORFLOW_LITE_CONTEXT_UTIL_H_
#define TENSORFLOW_LITE_CONTEXT_UTIL_H_
#include <stddef.h>
#include "tensorflow/lite/c/common.h"
namespace tflite {
/// Provides a range iterable wrapper for TfLiteIntArray* (C lists) that TfLite
/// C api uses.
// Can't use the google array_view, since we can't depend on even
// absl for embedded device reasons.
class TfLiteIntArrayView {
public:
/// Construct a view of a TfLiteIntArray*. Note, `int_array` should be
/// non-null and this view does not take ownership of it.
explicit TfLiteIntArrayView(const TfLiteIntArray* int_array)
: int_array_(int_array) {}
TfLiteIntArrayView(const TfLiteIntArrayView&) = default;
TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default;
typedef const int* const_iterator;
const_iterator begin() const { return int_array_->data; }
const_iterator end() const { return &int_array_->data[int_array_->size]; }
size_t size() const { return end() - begin(); }
int operator[](size_t pos) const { return int_array_->data[pos]; }
private:
const TfLiteIntArray* int_array_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_CONTEXT_UTIL_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/core/api/error_reporter.h"
#include <cstdarg>
namespace tflite {
int ErrorReporter::Report(const char* format, ...) {
va_list args;
va_start(args, format);
int code = Report(format, args);
va_end(args);
return code;
}
// TODO(aselle): Make the name of ReportError on context the same, so
// we can use the ensure functions w/o a context and w/ a reporter.
int ErrorReporter::ReportError(void*, const char* format, ...) {
va_list args;
va_start(args, format);
int code = Report(format, args);
va_end(args);
return code;
}
} // namespace tflite

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
#define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
#include <cstdarg>
namespace tflite {
/// A functor that reports error to supporting system. Invoked similar to
/// printf.
///
/// Usage:
/// ErrorReporter foo;
/// foo.Report("test %d", 5);
/// or
/// va_list args;
/// foo.Report("test %d", args); // where args is va_list
///
/// Subclass ErrorReporter to provide another reporting destination.
/// For example, if you have a GUI program, you might redirect to a buffer
/// that drives a GUI error log box.
class ErrorReporter {
public:
virtual ~ErrorReporter() {}
virtual int Report(const char* format, va_list args) = 0;
int Report(const char* format, ...);
int ReportError(void*, const char* format, ...);
};
} // namespace tflite
// You should not make bare calls to the error reporter, instead use the
// TF_LITE_REPORT_ERROR macro, since this allows message strings to be
// stripped when the binary size has to be optimized. If you are looking to
// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and
// every call will be stubbed out, taking no memory.
#ifndef TF_LITE_STRIP_ERROR_STRINGS
#define TF_LITE_REPORT_ERROR(reporter, ...) \
do { \
static_cast<tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \
} while (false)
#else // TF_LITE_STRIP_ERROR_STRINGS
#define TF_LITE_REPORT_ERROR(reporter, ...)
#endif // TF_LITE_STRIP_ERROR_STRINGS
#endif // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_

View File

@ -0,0 +1,412 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
#define TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
// These functions transform codes and data structures that are defined in the
// flatbuffer serialization format into in-memory values that are used by the
// runtime API and interpreter.
#include <cstddef>
#include <new>
#include <type_traits>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// Interface class for builtin data allocations.
class BuiltinDataAllocator {
public:
virtual void* Allocate(size_t size, size_t alignment_hint) = 0;
virtual void Deallocate(void* data) = 0;
// Allocate a structure, but make sure it is a POD structure that doesn't
// require constructors to run. The reason we do this, is that Interpreter's C
// extension part will take ownership so destructors will not be run during
// deallocation.
template <typename T>
T* AllocatePOD() {
// TODO(b/154346074): Change this to is_trivially_destructible when all
// platform targets support that properly.
static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
void* allocated_memory = this->Allocate(sizeof(T), alignof(T));
return new (allocated_memory) T();
}
virtual ~BuiltinDataAllocator() {}
};
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
// calling function has to pass in an allocator object, and this allocator
// will be called to reserve space for the output data. If the calling
// function's allocator reserves memory on the heap, then it's the calling
// function's responsibility to free it.
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
// Converts the tensor data type used in the flat buffer to the representation
// used by the runtime.
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
ErrorReporter* error_reporter);
TfLiteStatus ParseAbs(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseAdd(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseAddN(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseArgMax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseArgMin(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseAssignVariable(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseBatchMatMul(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseBatchToSpaceNd(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseBroadcastArgs(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseBroadcastTo(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseCallOnce(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseCeil(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCast(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseConcatenation(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseConv2D(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCos(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCumsum(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseDepthToSpace(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseDepthwiseConv2D(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseDequantize(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseDiv(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseElu(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseEqual(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseExp(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseExpandDims(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseFill(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFloor(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFloorDiv(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseFloorMod(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseFullyConnected(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseGather(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseGatherNd(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseGreater(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseGreaterEqual(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseHardSwish(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseIf(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseL2Normalization(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLeakyRelu(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLess(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseLessEqual(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLog(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseLogicalAnd(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogicalNot(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogicalOr(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogistic(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogSoftmax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLSTM(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseMaximum(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseMinimum(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseMirrorPad(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseMul(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseNeg(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseNotEqual(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParsePack(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePad(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePadV2(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePool(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePow(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePrelu(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseQuantize(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseReadVariable(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseReducer(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseRelu(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseRelu6(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseResizeBilinear(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseRound(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseRsqrt(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSelectV2(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSlice(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSpaceToBatchNd(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSpaceToDepth(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSplitV(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSqueeze(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSqrt(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSquare(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSquaredDifference(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseStridedSlice(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSub(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSvdf(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseTanh(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseTranspose(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseTransposeConv(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseUnidirectionalSequenceLSTM(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseVarHandle(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseWhile(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
} // namespace tflite
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_

View File

@ -0,0 +1,68 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/core/api/op_resolver.h"
#include "third_party/flatbuffers/include/flatbuffers/flatbuffers.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_utils.h"
namespace tflite {
TfLiteStatus GetRegistrationFromOpCode(
const OperatorCode* opcode, const OpResolver& op_resolver,
ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
TfLiteStatus status = kTfLiteOk;
*registration = nullptr;
auto builtin_code = GetBuiltinCode(opcode);
int version = opcode->version();
if (builtin_code > BuiltinOperator_MAX) {
TF_LITE_REPORT_ERROR(
error_reporter,
"Op builtin_code out of range: %d. Are you using old TFLite binary "
"with newer model?",
builtin_code);
status = kTfLiteError;
} else if (builtin_code != BuiltinOperator_CUSTOM) {
*registration = op_resolver.FindOp(builtin_code, version);
if (*registration == nullptr) {
TF_LITE_REPORT_ERROR(
error_reporter,
"Didn't find op for builtin opcode '%s' version '%d'. "
"An older version of this builtin might be supported. "
"Are you using an old TFLite binary with a newer model?\n",
EnumNameBuiltinOperator(builtin_code), version);
status = kTfLiteError;
}
} else if (!opcode->custom_code()) {
TF_LITE_REPORT_ERROR(
error_reporter,
"Operator with CUSTOM builtin_code has no custom_code.\n");
status = kTfLiteError;
} else {
const char* name = opcode->custom_code()->c_str();
*registration = op_resolver.FindOp(name, version);
if (*registration == nullptr) {
// Do not report error for unresolved custom op, we do the final check
// while preparing ops.
status = kTfLiteError;
}
}
return status;
}
} // namespace tflite

View File

@ -0,0 +1,140 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
// This is used for cases (e.g. when using "TF Lite with Google Play Services")
// where the TF Lite runtime might be built using a newer (or older)
// version of the TF Lite sources than the app, and hence might have a
// different definition of the TfLiteDelegate type. TF Lite APIs use
// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
// refer to a delegate defined with that potentially different version
// of the TfLiteDelegate type.
struct TfLiteOpaqueDelegateStruct;
namespace tflite {
/// Abstract interface that returns TfLiteRegistrations given op codes or custom
/// op names. This is the mechanism that ops being referenced in the flatbuffer
/// model are mapped to executable function pointers (TfLiteRegistrations).
class OpResolver {
public:
/// Finds the op registration for a builtin operator by enum code.
virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const = 0;
/// Finds the op registration of a custom operator by op name.
virtual const TfLiteRegistration* FindOp(const char* op,
int version) const = 0;
// Represents a sequence of delegates.
using TfLiteDelegatePtrVector =
std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
// Returns optional delegates for resolving and handling ops in the flatbuffer
// model. This may be used in addition to the standard TfLiteRegistration
// lookup for graph resolution.
// WARNING: This API is deprecated, GetDelegateCreators is preferred.
virtual TfLiteDelegatePtrVector GetDelegates(int num_threads) const {
return {};
}
// Represents a function that creates a TfLite delegate instance.
using TfLiteDelegateCreator =
std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
int /*num_threads*/)>;
// Represents a sequence of delegate creator functions.
using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
// Returns a vector of delegate creators to create optional delegates for
// resolving and handling ops in the flatbuffer model. This may be used in
// addition to the standard TfLiteRegistration lookup for graph resolution.
//
// Note that this method is not used (will not be called) if you are using
// TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
// (see below) is used for that case.
virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
// TODO(b/202712825): it would be nice if we could avoid the need for separate
// "opaque" types & methods for use only with TF Lite in Google Play Services.
// Represents an opaque delegate instance.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegatePtr =
std::unique_ptr<TfLiteOpaqueDelegateStruct,
void (*)(TfLiteOpaqueDelegateStruct*)>;
// Represents a function that creates an opaque delegate instance.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegateCreator =
std::function<TfLiteOpaqueDelegatePtr(int /*num_threads*/)>;
// Represents a sequence of opaque delegate creator functions.
// WARNING: Experimental interface, subject to change.
using TfLiteOpaqueDelegateCreators = std::vector<TfLiteOpaqueDelegateCreator>;
// Returns a vector of opaque delegate creators to create optional opaque
// delegates for resolving and handling ops in the flatbuffer model. This may
// be used in addition to the standard TfLiteRegistration lookup for graph
// resolution.
//
// Note that this method will be called only if you are using TF Lite in
// Google Play Services; if you are using regular TF Lite, GetDelegateCreators
// (see above) is used instead.
//
// WARNING: Experimental interface, subject to change.
virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
return {};
}
virtual ~OpResolver() {}
private:
/// Returns true if this OpResolver may contain any "user defined" ops.
/// By "user defined" ops, we mean any op definitions other than those
/// contained in tflite::ops::builtin::BuiltinOpResolver.
///
/// If this method returns true, it doesn't necessarily mean that the
/// OpResolver contains a user-defined op, just that the absence of
/// user-defined ops can't be guaranteed.
///
/// Note that "user-defined" ops are not the same as "custom" ops;
/// BuiltinOpResolver may support certain "custom" ops, in addition to
/// "builtin" ops, and may not support all of the "builtin" op enum values.
virtual bool MayContainUserDefinedOps() const { return true; }
friend class OpResolverInternal;
};
// Handles the logic for converting between an OperatorCode structure extracted
// from a flatbuffer and information about a registered operator
// implementation.
TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
const OpResolver& op_resolver,
ErrorReporter* error_reporter,
const TfLiteRegistration** registration);
} // namespace tflite
#endif // TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_

View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/core/api/tensor_utils.h"
#include <string.h>
#include "tensorflow/lite/c/common.h"
namespace tflite {
TfLiteStatus ResetVariableTensor(TfLiteTensor* tensor) {
if (!tensor->is_variable) {
return kTfLiteOk;
}
// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
// to the value of the buffer.
int value = 0;
if (tensor->type == kTfLiteInt8) {
value = tensor->params.zero_point;
}
// TODO(b/139446230): Provide a platform header to better handle these
// specific scenarios.
#if __ANDROID__ || defined(__x86_64__) || defined(__i386__) || \
defined(__i386) || defined(__x86__) || defined(__X86__) || \
defined(_X86_) || defined(_M_IX86) || defined(_M_X64)
memset(tensor->data.raw, value, tensor->bytes);
#else
char* raw_ptr = tensor->data.raw;
for (size_t i = 0; i < tensor->bytes; ++i) {
*raw_ptr = value;
raw_ptr++;
}
#endif
return kTfLiteOk;
}
} // namespace tflite

Some files were not shown because too many files have changed in this diff Show More