tensorflow#
Warning
Additional Dependencies: To use the Tensorflow models, you will need to install additional dependencies. These dependencies are optional, as they can be heavy and may not be needed in all use cases. To install them, run:
audioclass.models.tensorflow
#
Module for defining TensorFlow-based audio classification models.
This module provides classes and functions for creating and using TensorFlow
models for audio classification tasks. It includes a TensorflowModel
class
that wraps a TensorFlow callable and a Signature
dataclass to define the
model's input and output specifications.
Classes#
Signature(input_name, classification_name, feature_name, input_length, input_dtype=np.float32)
dataclass
#
Defines the input and output signature of a TensorFlow model.
Attributes#
classification_name: str
instance-attribute
#
The name of the output tensor containing classification probabilities.
feature_name: str
instance-attribute
#
The name of the output tensor containing extracted features.
input_dtype: DTypeLike = np.float32
class-attribute
instance-attribute
#
The data type of the input tensor. Defaults to np.float32.
input_length: int
instance-attribute
#
The number of samples expected in the input tensor.
input_name: str
instance-attribute
#
The name of the input tensor.
TensorflowModel(callable, signature, tags, confidence_threshold, samplerate, name, logits=True)
#
Bases: ClipClassificationModel
A wrapper class for TensorFlow audio classification models.
This class provides a standardized interface for interacting with TensorFlow models, allowing them to be used seamlessly with the audioclass library.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
callable |
Callable
|
The TensorFlow callable representing the model. |
required |
signature |
Signature
|
The input and output signature of the model. |
required |
tags |
List[Tag]
|
The list of tags that the model can predict. |
required |
confidence_threshold |
float
|
The minimum confidence threshold for assigning a tag to a clip. |
required |
samplerate |
int
|
The sample rate of the audio data expected by the model (in Hz). |
required |
name |
str
|
The name of the model. |
required |
logits |
bool
|
Whether the model outputs logits (True) or probabilities (False). Defaults to True. |
True
|
Attributes#
callable: Callable = callable
instance-attribute
#
The TensorFlow callable representing the model.
confidence_threshold = confidence_threshold
instance-attribute
#
input_samples = signature.input_length
instance-attribute
#
logits = logits
instance-attribute
#
name = name
instance-attribute
#
num_classes = len(tags)
instance-attribute
#
samplerate = samplerate
instance-attribute
#
signature: Signature = signature
instance-attribute
#
The input and output signature of the model.
tags = tags
instance-attribute
#
Functions#
process_array(array)
#
Process a single audio array and return the model output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array |
ndarray
|
The audio array to be processed, with shape
|
required |
Returns:
Type | Description |
---|---|
ModelOutput
|
A |
Functions#
process_array(call, signature, array, validate_signature=False, logits=True)
#
Process an array with a TensorFlow model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
call |
Callable
|
The TensorFlow callable representing the model. |
required |
signature |
Signature
|
The input and output signature of the model. |
required |
array |
ndarray
|
The audio array to be processed, with shape (num_frames, input_samples) or (input_samples,). |
required |
validate_signature |
bool
|
Whether to validate the model signature. Defaults to False. |
False
|
logits |
bool
|
Whether the model outputs logits (True) or probabilities (False). Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
ModelOutput
|
A |
Raises:
Type | Description |
---|---|
ValueError
|
If the input array has the wrong shape or if the model signature is invalid. |