feat: git init
This commit is contained in:
59
internal/diarization/align.go
Normal file
59
internal/diarization/align.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package diarization
|
||||
|
||||
import (
|
||||
"transcribe/internal/whisper"
|
||||
)
|
||||
|
||||
// AlignSpeakers maps speaker segments to transcription segments by timestamp overlap
|
||||
func AlignSpeakers(transcription *whisper.TranscriptionResult, diarization *DiarizationResult) {
|
||||
if diarization == nil || len(diarization.Speakers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range transcription.Segments {
|
||||
seg := &transcription.Segments[i]
|
||||
speaker := findSpeakerForSegment(seg.Start, seg.End, diarization.Speakers)
|
||||
seg.Speaker = speaker
|
||||
}
|
||||
}
|
||||
|
||||
// findSpeakerForSegment finds the speaker with the most overlap with the given time range
|
||||
func findSpeakerForSegment(start, end float64, speakers []SpeakerSegment) string {
|
||||
var bestSpeaker string
|
||||
var maxOverlap float64
|
||||
|
||||
for _, spk := range speakers {
|
||||
overlap := calculateOverlap(start, end, spk.Start, spk.End)
|
||||
if overlap > maxOverlap {
|
||||
maxOverlap = overlap
|
||||
bestSpeaker = spk.Speaker
|
||||
}
|
||||
}
|
||||
|
||||
return bestSpeaker
|
||||
}
|
||||
|
||||
// calculateOverlap returns the duration of overlap between two time ranges
|
||||
func calculateOverlap(start1, end1, start2, end2 float64) float64 {
|
||||
overlapStart := max(start1, start2)
|
||||
overlapEnd := min(end1, end2)
|
||||
|
||||
if overlapEnd > overlapStart {
|
||||
return overlapEnd - overlapStart
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func max(a, b float64) float64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
222
internal/diarization/client.go
Normal file
222
internal/diarization/client.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package diarization
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// SpeakerSegment represents a segment with speaker identification
|
||||
type SpeakerSegment struct {
|
||||
Speaker string `json:"speaker"` // "Speaker 1", "Speaker 2", etc.
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
}
|
||||
|
||||
// DiarizationResult contains the speaker diarization output
|
||||
type DiarizationResult struct {
|
||||
Speakers []SpeakerSegment `json:"speakers"`
|
||||
NumSpeakers int `json:"num_speakers"`
|
||||
}
|
||||
|
||||
// Client handles speaker diarization using resemblyzer
|
||||
type Client struct{}
|
||||
|
||||
// NewClient creates a new diarization client
|
||||
func NewClient() *Client {
|
||||
return &Client{}
|
||||
}
|
||||
|
||||
// DiarizationOptions contains options for diarization
|
||||
type DiarizationOptions struct {
|
||||
NumSpeakers int // Number of speakers (0 = auto-detect)
|
||||
}
|
||||
|
||||
// DefaultDiarizationOptions returns default diarization options
|
||||
func DefaultDiarizationOptions() *DiarizationOptions {
|
||||
return &DiarizationOptions{
|
||||
NumSpeakers: 0, // Auto-detect
|
||||
}
|
||||
}
|
||||
|
||||
// Diarize processes an audio file and returns speaker segments
|
||||
func (c *Client) Diarize(audioPath string, options *DiarizationOptions) (*DiarizationResult, error) {
|
||||
if options == nil {
|
||||
options = DefaultDiarizationOptions()
|
||||
}
|
||||
|
||||
// Build the Python command
|
||||
cmd := exec.Command("python3", "-c", c.buildPythonCommand(audioPath, options))
|
||||
|
||||
// Capture stdout and stderr
|
||||
var out bytes.Buffer
|
||||
var errBuf bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &errBuf
|
||||
|
||||
// Execute the command
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("diarization failed: %v, stderr: %s", err, errBuf.String())
|
||||
}
|
||||
|
||||
// Parse the JSON output
|
||||
var result DiarizationResult
|
||||
err = json.Unmarshal(out.Bytes(), &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse diarization output: %v, output: %s", err, out.String())
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// buildPythonCommand constructs the Python command for diarization
|
||||
func (c *Client) buildPythonCommand(audioPath string, options *DiarizationOptions) string {
|
||||
numSpeakersStr := "None"
|
||||
if options.NumSpeakers > 0 {
|
||||
numSpeakersStr = fmt.Sprintf("%d", options.NumSpeakers)
|
||||
}
|
||||
|
||||
pythonCode := fmt.Sprintf(`
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# Suppress warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# Redirect both stdout and stderr during imports to suppress library noise
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
sys.stderr = open(os.devnull, 'w')
|
||||
|
||||
from resemblyzer import VoiceEncoder, preprocess_wav
|
||||
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
|
||||
import librosa
|
||||
|
||||
# Initialize voice encoder while stdout is suppressed (it prints loading message)
|
||||
encoder = VoiceEncoder()
|
||||
|
||||
# Restore stdout/stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
# Configuration
|
||||
AUDIO_PATH = "%s"
|
||||
NUM_SPEAKERS = %s
|
||||
SEGMENT_DURATION = 1.5 # seconds per segment for embedding extraction
|
||||
HOP_DURATION = 0.75 # hop between segments
|
||||
|
||||
# Load audio
|
||||
audio, sr = librosa.load(AUDIO_PATH, sr=16000)
|
||||
duration = len(audio) / sr
|
||||
|
||||
# Extract embeddings for overlapping segments
|
||||
embeddings = []
|
||||
timestamps = []
|
||||
current_time = 0.0
|
||||
|
||||
while current_time + SEGMENT_DURATION <= duration:
|
||||
start_sample = int(current_time * sr)
|
||||
end_sample = int((current_time + SEGMENT_DURATION) * sr)
|
||||
segment = audio[start_sample:end_sample]
|
||||
|
||||
# Skip silent segments
|
||||
if np.abs(segment).mean() > 0.01:
|
||||
try:
|
||||
wav = preprocess_wav(segment, source_sr=sr)
|
||||
if len(wav) > 0:
|
||||
embedding = encoder.embed_utterance(wav)
|
||||
embeddings.append(embedding)
|
||||
timestamps.append((current_time, current_time + SEGMENT_DURATION))
|
||||
except:
|
||||
pass
|
||||
|
||||
current_time += HOP_DURATION
|
||||
|
||||
# Handle edge cases
|
||||
if len(embeddings) == 0:
|
||||
print(json.dumps({"speakers": [], "num_speakers": 0}))
|
||||
sys.exit(0)
|
||||
|
||||
embeddings = np.array(embeddings)
|
||||
|
||||
# Determine number of speakers
|
||||
if NUM_SPEAKERS is None or NUM_SPEAKERS <= 0:
|
||||
# Auto-detect using silhouette score
|
||||
from sklearn.metrics import silhouette_score
|
||||
best_n = 2
|
||||
best_score = -1
|
||||
for n in range(2, min(6, len(embeddings))):
|
||||
try:
|
||||
clustering = AgglomerativeClustering(n_clusters=n)
|
||||
labels = clustering.fit_predict(embeddings)
|
||||
score = silhouette_score(embeddings, labels)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_n = n
|
||||
except:
|
||||
pass
|
||||
num_speakers = best_n
|
||||
else:
|
||||
num_speakers = NUM_SPEAKERS
|
||||
|
||||
# Cluster embeddings
|
||||
try:
|
||||
if len(embeddings) >= num_speakers:
|
||||
clustering = AgglomerativeClustering(n_clusters=num_speakers)
|
||||
labels = clustering.fit_predict(embeddings)
|
||||
else:
|
||||
labels = list(range(len(embeddings)))
|
||||
num_speakers = len(embeddings)
|
||||
except Exception as e:
|
||||
labels = [0] * len(embeddings)
|
||||
num_speakers = 1
|
||||
|
||||
# Build speaker segments with merging of consecutive same-speaker segments
|
||||
speaker_segments = []
|
||||
prev_speaker = None
|
||||
prev_start = None
|
||||
prev_end = None
|
||||
|
||||
for i, (start, end) in enumerate(timestamps):
|
||||
speaker = f"Speaker {labels[i] + 1}"
|
||||
|
||||
if speaker == prev_speaker and prev_end is not None:
|
||||
# Extend previous segment if same speaker and close in time
|
||||
if start - prev_end < 0.5:
|
||||
prev_end = end
|
||||
continue
|
||||
|
||||
# Save previous segment
|
||||
if prev_speaker is not None:
|
||||
speaker_segments.append({
|
||||
"speaker": prev_speaker,
|
||||
"start": prev_start,
|
||||
"end": prev_end
|
||||
})
|
||||
|
||||
prev_speaker = speaker
|
||||
prev_start = start
|
||||
prev_end = end
|
||||
|
||||
# Don't forget the last segment
|
||||
if prev_speaker is not None:
|
||||
speaker_segments.append({
|
||||
"speaker": prev_speaker,
|
||||
"start": prev_start,
|
||||
"end": prev_end
|
||||
})
|
||||
|
||||
print(json.dumps({
|
||||
"speakers": speaker_segments,
|
||||
"num_speakers": num_speakers
|
||||
}))
|
||||
`, audioPath, numSpeakersStr)
|
||||
|
||||
return pythonCode
|
||||
}
|
||||
162
internal/whisper/client.go
Normal file
162
internal/whisper/client.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package whisper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// ModelSize represents the different Whisper model sizes
|
||||
type ModelSize string
|
||||
|
||||
const (
|
||||
ModelTiny ModelSize = "tiny"
|
||||
ModelBase ModelSize = "base"
|
||||
ModelSmall ModelSize = "small"
|
||||
ModelMedium ModelSize = "medium"
|
||||
ModelLarge ModelSize = "large"
|
||||
ModelTurbo ModelSize = "turbo"
|
||||
)
|
||||
|
||||
// TranscriptionResult contains the transcription output
|
||||
type TranscriptionResult struct {
|
||||
Text string `json:"text"`
|
||||
Segments []Segment `json:"segments"`
|
||||
Language string `json:"language"`
|
||||
Duration float64 `json:"duration"`
|
||||
}
|
||||
|
||||
// Segment represents a segment of transcription with timestamps
|
||||
type Segment struct {
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Words []Word `json:"words,omitempty"`
|
||||
Speaker string `json:"speaker,omitempty"`
|
||||
}
|
||||
|
||||
// Word represents a word with timestamp
|
||||
type Word struct {
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Word string `json:"word"`
|
||||
}
|
||||
|
||||
// Client is the Whisper client that handles transcription
|
||||
type Client struct {
|
||||
ModelPath string
|
||||
ModelSize ModelSize
|
||||
}
|
||||
|
||||
// NewClient creates a new Whisper client
|
||||
func NewClient(modelSize ModelSize) *Client {
|
||||
return &Client{
|
||||
ModelSize: modelSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Transcribe processes an audio file and returns transcription
|
||||
func (c *Client) Transcribe(audioPath string, options *TranscriptionOptions) (*TranscriptionResult, error) {
|
||||
if options == nil {
|
||||
options = &TranscriptionOptions{}
|
||||
}
|
||||
|
||||
// Build the Python command
|
||||
cmd := exec.Command("python3", "-c", c.buildPythonCommand(audioPath, options))
|
||||
|
||||
// Capture stdout and stderr
|
||||
var out bytes.Buffer
|
||||
var errBuf bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &errBuf
|
||||
|
||||
// Execute the command
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transcription failed: %v, stderr: %s", err, errBuf.String())
|
||||
}
|
||||
|
||||
// Parse the JSON output
|
||||
var result TranscriptionResult
|
||||
err = json.Unmarshal(out.Bytes(), &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse transcription output: %v", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// buildPythonCommand constructs the Python command for Whisper
|
||||
func (c *Client) buildPythonCommand(audioPath string, options *TranscriptionOptions) string {
|
||||
// Convert Go bool to Python bool string
|
||||
verboseStr := "False"
|
||||
if options.Verbose {
|
||||
verboseStr = "True"
|
||||
}
|
||||
|
||||
// Handle language option
|
||||
langStr := "None"
|
||||
if options.Language != "" && options.Language != "auto" {
|
||||
langStr = fmt.Sprintf(`"%s"`, options.Language)
|
||||
}
|
||||
|
||||
pythonCode := fmt.Sprintf(`
|
||||
import whisper
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import warnings
|
||||
|
||||
# Suppress warnings and stdout during transcription
|
||||
warnings.filterwarnings("ignore")
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model("%s")
|
||||
|
||||
# Transcribe
|
||||
result = model.transcribe("%s",
|
||||
language=%s,
|
||||
verbose=%s,
|
||||
temperature=%.1f,
|
||||
best_of=%d)
|
||||
|
||||
# Restore stdout for JSON output
|
||||
sys.stdout = old_stdout
|
||||
|
||||
# Output as JSON
|
||||
print(json.dumps({
|
||||
"text": result["text"],
|
||||
"language": result.get("language", ""),
|
||||
"duration": result.get("duration", 0.0),
|
||||
"segments": [{
|
||||
"start": seg["start"],
|
||||
"end": seg["end"],
|
||||
"text": seg["text"],
|
||||
"words": seg.get("words", [])
|
||||
} for seg in result.get("segments", [])]
|
||||
}))
|
||||
`, c.ModelSize, audioPath, langStr, verboseStr, options.Temperature, options.BestOf)
|
||||
|
||||
return pythonCode
|
||||
}
|
||||
|
||||
// TranscriptionOptions contains options for transcription
|
||||
type TranscriptionOptions struct {
|
||||
Language string // Language code or "auto"
|
||||
Verbose bool // Show progress bar
|
||||
Temperature float64 // Temperature for sampling (higher = more creative)
|
||||
BestOf int // Number of candidates when sampling with temperature > 0
|
||||
}
|
||||
|
||||
// DefaultTranscriptionOptions returns default transcription options
|
||||
func DefaultTranscriptionOptions() *TranscriptionOptions {
|
||||
return &TranscriptionOptions{
|
||||
Language: "auto",
|
||||
Verbose: false,
|
||||
Temperature: 0.0,
|
||||
BestOf: 5,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user