feat: git init
This commit is contained in:
59
internal/diarization/align.go
Normal file
59
internal/diarization/align.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package diarization
|
||||
|
||||
import (
|
||||
"transcribe/internal/whisper"
|
||||
)
|
||||
|
||||
// AlignSpeakers maps speaker segments to transcription segments by timestamp overlap
|
||||
func AlignSpeakers(transcription *whisper.TranscriptionResult, diarization *DiarizationResult) {
|
||||
if diarization == nil || len(diarization.Speakers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transcription.Segments {
|
||||
seg := &transcription.Segments[i]
|
||||
speaker := findSpeakerForSegment(seg.Start, seg.End, diarization.Speakers)
|
||||
seg.Speaker = speaker
|
||||
}
|
||||
}
|
||||
|
||||
// findSpeakerForSegment finds the speaker with the most overlap with the given time range
|
||||
func findSpeakerForSegment(start, end float64, speakers []SpeakerSegment) string {
|
||||
var bestSpeaker string
|
||||
var maxOverlap float64
|
||||
|
||||
for _, spk := range speakers {
|
||||
overlap := calculateOverlap(start, end, spk.Start, spk.End)
|
||||
if overlap > maxOverlap {
|
||||
maxOverlap = overlap
|
||||
bestSpeaker = spk.Speaker
|
||||
}
|
||||
}
|
||||
|
||||
return bestSpeaker
|
||||
}
|
||||
|
||||
// calculateOverlap returns the duration of overlap between two time ranges
|
||||
func calculateOverlap(start1, end1, start2, end2 float64) float64 {
|
||||
overlapStart := max(start1, start2)
|
||||
overlapEnd := min(end1, end2)
|
||||
|
||||
if overlapEnd > overlapStart {
|
||||
return overlapEnd - overlapStart
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func max(a, b float64) float64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
222
internal/diarization/client.go
Normal file
222
internal/diarization/client.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package diarization
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// SpeakerSegment represents a segment with speaker identification
|
||||
type SpeakerSegment struct {
|
||||
Speaker string `json:"speaker"` // "Speaker 1", "Speaker 2", etc.
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
}
|
||||
|
||||
// DiarizationResult contains the speaker diarization output
|
||||
type DiarizationResult struct {
|
||||
Speakers []SpeakerSegment `json:"speakers"`
|
||||
NumSpeakers int `json:"num_speakers"`
|
||||
}
|
||||
|
||||
// Client handles speaker diarization using resemblyzer
|
||||
type Client struct{}
|
||||
|
||||
// NewClient creates a new diarization client
|
||||
func NewClient() *Client {
|
||||
return &Client{}
|
||||
}
|
||||
|
||||
// DiarizationOptions contains options for diarization
|
||||
type DiarizationOptions struct {
|
||||
NumSpeakers int // Number of speakers (0 = auto-detect)
|
||||
}
|
||||
|
||||
// DefaultDiarizationOptions returns default diarization options
|
||||
func DefaultDiarizationOptions() *DiarizationOptions {
|
||||
return &DiarizationOptions{
|
||||
NumSpeakers: 0, // Auto-detect
|
||||
}
|
||||
}
|
||||
|
||||
// Diarize processes an audio file and returns speaker segments
|
||||
func (c *Client) Diarize(audioPath string, options *DiarizationOptions) (*DiarizationResult, error) {
|
||||
if options == nil {
|
||||
options = DefaultDiarizationOptions()
|
||||
}
|
||||
|
||||
// Build the Python command
|
||||
cmd := exec.Command("python3", "-c", c.buildPythonCommand(audioPath, options))
|
||||
|
||||
// Capture stdout and stderr
|
||||
var out bytes.Buffer
|
||||
var errBuf bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &errBuf
|
||||
|
||||
// Execute the command
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("diarization failed: %v, stderr: %s", err, errBuf.String())
|
||||
}
|
||||
|
||||
// Parse the JSON output
|
||||
var result DiarizationResult
|
||||
err = json.Unmarshal(out.Bytes(), &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse diarization output: %v, output: %s", err, out.String())
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// buildPythonCommand constructs the Python command for diarization
|
||||
func (c *Client) buildPythonCommand(audioPath string, options *DiarizationOptions) string {
|
||||
numSpeakersStr := "None"
|
||||
if options.NumSpeakers > 0 {
|
||||
numSpeakersStr = fmt.Sprintf("%d", options.NumSpeakers)
|
||||
}
|
||||
|
||||
pythonCode := fmt.Sprintf(`
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# Suppress warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# Redirect both stdout and stderr during imports to suppress library noise
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
sys.stderr = open(os.devnull, 'w')
|
||||
|
||||
from resemblyzer import VoiceEncoder, preprocess_wav
|
||||
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
|
||||
import librosa
|
||||
|
||||
# Initialize voice encoder while stdout is suppressed (it prints loading message)
|
||||
encoder = VoiceEncoder()
|
||||
|
||||
# Restore stdout/stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
# Configuration
|
||||
AUDIO_PATH = "%s"
|
||||
NUM_SPEAKERS = %s
|
||||
SEGMENT_DURATION = 1.5 # seconds per segment for embedding extraction
|
||||
HOP_DURATION = 0.75 # hop between segments
|
||||
|
||||
# Load audio
|
||||
audio, sr = librosa.load(AUDIO_PATH, sr=16000)
|
||||
duration = len(audio) / sr
|
||||
|
||||
# Extract embeddings for overlapping segments
|
||||
embeddings = []
|
||||
timestamps = []
|
||||
current_time = 0.0
|
||||
|
||||
while current_time + SEGMENT_DURATION <= duration:
|
||||
start_sample = int(current_time * sr)
|
||||
end_sample = int((current_time + SEGMENT_DURATION) * sr)
|
||||
segment = audio[start_sample:end_sample]
|
||||
|
||||
# Skip silent segments
|
||||
if np.abs(segment).mean() > 0.01:
|
||||
try:
|
||||
wav = preprocess_wav(segment, source_sr=sr)
|
||||
if len(wav) > 0:
|
||||
embedding = encoder.embed_utterance(wav)
|
||||
embeddings.append(embedding)
|
||||
timestamps.append((current_time, current_time + SEGMENT_DURATION))
|
||||
except:
|
||||
pass
|
||||
|
||||
current_time += HOP_DURATION
|
||||
|
||||
# Handle edge cases
|
||||
if len(embeddings) == 0:
|
||||
print(json.dumps({"speakers": [], "num_speakers": 0}))
|
||||
sys.exit(0)
|
||||
|
||||
embeddings = np.array(embeddings)
|
||||
|
||||
# Determine number of speakers
|
||||
if NUM_SPEAKERS is None or NUM_SPEAKERS <= 0:
|
||||
# Auto-detect using silhouette score
|
||||
from sklearn.metrics import silhouette_score
|
||||
best_n = 2
|
||||
best_score = -1
|
||||
for n in range(2, min(6, len(embeddings))):
|
||||
try:
|
||||
clustering = AgglomerativeClustering(n_clusters=n)
|
||||
labels = clustering.fit_predict(embeddings)
|
||||
score = silhouette_score(embeddings, labels)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_n = n
|
||||
except:
|
||||
pass
|
||||
num_speakers = best_n
|
||||
else:
|
||||
num_speakers = NUM_SPEAKERS
|
||||
|
||||
# Cluster embeddings
|
||||
try:
|
||||
if len(embeddings) >= num_speakers:
|
||||
clustering = AgglomerativeClustering(n_clusters=num_speakers)
|
||||
labels = clustering.fit_predict(embeddings)
|
||||
else:
|
||||
labels = list(range(len(embeddings)))
|
||||
num_speakers = len(embeddings)
|
||||
except Exception as e:
|
||||
labels = [0] * len(embeddings)
|
||||
num_speakers = 1
|
||||
|
||||
# Build speaker segments with merging of consecutive same-speaker segments
|
||||
speaker_segments = []
|
||||
prev_speaker = None
|
||||
prev_start = None
|
||||
prev_end = None
|
||||
|
||||
for i, (start, end) in enumerate(timestamps):
|
||||
speaker = f"Speaker {labels[i] + 1}"
|
||||
|
||||
if speaker == prev_speaker and prev_end is not None:
|
||||
# Extend previous segment if same speaker and close in time
|
||||
if start - prev_end < 0.5:
|
||||
prev_end = end
|
||||
continue
|
||||
|
||||
# Save previous segment
|
||||
if prev_speaker is not None:
|
||||
speaker_segments.append({
|
||||
"speaker": prev_speaker,
|
||||
"start": prev_start,
|
||||
"end": prev_end
|
||||
})
|
||||
|
||||
prev_speaker = speaker
|
||||
prev_start = start
|
||||
prev_end = end
|
||||
|
||||
# Don't forget the last segment
|
||||
if prev_speaker is not None:
|
||||
speaker_segments.append({
|
||||
"speaker": prev_speaker,
|
||||
"start": prev_start,
|
||||
"end": prev_end
|
||||
})
|
||||
|
||||
print(json.dumps({
|
||||
"speakers": speaker_segments,
|
||||
"num_speakers": num_speakers
|
||||
}))
|
||||
`, audioPath, numSpeakersStr)
|
||||
|
||||
return pythonCode
|
||||
}
|
||||
Reference in New Issue
Block a user