223 lines
5.8 KiB
Go
223 lines
5.8 KiB
Go
package diarization
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os/exec"
|
|
)
|
|
|
|
// SpeakerSegment represents a segment with speaker identification
|
|
type SpeakerSegment struct {
|
|
Speaker string `json:"speaker"` // "Speaker 1", "Speaker 2", etc.
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
}
|
|
|
|
// DiarizationResult contains the speaker diarization output
|
|
type DiarizationResult struct {
|
|
Speakers []SpeakerSegment `json:"speakers"`
|
|
NumSpeakers int `json:"num_speakers"`
|
|
}
|
|
|
|
// Client handles speaker diarization using resemblyzer
|
|
type Client struct{}
|
|
|
|
// NewClient creates a new diarization client
|
|
func NewClient() *Client {
|
|
return &Client{}
|
|
}
|
|
|
|
// DiarizationOptions contains options for diarization
|
|
type DiarizationOptions struct {
|
|
NumSpeakers int // Number of speakers (0 = auto-detect)
|
|
}
|
|
|
|
// DefaultDiarizationOptions returns default diarization options
|
|
func DefaultDiarizationOptions() *DiarizationOptions {
|
|
return &DiarizationOptions{
|
|
NumSpeakers: 0, // Auto-detect
|
|
}
|
|
}
|
|
|
|
// Diarize processes an audio file and returns speaker segments
|
|
func (c *Client) Diarize(audioPath string, options *DiarizationOptions) (*DiarizationResult, error) {
|
|
if options == nil {
|
|
options = DefaultDiarizationOptions()
|
|
}
|
|
|
|
// Build the Python command
|
|
cmd := exec.Command("python3", "-c", c.buildPythonCommand(audioPath, options))
|
|
|
|
// Capture stdout and stderr
|
|
var out bytes.Buffer
|
|
var errBuf bytes.Buffer
|
|
cmd.Stdout = &out
|
|
cmd.Stderr = &errBuf
|
|
|
|
// Execute the command
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("diarization failed: %v, stderr: %s", err, errBuf.String())
|
|
}
|
|
|
|
// Parse the JSON output
|
|
var result DiarizationResult
|
|
err = json.Unmarshal(out.Bytes(), &result)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse diarization output: %v, output: %s", err, out.String())
|
|
}
|
|
|
|
return &result, nil
|
|
}
|
|
|
|
// buildPythonCommand constructs the Python command for diarization
|
|
func (c *Client) buildPythonCommand(audioPath string, options *DiarizationOptions) string {
|
|
numSpeakersStr := "None"
|
|
if options.NumSpeakers > 0 {
|
|
numSpeakersStr = fmt.Sprintf("%d", options.NumSpeakers)
|
|
}
|
|
|
|
pythonCode := fmt.Sprintf(`
|
|
import json
|
|
import sys
|
|
import os
|
|
import warnings
|
|
import numpy as np
|
|
|
|
# Suppress warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
# Redirect both stdout and stderr during imports to suppress library noise
|
|
old_stdout = sys.stdout
|
|
old_stderr = sys.stderr
|
|
sys.stdout = open(os.devnull, 'w')
|
|
sys.stderr = open(os.devnull, 'w')
|
|
|
|
from resemblyzer import VoiceEncoder, preprocess_wav
|
|
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
|
|
import librosa
|
|
|
|
# Initialize voice encoder while stdout is suppressed (it prints loading message)
|
|
encoder = VoiceEncoder()
|
|
|
|
# Restore stdout/stderr
|
|
sys.stdout = old_stdout
|
|
sys.stderr = old_stderr
|
|
|
|
# Configuration
|
|
AUDIO_PATH = "%s"
|
|
NUM_SPEAKERS = %s
|
|
SEGMENT_DURATION = 1.5 # seconds per segment for embedding extraction
|
|
HOP_DURATION = 0.75 # hop between segments
|
|
|
|
# Load audio
|
|
audio, sr = librosa.load(AUDIO_PATH, sr=16000)
|
|
duration = len(audio) / sr
|
|
|
|
# Extract embeddings for overlapping segments
|
|
embeddings = []
|
|
timestamps = []
|
|
current_time = 0.0
|
|
|
|
while current_time + SEGMENT_DURATION <= duration:
|
|
start_sample = int(current_time * sr)
|
|
end_sample = int((current_time + SEGMENT_DURATION) * sr)
|
|
segment = audio[start_sample:end_sample]
|
|
|
|
# Skip silent segments
|
|
if np.abs(segment).mean() > 0.01:
|
|
try:
|
|
wav = preprocess_wav(segment, source_sr=sr)
|
|
if len(wav) > 0:
|
|
embedding = encoder.embed_utterance(wav)
|
|
embeddings.append(embedding)
|
|
timestamps.append((current_time, current_time + SEGMENT_DURATION))
|
|
except:
|
|
pass
|
|
|
|
current_time += HOP_DURATION
|
|
|
|
# Handle edge cases
|
|
if len(embeddings) == 0:
|
|
print(json.dumps({"speakers": [], "num_speakers": 0}))
|
|
sys.exit(0)
|
|
|
|
embeddings = np.array(embeddings)
|
|
|
|
# Determine number of speakers
|
|
if NUM_SPEAKERS is None or NUM_SPEAKERS <= 0:
|
|
# Auto-detect using silhouette score
|
|
from sklearn.metrics import silhouette_score
|
|
best_n = 2
|
|
best_score = -1
|
|
for n in range(2, min(6, len(embeddings))):
|
|
try:
|
|
clustering = AgglomerativeClustering(n_clusters=n)
|
|
labels = clustering.fit_predict(embeddings)
|
|
score = silhouette_score(embeddings, labels)
|
|
if score > best_score:
|
|
best_score = score
|
|
best_n = n
|
|
except:
|
|
pass
|
|
num_speakers = best_n
|
|
else:
|
|
num_speakers = NUM_SPEAKERS
|
|
|
|
# Cluster embeddings
|
|
try:
|
|
if len(embeddings) >= num_speakers:
|
|
clustering = AgglomerativeClustering(n_clusters=num_speakers)
|
|
labels = clustering.fit_predict(embeddings)
|
|
else:
|
|
labels = list(range(len(embeddings)))
|
|
num_speakers = len(embeddings)
|
|
except Exception as e:
|
|
labels = [0] * len(embeddings)
|
|
num_speakers = 1
|
|
|
|
# Build speaker segments with merging of consecutive same-speaker segments
|
|
speaker_segments = []
|
|
prev_speaker = None
|
|
prev_start = None
|
|
prev_end = None
|
|
|
|
for i, (start, end) in enumerate(timestamps):
|
|
speaker = f"Speaker {labels[i] + 1}"
|
|
|
|
if speaker == prev_speaker and prev_end is not None:
|
|
# Extend previous segment if same speaker and close in time
|
|
if start - prev_end < 0.5:
|
|
prev_end = end
|
|
continue
|
|
|
|
# Save previous segment
|
|
if prev_speaker is not None:
|
|
speaker_segments.append({
|
|
"speaker": prev_speaker,
|
|
"start": prev_start,
|
|
"end": prev_end
|
|
})
|
|
|
|
prev_speaker = speaker
|
|
prev_start = start
|
|
prev_end = end
|
|
|
|
# Don't forget the last segment
|
|
if prev_speaker is not None:
|
|
speaker_segments.append({
|
|
"speaker": prev_speaker,
|
|
"start": prev_start,
|
|
"end": prev_end
|
|
})
|
|
|
|
print(json.dumps({
|
|
"speakers": speaker_segments,
|
|
"num_speakers": num_speakers
|
|
}))
|
|
`, audioPath, numSpeakersStr)
|
|
|
|
return pythonCode
|
|
}
|