tflite#
Warning
Additional Dependencies: To use the Tensorflow Lite models, you will need to install additional dependencies. These dependencies are optional, as they can be heavy and may not be needed in all use cases. To install them, run:
audioclass.models.tflite
#
Module for defining TensorFlow Lite-based audio classification models.
This module provides classes and functions for creating and using TensorFlow
Lite models for audio classification tasks. It includes a TFLiteModel
class
that wraps a TensorFlow Lite interpreter and a Signature
dataclass to define
the model's input and output specifications.
Classes:
Name | Description |
---|---|
Signature |
Defines the input and output signature of a TensorFlow Lite model. |
TFLiteModel |
A wrapper class for TensorFlow Lite audio classification models. |
Functions:
Name | Description |
---|---|
load_model |
Load a TensorFlow Lite model from a file. |
Classes#
Signature(input_index, classification_index, feature_index, input_length, input_dtype=np.float32)
dataclass
#
Defines the input and output signature of a TensorFlow Lite model.
Attributes:
Name | Type | Description |
---|---|---|
classification_index |
int
|
The index of the tensor containing classification probabilities. |
feature_index |
int
|
The index of the tensor containing extracted features. |
input_dtype |
DTypeLike
|
The data type of the input tensor. Defaults to np.float32. |
input_index |
int
|
The index of the input tensor in the model. |
input_length |
int
|
The number of audio samples expected in the input tensor. |
Attributes#
classification_index: int
instance-attribute
#
The index of the tensor containing classification probabilities.
feature_index: int
instance-attribute
#
The index of the tensor containing extracted features.
input_dtype: DTypeLike = np.float32
class-attribute
instance-attribute
#
The data type of the input tensor. Defaults to np.float32.
input_index: int
instance-attribute
#
The index of the input tensor in the model.
input_length: int
instance-attribute
#
The number of audio samples expected in the input tensor.
TFLiteModel(interpreter, signature, tags, confidence_threshold, samplerate, name, logits=True, batch_size=8)
#
Bases: ClipClassificationModel
A wrapper class for TensorFlow Lite audio classification models.
This class provides a standardized interface for interacting with TensorFlow Lite models, allowing them to be used seamlessly with the audioclass library.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
interpreter
|
Interpreter
|
The TensorFlow Lite interpreter object. |
required |
signature
|
Signature
|
The input and output signature of the model. |
required |
tags
|
List[Tag]
|
The list of tags that the model can predict. |
required |
confidence_threshold
|
float
|
The minimum confidence threshold for assigning a tag to a clip. |
required |
samplerate
|
int
|
The sample rate of the audio data expected by the model (in Hz). |
required |
name
|
str
|
The name of the model. |
required |
logits
|
bool
|
Whether the model outputs logits (True) or probabilities (False). Defaults to True. |
True
|
batch_size
|
int
|
The maximum number of frames to process in each batch. Defaults to 8. |
8
|
Methods:
Name | Description |
---|---|
process_array |
Process a single audio array and return the model output. |
Attributes:
Name | Type | Description |
---|---|---|
batch_size |
|
|
confidence_threshold |
|
|
input_samples |
|
|
interpreter |
Interpreter
|
The TensorFlow Lite interpreter object. |
logits |
|
|
name |
|
|
num_classes |
|
|
samplerate |
|
|
signature |
Signature
|
The input and output signature of the model. |
tags |
|
Attributes#
batch_size = batch_size
instance-attribute
#
confidence_threshold = confidence_threshold
instance-attribute
#
input_samples = signature.input_length
instance-attribute
#
interpreter: Interpreter = interpreter
instance-attribute
#
The TensorFlow Lite interpreter object.
logits = logits
instance-attribute
#
name = name
instance-attribute
#
num_classes = len(tags)
instance-attribute
#
samplerate = samplerate
instance-attribute
#
signature: Signature = signature
instance-attribute
#
The input and output signature of the model.
tags = tags
instance-attribute
#
Functions#
process_array(array)
#
Process a single audio array and return the model output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array
|
ndarray
|
The audio array to be processed, with shape
|
required |
Returns:
Type | Description |
---|---|
ModelOutput
|
A |
Note
This is a low-level method that requires manual batching of
the input audio array. If you prefer a higher-level
interface that handles batching automatically, consider
using process_file
, process_recording
, or process_clip
instead.
Be aware that passing an array with a large batch size may exceed available device memory and cause the process to crash.
Functions#
load_model(path, num_threads=None)
#
Load a TensorFlow Lite model from a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Union[Path, str]
|
The path to the TensorFlow Lite model file. |
required |
num_threads
|
Optional[int]
|
The number of threads to use for inference. If None, the default number of threads will be used. |
None
|
Returns:
Type | Description |
---|---|
Interpreter
|
The TensorFlow Lite interpreter object. |
process_array(interpreter, signature, array, validate_signature=False, logits=True)
#
Process an array with a TF Lite model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
interpreter
|
Interpreter
|
The TF Lite model interpreter. |
required |
signature
|
Signature
|
The input and output signature of the model. |
required |
array
|
ndarray
|
The audio array to be processed, with shape (num_frames, input_samples) or (input_samples,). |
required |
validate_signature
|
bool
|
Whether to validate the model signature. Defaults to False. |
False
|
logits
|
bool
|
Whether the model outputs logits (True) or probabilities (False). Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
ModelOutput
|
A |
Raises:
Type | Description |
---|---|
ValueError
|
If the input array has the wrong shape or if the model signature is invalid. |