Trains CNN classifiers from raw speech using Keras and tests them.
Project description
Raw Speech Classification
Trains CNN (or any neural network based) classifiers from raw speech using Keras and tests them. The inputs are lists of wav files, where each file is labelled. It then creates fixed length signals and processes them. During testing, it computes scores at the utterance or speaker levels by averaging the corresponding frame-level scores from the fixed length signals.
Installation
From source in a conda environment
To install Keras 3 with PyTorch backend, run:
conda env create -f conda/rsclf-pytorch.yaml
To install Keras 3 with TensorFlow backend, run:
conda env create -f conda/rsclf-tensorflow.yaml
To install Keras 3 with Jax backend, run:
conda env create -f conda/rsclf-jax.yaml
Then install the package in that environment (the default name is rsclf
) with:
conda run -n rsclf pip install .
Installing from PyPI
If you want to install the last release of this package in your current environment, you can run either of the following commands depending on your desired framework:
pip install raw-speech-classification[torch]
or
pip install raw-speech-classification[tensorflow]
or
pip install raw-speech-classification[jax]
If you already have an environment with PyTorch, TensorFlow, or Jax installed, you can simply run:
pip install raw-speech-classification
You will also need to set the KERAS_BACKEND
environment variable to the correct backend
before running rsclf-train
or rsclf-test
(see below), or globally for the current
bash session with:
export KERAS_BACKEND=torch
Replace torch
by tensorflow
or jax
accordingly.
Using the code
-
Create lists for training, cross-validation and testing. Each line in a list must contain the path to a wav file (relative to the
-R
or--root
option), followed by its integer label indexed from 0, separated by a space. E.g. if your data files are in/home/bob/data/my_dataset/part*/file*.wav
, theroot
option could be/home/bob/data/my_dataset
and the content of the files would then be like:part1/file1.wav 1 part1/file2.wav 0
Full list files for IEMOCAP are available in the repository as example in
datasets/IEMOCAP/F1_lists
. -
If you installed from source with Conda: A
run
script is available that concatenates all the steps. Runrun.sh
. Provide the model architecture as an argument. Seemodel_architecture.py
for valid options. Optionally, provide an integer as a count of the number of times the experiment is repeated. This is useful when the same experiment needs to be repeated multiple times with different initialization. The argument defaults to 1.If you installed with pip: You can run the following commands (give the
--help
option to each command for more details):rsclf-wav2feat --wav-list-file list_files/cv.list --feature-dir output/cv_feat --mode train --root path/to/dataset/basedir rsclf-wav2feat --wav-list-file list_files/train.list --feature-dir output/train_feat --mode train --root path/to/dataset/basedir rsclf-wav2feat --wav-list-file list_files/test.list --feature-dir output/test_feat --mode test --root path/to/dataset/basedir KERAS_BACKEND=torch rsclf-train --train-feature-dir output/train_feat --validation-feature-dir output/cv_feat --output-dir output/cnn_subseg --arch subseg --splice-size 25 --verbose 2 KERAS_BACKEND=torch rsclf-test --feature-dir output/test_feat --model-filename output/cnn_subseg/cnn.keras --output-dir output/cnn_subseg --splice-size 25 --verbose 0 rsclf-plot --output-dir output/ output/cnn_subseg
This is an example of how to run on the IEMOCAP dataset using conda, assuming conda is
installed in ~/miniconda3
and your environment is rsclf
:
bash run.sh -C ~/miniconda3 -n rsclf -D ./datasets/IEMOCAP/F1_lists -a seg -o results/seg-f1 -R <IEMOCAP_ROOT>
For instance, <IEMOCAP_ROOT>
can be /ssd/data/IEMOCAP
which should
contain IEMOCAP_full_release/Session*
.
This is an example of the log printed to the terminal, and you should
obtain the following curve in results/seg-f1/plot.png
:
Code components
-
wav2feat.py
creates directories where the wav files are stored as fixed length frames for faster access during training and testing. -
train.py
is the Keras training script. -
Model architecture can be configured in
model_architecture.py
. -
rawdataset.py
provides an object that reads the saved directories in batches and retrieves mini-batches for training. -
test.py
performs the testing and generates scores as posterior probabilities. If you need the results per speaker, configure it accordingly (see the script for details). The default output format is:<speakerID> <label> [<posterior_probability_vector>]
-
plot.py
generates and saves the learning curves.
Training schedule
The script uses stochastic gradient descent with 0.5 momentum. It starts with a learning rate of 0.1 for a minimum of 5 epochs. Whenever the validation loss reduces by less than 0.002 between successive epochs, the learning rate is halved. Halving is performed until the learning rate reaches 1e-7.
Contributors
Idiap Research Institute
Authors: S. Pavankumar Dubagunta and Dr. Mathew Magimai-Doss
License
GNU GPL v3
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file raw_speech_classification-1.0.2.tar.gz
.
File metadata
- Download URL: raw_speech_classification-1.0.2.tar.gz
- Upload date:
- Size: 24.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c8f71c2b99439e644cd3f4296205895459cbada2506feb1190e95e5f60517e40 |
|
MD5 | 715102c1d3fe32a85f19815447b28ad2 |
|
BLAKE2b-256 | 83397d98bb3e54e07de46dd66f0086b4a5ae9631673475618e74e580ceb40843 |
File details
Details for the file raw_speech_classification-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: raw_speech_classification-1.0.2-py3-none-any.whl
- Upload date:
- Size: 14.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 96026150fac38a4429aaf9f89f64ce07f70ad0aac2d12f41d0a5712e69179950 |
|
MD5 | efa11120b935bf0c4aac5cfd8d1865b3 |
|
BLAKE2b-256 | ee2f8c0716b50d7f23e5c63de8a0963a8a8a8f5461b5021f675c81c047e6f47b |