Files
transcribe/internal/diarization/client.go
2026-01-17 19:18:58 -06:00

223 lines
5.8 KiB
Go

package diarization
import (
"bytes"
"encoding/json"
"fmt"
"os/exec"
)
// SpeakerSegment represents a segment with speaker identification
type SpeakerSegment struct {
Speaker string `json:"speaker"` // "Speaker 1", "Speaker 2", etc.
Start float64 `json:"start"`
End float64 `json:"end"`
}
// DiarizationResult contains the speaker diarization output
type DiarizationResult struct {
Speakers []SpeakerSegment `json:"speakers"`
NumSpeakers int `json:"num_speakers"`
}
// Client handles speaker diarization using resemblyzer
type Client struct{}
// NewClient creates a new diarization client
func NewClient() *Client {
return &Client{}
}
// DiarizationOptions contains options for diarization
type DiarizationOptions struct {
NumSpeakers int // Number of speakers (0 = auto-detect)
}
// DefaultDiarizationOptions returns default diarization options
func DefaultDiarizationOptions() *DiarizationOptions {
return &DiarizationOptions{
NumSpeakers: 0, // Auto-detect
}
}
// Diarize processes an audio file and returns speaker segments
func (c *Client) Diarize(audioPath string, options *DiarizationOptions) (*DiarizationResult, error) {
if options == nil {
options = DefaultDiarizationOptions()
}
// Build the Python command
cmd := exec.Command("python3", "-c", c.buildPythonCommand(audioPath, options))
// Capture stdout and stderr
var out bytes.Buffer
var errBuf bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &errBuf
// Execute the command
err := cmd.Run()
if err != nil {
return nil, fmt.Errorf("diarization failed: %v, stderr: %s", err, errBuf.String())
}
// Parse the JSON output
var result DiarizationResult
err = json.Unmarshal(out.Bytes(), &result)
if err != nil {
return nil, fmt.Errorf("failed to parse diarization output: %v, output: %s", err, out.String())
}
return &result, nil
}
// buildPythonCommand constructs the Python command for diarization
func (c *Client) buildPythonCommand(audioPath string, options *DiarizationOptions) string {
numSpeakersStr := "None"
if options.NumSpeakers > 0 {
numSpeakersStr = fmt.Sprintf("%d", options.NumSpeakers)
}
pythonCode := fmt.Sprintf(`
import json
import sys
import os
import warnings
import numpy as np
# Suppress warnings
warnings.filterwarnings("ignore")
# Redirect both stdout and stderr during imports to suppress library noise
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = open(os.devnull, 'w')
sys.stderr = open(os.devnull, 'w')
from resemblyzer import VoiceEncoder, preprocess_wav
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
import librosa
# Initialize voice encoder while stdout is suppressed (it prints loading message)
encoder = VoiceEncoder()
# Restore stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Configuration
AUDIO_PATH = "%s"
NUM_SPEAKERS = %s
SEGMENT_DURATION = 1.5 # seconds per segment for embedding extraction
HOP_DURATION = 0.75 # hop between segments
# Load audio
audio, sr = librosa.load(AUDIO_PATH, sr=16000)
duration = len(audio) / sr
# Extract embeddings for overlapping segments
embeddings = []
timestamps = []
current_time = 0.0
while current_time + SEGMENT_DURATION <= duration:
start_sample = int(current_time * sr)
end_sample = int((current_time + SEGMENT_DURATION) * sr)
segment = audio[start_sample:end_sample]
# Skip silent segments
if np.abs(segment).mean() > 0.01:
try:
wav = preprocess_wav(segment, source_sr=sr)
if len(wav) > 0:
embedding = encoder.embed_utterance(wav)
embeddings.append(embedding)
timestamps.append((current_time, current_time + SEGMENT_DURATION))
except:
pass
current_time += HOP_DURATION
# Handle edge cases
if len(embeddings) == 0:
print(json.dumps({"speakers": [], "num_speakers": 0}))
sys.exit(0)
embeddings = np.array(embeddings)
# Determine number of speakers
if NUM_SPEAKERS is None or NUM_SPEAKERS <= 0:
# Auto-detect using silhouette score
from sklearn.metrics import silhouette_score
best_n = 2
best_score = -1
for n in range(2, min(6, len(embeddings))):
try:
clustering = AgglomerativeClustering(n_clusters=n)
labels = clustering.fit_predict(embeddings)
score = silhouette_score(embeddings, labels)
if score > best_score:
best_score = score
best_n = n
except:
pass
num_speakers = best_n
else:
num_speakers = NUM_SPEAKERS
# Cluster embeddings
try:
if len(embeddings) >= num_speakers:
clustering = AgglomerativeClustering(n_clusters=num_speakers)
labels = clustering.fit_predict(embeddings)
else:
labels = list(range(len(embeddings)))
num_speakers = len(embeddings)
except Exception as e:
labels = [0] * len(embeddings)
num_speakers = 1
# Build speaker segments with merging of consecutive same-speaker segments
speaker_segments = []
prev_speaker = None
prev_start = None
prev_end = None
for i, (start, end) in enumerate(timestamps):
speaker = f"Speaker {labels[i] + 1}"
if speaker == prev_speaker and prev_end is not None:
# Extend previous segment if same speaker and close in time
if start - prev_end < 0.5:
prev_end = end
continue
# Save previous segment
if prev_speaker is not None:
speaker_segments.append({
"speaker": prev_speaker,
"start": prev_start,
"end": prev_end
})
prev_speaker = speaker
prev_start = start
prev_end = end
# Don't forget the last segment
if prev_speaker is not None:
speaker_segments.append({
"speaker": prev_speaker,
"start": prev_start,
"end": prev_end
})
print(json.dumps({
"speakers": speaker_segments,
"num_speakers": num_speakers
}))
`, audioPath, numSpeakersStr)
return pythonCode
}