Source code for whisper_web.whisper_model

import torch
import asyncio

from pydantic import BaseModel, Field
from typing import Optional, Tuple
from torch import Tensor
from datetime import datetime

from whisper_web.utils import set_device
from whisper_web.events import EventBus, TranscriptionCompleted
from whisper_web.types import Transcription
from transformers.models.whisper import WhisperProcessor, WhisperForConditionalGeneration
from transformers import logging
from dotenv import load_dotenv

load_dotenv()

logging.set_verbosity_error()


[docs] class ModelConfig(BaseModel): """Configuration for creating and loading the Whisper ASR model. This class contains the configuration options required for initializing the Whisper model, including the model size, device type, and other parameters related to the inference process. """ model_id: Optional[str] = Field(default=None, description="The model id to be used for loading the model.") model_size: str = Field( default="large-v3", description="The size of the model to be used for inference. choices=['small', 'medium', 'large-v3']", ) device: str = Field( default="cuda", description="The device to be used for inference. choices=['cpu', 'cuda', 'mps']", ) continuous: bool = Field(default=True, description="Whether to generate audio data continuously or not.") use_vad: bool = Field(default=False, description="Whether to use VAD (Voice Activity Detection) or not.") samplerate: int = Field(default=16000, description="The sample rate of the generated audio.")
[docs] class WhisperModel: """Event-driven Whisper model with no direct dependencies.""" def __init__(self, model_args: ModelConfig, event_bus: EventBus): self.event_bus = event_bus self.device = set_device(model_args.device) # type: ignore self.samplerate: int = model_args.samplerate self.torch_dtype = torch.float16 if self.device == torch.device("cuda") else torch.float32 if self.device == torch.device("cuda"): torch.backends.cuda.matmul.allow_tf32 = True # type: ignore self.available_model_sizes = [] self.model_size = "" self.model_id = "" self.speech_model: Optional[WhisperForConditionalGeneration] = None self.processor: Optional[WhisperProcessor] = None self.load_model(model_args.model_size, model_args.model_id)
[docs] def load_model(self, model_size: str, model_id: Optional[str]) -> None: """ Loads the Whisper ASR model based on the specified model size or model ID. """ if model_id is None: self.available_model_sizes = ["small", "medium", "large-v3"] self.model_size = "large-v3" if model_size == "large" else model_size if model_size not in self.available_model_sizes: print(f"Model size not supported. Defaulting to {self.model_size}.") self.model_id = ( f"distil-whisper/distil-{self.model_size}.en" if self.model_size in self.available_model_sizes[:2] else f"distil-whisper/distil-{self.model_size}" ) else: self.model_id = model_id # Use HF_HOME environment variable if set, otherwise default to relative path import os cache_dir = os.environ.get("HF_HOME", "./.models") self.speech_model = WhisperForConditionalGeneration.from_pretrained( self.model_id, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, cache_dir=cache_dir, ).to(self.device) # type: ignore self.processor = WhisperProcessor.from_pretrained(self.model_id, cache_dir=cache_dir) # type: ignore
async def _transcribe(self, audio: Tensor) -> str: """Core transcription logic separated from event handling.""" assert isinstance(self.processor, WhisperProcessor), "Processor must be an instance of WhisperProcessor" assert isinstance(self.speech_model, WhisperForConditionalGeneration), "Speechmodel must be instance of WhisperForConditionalGeneration" inputs = self.processor( audio, sampling_rate=self.samplerate, return_tensors="pt", ) inputs = inputs.to(self.device, dtype=self.torch_dtype) generated_ids = await asyncio.to_thread( self.speech_model.generate, **inputs, max_new_tokens=128, num_beams=1, return_timestamps=False, pad_token_id=self.processor.tokenizer.pad_token_id, # type: ignore eos_token_id=self.processor.tokenizer.eos_token_id, # type: ignore ) transcript = await asyncio.to_thread( self.processor.batch_decode, generated_ids, skip_special_tokens=True, decode_with_timestamps=False, return_timestamps=False ) return transcript[-1].strip() async def __call__(self, audio_data: Tuple[Tensor, bool]) -> None: """Transcribe audio and publish result as event.""" audio, is_final = audio_data assert isinstance(audio, Tensor), "Audio data must be a torch.Tensor" transcript_text = await self._transcribe(audio) transcription = Transcription(transcript_text, datetime.now().date()) await self.event_bus.publish(TranscriptionCompleted(transcription=transcription, is_final=is_final))