Keras stateful LSTM/GRU support
ST Edge AI Core Technology 2.2.0
r2.0
Introduction
ST Edge AI Core provides an initial support for the Keras
stateful LSTM/GRU layers (‘legacy’
C-API). Only a batch-size
of 1
is
supported. To manage completely the internal state of a layer,
dedicated weak functions can be implemented by the client
application. Note that the return-state
attribute is
not supported.
Example of Keras model with LSTM layer
= tf.keras.layers.Input(shape=(1, 64), batch_size=1)
inputs = tf.keras.layers.LSTM(20, activation='relu', stateful=True,
lstm =True)(inputs)
return_sequences= tf.keras.layers.Dense(5)(lstm)
dense = tf.keras.Model([inputs], [dense]) model
Limitations
- STM32 target: Weak functions to manage the state of the layer are not supported for the relocatable binary model.
State buffer allocation
State of a recurrent layer is currently stored in a memory buffer
which must be allocated outside the
[activations
][API_data_placement] buffer because the
contents should be conserved between two inferences. By default,
dedicated weak functions
(_allocate_[lstm,gru]_states()
) are implemented in the
network runtime library to allocate the buffer in the system heap
(call of malloc()
function). This function is called
during the execution of the ai_<network>_init()
function. Allocated buffer will be used by the forward LSTM/GRU
function during the inference (ai_<network>_run()
function) to restore and to save the global state.
- called for each instance of the Keras stateful LSTM/GRU
layer
size_in_bytes = 2 * n_cell * sizeof(ai_float)
AI_DECLARE_WEAKvoid _allocate_lstm_states(ai_float **states, ai_u32 size_in_bytes)
{
= AI_HANDLE_PTR(*states);
ai_handle src *states = (ai_float*)realloc(src, size_in_bytes);
// Clear lstm initial state
if (*states) {
(*states, 0, size_in_bytes);
memset}
}
AI_DECLARE_WEAKvoid _allocate_gru_states(ai_float **states, ai_u32 size_in_bytes)
{
(states && size_in_bytes>0)
AI_ASSERT= AI_HANDLE_PTR(*states);
ai_handle src *states = (ai_float*)realloc(src, size_in_bytes);
// Clear lstm initial state
if (*states) {
(*states, 0, size_in_bytes);
memset}
}
As illustrated in the following snippet code, to be able to set the state at a specific value or to reset the state regularly, this function should be implemented by the client application to have a reference on the allocated buffers.
static ai_float* _lstm_states[NUM_INSTANCES] = {};
static ai_u32 _lstm_states_size[NUM_INSTANCES] = {};
static int _lstm_instance_idx = 0;
void _allocate_lstm_states(ai_float **states, ai_u32 size_in_bytes)
{
if ((_lstm_instance_idx > NUM_INSTANCES) || (!states) || (*states) || (!size_in_bytes)) {
/* error - invalid call */
return;
}
= AI_HANDLE_PTR(*states);
ai_handle src [_lstm_instance_idx] = (ai_float *)malloc(size_in_bytes);
_lstm_states[_lstm_instance_idx] = size_in_bytes;
_lstm_states_size*states = _lstm_states[_lstm_instance_idx];
/*
Clear lstm initial state or
set the state with a user-defined value.
*/
if (*states) {
(*states, 0, size_in_bytes);
memset}
++;
_lstm_instance_idx}
...
void clear_lstm_states(void) {
for (int i=0; i<NUM_INSTANCES; i++) {
(_lstm_states[i], 0, _lstm_states_size[i]);
memset}
}
...
();
aiInit
/* main processing loop */
int sample_n = 0;
while (1) {
/* 1 - Acquire, pre-process and fill the input buffers */
(in_data);
acquire_and_process_data
/* 2 - Call inference engine */
if (sample_n++ > 20) {
();
clear_lstm_states= 0;
sample_n }
(in_data, out_data);
aiRun
/* 3 - Post-process the predictions */
(out_data);
post_process}
To free the allocated resources, second weak function is also
defined and called when the instance of the model is destroyed (ai_<network>_destroy()
function).
AI_DECLARE_WEAKvoid _deallocate_lstm_states(ai_float **states)
{
if (*states) {
(*states);
free*states = NULL;
}
}
AI_DECLARE_WEAKvoid _deallocate_gru_states(ai_float **states)
{
if (*states) {
(*states);
free*states = NULL;
}
}
Validate command limitation
With the current implementation of the aiValidation firmware, the state buffers are created and reset only during the initialization of the firmware. If multiple validate command are used, the board should be re-start between two executions.
# board should be reset.
$ stm32ai validate model.h5 --mode stm32 -b 20
There is no limitation for the validation on desktop. Run-time is re-instantiated between two executions.
$ stm32ai validate model.h5 -b 20