From 6e0c9b700c46a9a6801f4c4e99644c9000fbd6ec Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 11:04:09 +0200 Subject: [PATCH 01/13] fix(STT): whisper call was blocking entire huri's loop --- src/modules/speech_to_text/speech_to_text.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 1300dd3..fdc68f0 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -41,7 +41,7 @@ def __init__( ): super().__init__() - self.model_faster = WhisperModel(model) + self.model_faster = WhisperModel(model, cpu_threads=2) self.language = language self.sample_rate = sample_rate @@ -83,13 +83,15 @@ async def process(self, voice: Voice) -> Optional[Transcript]: self.pending_silence = False processing_audio = np.concatenate(processing_chunks, axis=0) - segments, _ = self.model_faster.transcribe( - processing_audio, - language=self.language, - beam_size=1, # faster for realtime - ) + def transcribe_text(): + segments, _ = self.model_faster.transcribe( + processing_audio, + language=self.language, + beam_size=1, + ) + return " ".join(seg.text for seg in segments).strip() - current_text = " ".join([seg.text for seg in segments]).strip() + current_text = await asyncio.to_thread(transcribe_text) processed_size = self.window_size - self.step_size async with self.lock: From d72b1c369b3ff8729d701435edfd789e77f36255 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 17:19:19 +0200 Subject: [PATCH 02/13] feat(client): added ClientHook abstract class --- src/core/client.py | 119 ++++++++++++++++++++++++++++++++----- src/core/client_senders.py | 102 ------------------------------- 2 files changed, 105 insertions(+), 116 deletions(-) delete mode 100644 src/core/client_senders.py diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..4170e70 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,14 +1,79 @@ import asyncio +import importlib import json import os +import struct +from collections import defaultdict from dataclasses import asdict -from typing import Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type import websockets from src.core.dataclasses.config import ClientConfig +from src.core.events import EventData -from .client_senders import ClientSender, get_senders +from .interface import Interface + + +class ClientSender: + """This class abstract sending data to HuRI. + + output_type: is the event data structure that the ClientSender will send. + It can be EventData or bytes, and must match event topic it send. + + Class derived from ClientSender must implement input_loop, + and use ClientSender.send to send data to HuRI. + """ + + output_type: Type[EventData] | bytes + + def __init__(self, topic: str, **_): + self.topic = topic + if issubclass(self.output_type, EventData): + self.send_function = self._send_event_data + elif issubclass(self.output_type, bytes): + self.send_function = self._send_bytes + else: + raise RuntimeError(f"{self.output_type} should be inherited from \ +EventData or bytes") + + async def input_loop(self, ws: websockets.ClientConnection): + raise NotImplementedError + + async def _send_bytes(self, ws: websockets.ClientConnection, data: bytes): + topic_bytes = self.topic.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + + await ws.send(packet) + + async def _send_event_data(self, ws: websockets.ClientConnection, data: EventData): + packet = json.dumps({"topic": self.topic, "data": asdict(data)}) + + await ws.send(packet) + + async def send(self, ws: websockets.ClientConnection, data: EventData | bytes): + await self.send_function(ws, data) + + +class ClientHook: + """This class abstract processing data from HuRI. + + input_type: is the event data structure that the ClientHook will process. + It can be EventData or bytes, and must match event topic it react to. + + Class derived from ClientHook must implement hook. + + `singletton` allow hooks to modifies shared ressources, + and comes from the used interface. + """ + + input_type: Type[EventData] | bytes + + def __init__(self, **_): + pass + + async def hook(self, singletton: Any, data: EventData | bytes): + raise NotImplementedError class Client: @@ -18,11 +83,29 @@ def __init__( self, config: ClientConfig, user_id_file: str = os.path.expanduser("~/.huri_user_id"), - senders_dict: Dict[str, Type[ClientSender]] = get_senders(), ): self.config = config + + module_path, object_name = self.config.interface_path.split(":", 1) + + module = importlib.import_module(module_path) + interface: Interface = getattr(module, object_name) + + self.singletton = interface.singletton + + available_senders = interface.get_senders() + self.senders: List[ClientSender] = [ + available_senders[sender.name](topic=sender.topic, **sender.args) + for sender in self.config.senders.values() + ] + + available_hooks = interface.get_hooks() + self.hooks: Dict[str, List[ClientHook]] = defaultdict(list) + for hook in self.config.hooks.values(): + for topic in hook.topics: + self.hooks[topic].append(available_hooks[hook.name](**hook.args)) + self.user_id_file = user_id_file - self.senders_dict = senders_dict def _load_user_id(self) -> Optional[str]: if os.path.exists(self.user_id_file): @@ -37,9 +120,22 @@ def _save_user_id(self, _user_id: str): async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: - text = await ws.recv() - print("<<", text) - await asyncio.sleep(0.1) + msg = await ws.recv() + + if isinstance(msg, bytes): + topic_len = struct.unpack("!H", msg[:2])[0] + + topic = msg[2 : 2 + topic_len].decode() + data = msg[2 + topic_len :] + else: + event = json.loads(msg) + topic = event["topic"] + data = event["data"] + + for hook in self.hooks[topic]: + if not issubclass(hook.input_type, bytes): + data = hook.input_type(**data) + asyncio.create_task(hook.hook(self.singletton, data)) except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass @@ -50,11 +146,6 @@ async def run(self): self.config.user_id = self._load_user_id() - senders: List[ClientSender] = [ - self.senders_dict[config.name](ws=ws, **config.args) - for config in self.config.senders.values() - ] - await ws.send(json.dumps(asdict(self.config))) init_msg = json.loads(await ws.recv()) @@ -63,9 +154,9 @@ async def run(self): self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") - receive_task = asyncio.create_task(self._receive_loop(ws)) + receive_task = asyncio.create_task(self._receive_loop(ws=ws)) await asyncio.gather( - *(sender.input_loop() for sender in senders), + *(sender.input_loop(ws=ws) for sender in self.senders), ) receive_task.cancel() diff --git a/src/core/client_senders.py b/src/core/client_senders.py deleted file mode 100644 index 03301a6..0000000 --- a/src/core/client_senders.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -import json -import struct -from dataclasses import asdict -from typing import Dict, Type - -import numpy as np -import sounddevice as sd -import websockets -from prompt_toolkit import PromptSession -from prompt_toolkit.patch_stdout import patch_stdout - -from src.core.events import EventData -from src.modules.speech_to_text.events import Sentence - - -class ClientSender: - """This class abstract sending data to HuRI. - - output_type: is the topic that the ClientSender will send. - Data structure must match event topic. - - Class derived from ClientSender must implement input_loop, - and use ClientSender.send to send data to HuRI. It can be EventData or bytes - """ - - output_type: str - - def __init__(self, ws: websockets.ClientConnection): - self.ws = ws - - async def input_loop(self): - raise NotImplementedError - - async def send(self, topic: str, data: EventData | bytes): - packet: str | bytes - if isinstance(data, EventData): - packet = json.dumps({"topic": topic, "data": asdict(data)}) - else: - topic_bytes = topic.encode() - - packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data - - await self.ws.send(packet) - - -class AudioSender(ClientSender): - output_type = "audio" - - def __init__( - self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs - ): - super().__init__(**kwargs) - - self.sample_rate = sample_rate - self.frame_size = int(sample_rate * frame_duration) - - async def input_loop(self): - loop = asyncio.get_running_loop() - - queue: asyncio.Queue[np.ndarray] = asyncio.Queue() - - def callback(indata: np.ndarray, frames, time, status): - loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) - - with sd.InputStream( - samplerate=self.sample_rate, - channels=1, - dtype="int16", - callback=callback, - blocksize=self.frame_size, - ): - while True: - chunk = await queue.get() - await self.send(self.output_type, chunk.tobytes()) - - -class TextSender(ClientSender): - output_type = "question" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def input_loop(self): - print("'\\exit' or CTRL+D/C to exit.") - session: PromptSession = PromptSession() - try: - while True: - with patch_stdout(): - text = await session.prompt_async(">> ") - if text == "\\exit": - return - await self.send(self.output_type, Sentence(text)) - - except (EOFError, KeyboardInterrupt): - pass - finally: - print("TextSender Exited...") - - -def get_senders() -> Dict[str, Type[ClientSender]]: - return {"audio": AudioSender, "text": TextSender} From eafd64d51503f6100bbfb8f5368c1aa354d12c15 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:07:57 +0200 Subject: [PATCH 03/13] evol(client): better typing for event --- src/core/client.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/core/client.py b/src/core/client.py index 4170e70..cae9dfa 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -5,17 +5,17 @@ import struct from collections import defaultdict from dataclasses import asdict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import websockets from src.core.dataclasses.config import ClientConfig from src.core.events import EventData -from .interface import Interface +T = TypeVar("T", bound=EventData | bytes) -class ClientSender: +class ClientSender(Generic[T]): """This class abstract sending data to HuRI. output_type: is the event data structure that the ClientSender will send. @@ -25,17 +25,10 @@ class ClientSender: and use ClientSender.send to send data to HuRI. """ - output_type: Type[EventData] | bytes + output_type: Type[T] def __init__(self, topic: str, **_): self.topic = topic - if issubclass(self.output_type, EventData): - self.send_function = self._send_event_data - elif issubclass(self.output_type, bytes): - self.send_function = self._send_bytes - else: - raise RuntimeError(f"{self.output_type} should be inherited from \ -EventData or bytes") async def input_loop(self, ws: websockets.ClientConnection): raise NotImplementedError @@ -51,11 +44,14 @@ async def _send_event_data(self, ws: websockets.ClientConnection, data: EventDat await ws.send(packet) - async def send(self, ws: websockets.ClientConnection, data: EventData | bytes): - await self.send_function(ws, data) + async def send(self, ws: websockets.ClientConnection, data: T): + if isinstance(data, bytes): + await self._send_bytes(ws, data) + else: + await self._send_event_data(ws, data) -class ClientHook: +class ClientHook(Generic[T]): """This class abstract processing data from HuRI. input_type: is the event data structure that the ClientHook will process. @@ -67,12 +63,12 @@ class ClientHook: and comes from the used interface. """ - input_type: Type[EventData] | bytes + input_type: Type[T] def __init__(self, **_): pass - async def hook(self, singletton: Any, data: EventData | bytes): + async def hook(self, singletton: Any, data: T): raise NotImplementedError @@ -89,7 +85,7 @@ def __init__( module_path, object_name = self.config.interface_path.split(":", 1) module = importlib.import_module(module_path) - interface: Interface = getattr(module, object_name) + interface = getattr(module, object_name) self.singletton = interface.singletton @@ -133,7 +129,7 @@ async def _receive_loop(self, ws: websockets.ClientConnection): data = event["data"] for hook in self.hooks[topic]: - if not issubclass(hook.input_type, bytes): + if not isinstance(data, bytes): data = hook.input_type(**data) asyncio.create_task(hook.hook(self.singletton, data)) From 32f990faeb4b328e41b4d29462d6164de605aaff Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:09:18 +0200 Subject: [PATCH 04/13] evol(Sender): send topic and data --- src/modules/utils/sender.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index f09b0ba..a9fc2fa 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -1,3 +1,4 @@ +import struct from dataclasses import asdict from fastapi import WebSocket @@ -23,8 +24,8 @@ def __init__(self, ws: WebSocket, type: str): async def process(self, data: EventData | bytes): if isinstance(data, bytes): - await self.ws.send_bytes(data) - elif isinstance(data, EventData): - await self.ws.send_json(asdict(data)) + topic_bytes = self.input_type.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + await self.ws.send_bytes(packet) else: - await self.ws.send_text(data) + await self.ws.send_json({"topic": self.input_type, "data": asdict(data)}) From d078a628459d747acd0a84dd9adbc8b9364de500 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:10:09 +0200 Subject: [PATCH 05/13] feat(Interface): abstract class to define specific Client sender and hooks --- src/core/interface.py | 22 ++++++++++++++++++++++ src/interfaces/__init__.py | 0 2 files changed, 22 insertions(+) create mode 100644 src/core/interface.py create mode 100644 src/interfaces/__init__.py diff --git a/src/core/interface.py b/src/core/interface.py new file mode 100644 index 0000000..fb2b7ac --- /dev/null +++ b/src/core/interface.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, Type + +from .client import ClientHook, ClientSender + + +class Interface: + """This class abstract defining specific Client senders and hooks. + + `self.singletton`: allow hooks to modifies shared ressources, + and comes from the used interface. + + Class derived from Interface must implement get_senders and get_hooks. + """ + + def __init__(self, singletton: Any): + self.singletton = singletton + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + raise NotImplementedError + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + raise NotImplementedError diff --git a/src/interfaces/__init__.py b/src/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 From 349372d32ff18cd59e8b544b0ddf1cbf7b5495a7 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:11:17 +0200 Subject: [PATCH 06/13] feat(config): ClientHookConfig + Interface path + modified topic_list --- src/core/dataclasses/config.py | 31 +++++++++++++++++++++++++++---- src/core/huri.py | 11 +++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index aea111f..f515026 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -15,15 +15,32 @@ def from_dict(self, raw: dict) -> "ModuleConfig": ) +@dataclass +class ClientHookConfig: + name: str + topics: List[str] + args: Mapping[str, Any] + + @classmethod + def from_dict(self, raw: dict) -> "ClientHookConfig": + return self( + name=raw["name"], + topics=raw["topics"], + args=raw.get("args", {}), + ) + + @dataclass class ClientSenderConfig: name: str + topic: str args: Mapping[str, Any] @classmethod def from_dict(self, raw: dict) -> "ClientSenderConfig": return self( name=raw["name"], + topic=raw["topic"], args=raw.get("args", {}), ) @@ -32,15 +49,20 @@ def from_dict(self, raw: dict) -> "ClientSenderConfig": class ClientConfig: user_id: Optional[str] huri_url: str - topic_list: List[str] + interface_path: str + hooks: Dict[str, ClientHookConfig] senders: Dict[str, ClientSenderConfig] modules: Dict[str, ModuleConfig] @classmethod def from_dict(cls, raw: Dict) -> "ClientConfig": + hooks = { + hook_id: ClientHookConfig.from_dict(hok_raw) + for hook_id, hok_raw in raw.get("hooks", {}).items() + } senders = { - sender_id: ClientSenderConfig.from_dict(mod_raw) - for sender_id, mod_raw in raw.get("senders", {}).items() + sender_id: ClientSenderConfig.from_dict(snd_raw) + for sender_id, snd_raw in raw.get("senders", {}).items() } modules = { module_id: ModuleConfig.from_dict(mod_raw) @@ -49,7 +71,8 @@ def from_dict(cls, raw: Dict) -> "ClientConfig": return cls( user_id=None, huri_url=raw["huri_url"], - topic_list=raw["topic_list"], + interface_path=raw["interface_path"], + hooks=hooks, senders=senders, modules=modules, ) diff --git a/src/core/huri.py b/src/core/huri.py index 5fa8038..f5e4eeb 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -32,7 +32,7 @@ def __init__( self, modules: Dict[str, Type[Module]], handles: Dict[str, handle.DeploymentHandle], - events: Dict[str, Type[EventData]], + events: Dict[str, Type[EventData | bytes]], ) -> None: self.module_factory = ModuleFactory(handles) self.event_factory = EventDataFactory() @@ -80,9 +80,12 @@ async def run_session(self, ws: WebSocket): user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - senders: List[Module] = [ - Sender(ws, topic) for topic in client_config.topic_list + topic_list = [ + topic + for hook_config in client_config.hooks.values() + for topic in hook_config.topics ] + senders: List[Module] = [Sender(ws, topic) for topic in topic_list] modules: List[Module] = ( self.module_factory.create_from_config(user_id, client_config.modules) + senders @@ -112,7 +115,7 @@ async def receive_loop(session: Session, ws: WebSocket): msg_text = msg["text"] event = json.loads(msg_text) topic = event["topic"] - data = event["data"] + data = event["data"] # TODO client/server one function data = self.event_factory.create(topic, data) From f1d112d9ff4b50a39ce213627ad3b8a64ed11d00 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:11:46 +0200 Subject: [PATCH 07/13] feat(interface): added cli_interface for cli use --- src/interfaces/cli_interface.py | 123 +++++++++++++++++++ src/modules/speech_to_text/speech_to_text.py | 2 +- 2 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 src/interfaces/cli_interface.py diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py new file mode 100644 index 0000000..c07469f --- /dev/null +++ b/src/interfaces/cli_interface.py @@ -0,0 +1,123 @@ +import asyncio +from typing import Dict, Type + +import numpy as np +import sounddevice as sd +from prompt_toolkit import PromptSession +from prompt_toolkit.patch_stdout import patch_stdout +from scipy.signal import resample + +from src.core.client import ClientHook, ClientSender +from src.core.interface import Interface +from src.modules.speech_to_text.events import Sentence + + +class AudioSender(ClientSender[bytes]): + def __init__( + self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs + ): + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.frame_size = int(sample_rate * frame_duration) + + async def input_loop(self, ws): + loop = asyncio.get_running_loop() + + queue: asyncio.Queue[np.ndarray] = asyncio.Queue() + + def callback(indata: np.ndarray, frames, time, status): + loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) + + with sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype="int16", + callback=callback, + blocksize=self.frame_size, + ): + while True: + chunk = await queue.get() + await self.send(ws, chunk.tobytes()) + + +class TextSender(ClientSender[Sentence]): + output_type = Sentence + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def input_loop(self, ws): + print("'\\exit' or CTRL+D/C to exit.") + session: PromptSession = PromptSession() + try: + while True: + with patch_stdout(): + text = await session.prompt_async(">> ") + if text == "\\exit": + return + await self.send(ws, Sentence(text)) + + except (EOFError, KeyboardInterrupt): + pass + finally: + print("TextSender Exited...") + + +class AudioHook(ClientHook[bytes]): + input_type = bytes + + def __init__(self, sample_rate=48000, incoming_sample_rate=16000, **kwargs): + super().__init__(**kwargs) + + print("Speaker:", sd.query_devices(kind="output")) + + self.incoming_sample_rate = incoming_sample_rate + self.sample_rate = sample_rate + self.stream = sd.OutputStream( + samplerate=sample_rate, + channels=1, + dtype="int16", + ) + self.stream.start() + + self.resample_function = ( + self._resample if sample_rate != incoming_sample_rate else lambda x: x + ) + + def _resample(self, audio: np.ndarray): + return resample( + audio, + int(len(audio) * self.sample_rate / self.incoming_sample_rate), + ).astype(np.int16) + + async def hook(self, singletton: None, data: bytes): + audio = np.frombuffer(data, dtype=np.int16) + + audio = self.resample_function(audio) + + self.stream.write(audio.reshape(-1, 1)) + + +class TextHook(ClientHook[Sentence]): + input_type = Sentence + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def hook(self, singletton: None, data: Sentence): + print("<<", data.text) + + +class CLIInterface(Interface): + def __init__(self): + super().__init__(singletton=None) + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + return {"audio": AudioSender, "text": TextSender} + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + return {"audio": AudioHook, "text": TextHook} + + +cli_interface = CLIInterface() diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index fdc68f0..63bf060 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -41,7 +41,7 @@ def __init__( ): super().__init__() - self.model_faster = WhisperModel(model, cpu_threads=2) + self.model_faster = WhisperModel(model) self.language = language self.sample_rate = sample_rate From 88119f62ce11512c3ee01f6830aacc765ac42087 Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:32:07 +0200 Subject: [PATCH 08/13] evol(interface): cli TextHook is for RAGResult event type --- src/interfaces/cli_interface.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py index c07469f..03cacbf 100644 --- a/src/interfaces/cli_interface.py +++ b/src/interfaces/cli_interface.py @@ -9,6 +9,7 @@ from src.core.client import ClientHook, ClientSender from src.core.interface import Interface +from src.modules.rag.events import RAGResult from src.modules.speech_to_text.events import Sentence @@ -99,14 +100,14 @@ async def hook(self, singletton: None, data: bytes): self.stream.write(audio.reshape(-1, 1)) -class TextHook(ClientHook[Sentence]): - input_type = Sentence +class TextHook(ClientHook[RAGResult]): + input_type = RAGResult def __init__(self, **kwargs): super().__init__(**kwargs) - async def hook(self, singletton: None, data: Sentence): - print("<<", data.text) + async def hook(self, singletton: None, data: RAGResult): + print("<<", data.answer) class CLIInterface(Interface): From ac292b680b5d897d776b7c80e5975a5f30986f0a Mon Sep 17 00:00:00 2001 From: Popochounet Date: Mon, 1 Jun 2026 18:32:31 +0200 Subject: [PATCH 09/13] evol(config): yaml files with new config --- config/{client_aux2.yaml => client_aux .yaml} | 22 +++++++++++++-- config/client_aux.yaml | 28 ------------------- config/client_auxio.yaml | 25 ----------------- config/client_template.yaml | 21 ++++++++++++-- config/client_text.yaml | 8 +++++- 5 files changed, 44 insertions(+), 60 deletions(-) rename config/{client_aux2.yaml => client_aux .yaml} (55%) delete mode 100644 config/client_aux.yaml delete mode 100644 config/client_auxio.yaml diff --git a/config/client_aux2.yaml b/config/client_aux .yaml similarity index 55% rename from config/client_aux2.yaml rename to config/client_aux .yaml index 7d7b601..d1f5195 100644 --- a/config/client_aux2.yaml +++ b/config/client_aux .yaml @@ -1,13 +1,31 @@ huri_url: ws://localhost:8000/session -topic_list: [transcript, question, rag_response] +interface_path: src.interfaces.cli_interface:cli_interface senders: audio: name: audio + topic: audio args: sample_rate: 16000 frame_duration: 0.030 + text: + name: text + topic: question + args: + sample_rate: 16000 + frame_duration: 0.030 + +hooks: + text: + name: text + topics: [question, answer] + audio: + name: audio + topics: [audio] + args: + incoming_sample_rate: ${senders.audio.args.sample_rate} + sample_rate: 44100 modules: mic: @@ -21,10 +39,8 @@ modules: args: language: en block_duration: ${senders.audio.args.frame_duration} - logging: INFO tag: name: tag - logging: INFO rag: name: rag args: diff --git a/config/client_aux.yaml b/config/client_aux.yaml deleted file mode 100644 index fe3e332..0000000 --- a/config/client_aux.yaml +++ /dev/null @@ -1,28 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [question] - -senders: - audio: - name: audio - args: - sample_rate: 16000 - frame_duration: 0.030 - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO - stt: - name: stt - args: - language: "en" - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO diff --git a/config/client_auxio.yaml b/config/client_auxio.yaml deleted file mode 100644 index 8fa2a91..0000000 --- a/config/client_auxio.yaml +++ /dev/null @@ -1,25 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [question] - -senders: - text: - name: text - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - stt: - name: stt - args: - language: en - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO diff --git a/config/client_template.yaml b/config/client_template.yaml index cf1627d..441f3c5 100644 --- a/config/client_template.yaml +++ b/config/client_template.yaml @@ -1,19 +1,34 @@ # HuRI websocket server url huri_url: ws://localhost:8000/session -# List of event topic the client will receive -topic_list: [topic1, topic2] +# Define interface to be used's import path +interface_path: src.interfaces.cli_interface:cli_interface # Define senders to be used and their custom args senders: # sender tag can be anything example: - # sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders) + # sender name must be in the list of available ClientSender in chosen Interface (Interface.get_senders) name: my_sender + # topic the sender will send to HuRI, it must match output_type event data structure + topic: my_event # if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here args: refresh_rate: infinite +# Define hooks to be used and their custom args +hooks: + # hook tag can be anything + example: + # hook name must be in the list of available ClientHook in chosen Interface (Interface.get_senders) + name: my_hook + # topics the hook will process from HuRI, it must match input_type event data structure + topics: [my_event, llm_response] + # if my_hook init with "model", "sample_rate" and "refresh_rate" params, they can be customized here + args: + sample_rate: 0 + no: beat + # Define module to be used and their custom args modules: # module tag can be anything diff --git a/config/client_text.yaml b/config/client_text.yaml index 8ddcaab..d2fb26f 100644 --- a/config/client_text.yaml +++ b/config/client_text.yaml @@ -1,10 +1,16 @@ huri_url: ws://localhost:8000/session -topic_list: [question, rag_response] +interface_path: src.interfaces.cli_interface:cli_interface senders: text: name: text + topic: question + +hooks: + text: + name: text + topics: [rag_response] modules: rag: From 87210a74979cc0e3d875e6e32386be4b43ce936e Mon Sep 17 00:00:00 2001 From: Popochounet Date: Tue, 2 Jun 2026 06:27:50 +0200 Subject: [PATCH 10/13] evol(client): move singletton to __init__ for senders and hooks --- src/core/client.py | 31 ++++++++++++++++++++----------- src/interfaces/cli_interface.py | 4 ++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/core/client.py b/src/core/client.py index cae9dfa..6927565 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -23,12 +23,18 @@ class ClientSender(Generic[T]): Class derived from ClientSender must implement input_loop, and use ClientSender.send to send data to HuRI. + + `singletton` is available to access shared ressources. """ output_type: Type[T] - def __init__(self, topic: str, **_): + def __init__(self, topic: str, singletton: Any, **_): + """ + :topic: topic sent to HuRI + :singletton: allow to get shared ressources""" self.topic = topic + self.singletton = singletton async def input_loop(self, ws: websockets.ClientConnection): raise NotImplementedError @@ -59,16 +65,15 @@ class ClientHook(Generic[T]): Class derived from ClientHook must implement hook. - `singletton` allow hooks to modifies shared ressources, - and comes from the used interface. + `singletton` is available to access and modifies shared ressources. """ input_type: Type[T] - def __init__(self, **_): - pass + def __init__(self, singletton: Any, **_): + self.singletton = singletton - async def hook(self, singletton: Any, data: T): + async def hook(self, data: T): raise NotImplementedError @@ -87,11 +92,11 @@ def __init__( module = importlib.import_module(module_path) interface = getattr(module, object_name) - self.singletton = interface.singletton - available_senders = interface.get_senders() self.senders: List[ClientSender] = [ - available_senders[sender.name](topic=sender.topic, **sender.args) + available_senders[sender.name]( + topic=sender.topic, singletton=interface.singletton, **sender.args + ) for sender in self.config.senders.values() ] @@ -99,7 +104,11 @@ def __init__( self.hooks: Dict[str, List[ClientHook]] = defaultdict(list) for hook in self.config.hooks.values(): for topic in hook.topics: - self.hooks[topic].append(available_hooks[hook.name](**hook.args)) + self.hooks[topic].append( + available_hooks[hook.name]( + singletton=interface.singletton, **hook.args + ) + ) self.user_id_file = user_id_file @@ -131,7 +140,7 @@ async def _receive_loop(self, ws: websockets.ClientConnection): for hook in self.hooks[topic]: if not isinstance(data, bytes): data = hook.input_type(**data) - asyncio.create_task(hook.hook(self.singletton, data)) + asyncio.create_task(hook.hook(data)) except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py index 03cacbf..2634d07 100644 --- a/src/interfaces/cli_interface.py +++ b/src/interfaces/cli_interface.py @@ -92,7 +92,7 @@ def _resample(self, audio: np.ndarray): int(len(audio) * self.sample_rate / self.incoming_sample_rate), ).astype(np.int16) - async def hook(self, singletton: None, data: bytes): + async def hook(self, data: bytes): audio = np.frombuffer(data, dtype=np.int16) audio = self.resample_function(audio) @@ -106,7 +106,7 @@ class TextHook(ClientHook[RAGResult]): def __init__(self, **kwargs): super().__init__(**kwargs) - async def hook(self, singletton: None, data: RAGResult): + async def hook(self, data: RAGResult): print("<<", data.answer) From ba39bb50e5a2168e6cf3bc1325414aa061109d6f Mon Sep 17 00:00:00 2001 From: Popochounet Date: Wed, 24 Jun 2026 14:48:39 +0200 Subject: [PATCH 11/13] evol(client): moved audio into the AudioHook --- src/core/client.py | 76 +++++++------------------------- src/interfaces/cli_interface.py | 78 +++++++++++++++++++++++++++++---- 2 files changed, 85 insertions(+), 69 deletions(-) diff --git a/src/core/client.py b/src/core/client.py index f35e7d0..fed316b 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -3,10 +3,8 @@ import json import os import struct -import wave from collections import defaultdict from dataclasses import asdict -from datetime import datetime from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import numpy as np @@ -44,7 +42,7 @@ async def input_loop(self, ws: websockets.ClientConnection): async def _send_bytes(self, ws: websockets.ClientConnection, data: bytes): topic_bytes = self.topic.encode() - packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + packet = struct.pack(">H", len(topic_bytes)) + topic_bytes + data await ws.send(packet) @@ -87,7 +85,6 @@ def __init__( self, config: ClientConfig, user_id_file: str = os.path.expanduser("~/.huri_user_id"), - save_audio_dir: Optional[str] = None, ): self.config = config @@ -116,16 +113,6 @@ def __init__( self.user_id_file = user_id_file - # When set, incoming audio chunks are buffered per utterance and written - # to a .wav under this directory each time an end-of-utterance marker - # arrives — handy for ear-checking what the TTS actually streamed. - self.save_audio_dir = save_audio_dir - self._audio_buf: List[np.ndarray] = [] - self._audio_sr: Optional[int] = None - self._audio_idx = 0 - if save_audio_dir: - os.makedirs(save_audio_dir, exist_ok=True) - def _load_user_id(self) -> Optional[str]: if os.path.exists(self.user_id_file): with open(self.user_id_file) as f: @@ -136,65 +123,37 @@ def _save_user_id(self, _user_id: str): with open(self.user_id_file, "w") as f: f.write(_user_id) - def _collect_audio(self, samples: np.ndarray, sample_rate: int, end: bool) -> None: - if samples.size: - self._audio_buf.append(samples) - self._audio_sr = sample_rate - if end: - self._flush_audio() - - def _flush_audio(self) -> None: - if not self._audio_buf or self._audio_sr is None: - self._audio_buf = [] - return - audio = np.concatenate(self._audio_buf) - stamp = datetime.now().strftime("%Y%m%d-%H%M%S") - path = os.path.join( - self.save_audio_dir, f"utt-{self._audio_idx:03d}-{stamp}.wav" - ) - self._write_wav(path, audio, self._audio_sr) - print( - f"** saved audio: {path} ({audio.size} samples, " - f"~{audio.size / self._audio_sr:.2f}s @ {self._audio_sr}Hz)" - ) - self._audio_idx += 1 - self._audio_buf = [] - - @staticmethod - def _write_wav(path: str, audio: np.ndarray, sample_rate: int) -> None: - # float32 [-1, 1] -> 16-bit PCM, clipped to avoid wraparound on overshoot. - pcm = np.clip(audio, -1.0, 1.0) - pcm = (pcm * 32767.0).astype("H", msg[:2])[0] topic = msg[2 : 2 + topic_len].decode() data = msg[2 + topic_len :] if topic == "audio" and len(data) >= 13: - sample_rate, end_flag, pts = struct.unpack(">IBd", data[:13]) + sample_rate, end, pts = struct.unpack(">IBd", data[:13]) # Samples are native-endian float32 (Sender uses ndarray.tobytes()). samples = np.frombuffer(data[13:], dtype=np.float32) - print( - f"<< audio: pts={pts:.3f}s samples={samples.size} @ {sample_rate}Hz " - f"end={bool(end_flag)}" - ) - if self.save_audio_dir: - self._collect_audio(samples, sample_rate, bool(end_flag)) + data = { + "sample_rate": sample_rate, + "end": end, + "pts": pts, + "data": samples, + } elif topic == "motion" and len(data) >= 16: pts, fps, n_frames = struct.unpack(">dII", data[:16]) print(f"<< motion: pts={pts:.3f}s frames={n_frames} @ {fps}fps") + data = { + "poses": np.ndarray(), + "expressions": np.ndarray(), + "trans": np.ndarray(), + "fps": fps, + "pts": pts, + } else: print(f"<< {topic}: bytes ({len(data)}B)") else: @@ -209,9 +168,6 @@ async def _receive_loop(self, ws: websockets.ClientConnection): except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass - finally: - if self.save_audio_dir: - self._flush_audio() # save anything left if the stream ended mid-utterance async def run(self): async with websockets.connect(self.config.huri_url) as ws: diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py index 2634d07..4c76d61 100644 --- a/src/interfaces/cli_interface.py +++ b/src/interfaces/cli_interface.py @@ -1,16 +1,19 @@ import asyncio -from typing import Dict, Type - +from typing import Dict, Type, Optional, List +import os import numpy as np import sounddevice as sd from prompt_toolkit import PromptSession from prompt_toolkit.patch_stdout import patch_stdout from scipy.signal import resample +import wave +from datetime import datetime from src.core.client import ClientHook, ClientSender from src.core.interface import Interface from src.modules.rag.events import RAGResult from src.modules.speech_to_text.events import Sentence +from src.modules.text_to_speech.events import Audio class AudioSender(ClientSender[bytes]): @@ -65,10 +68,16 @@ async def input_loop(self, ws): print("TextSender Exited...") -class AudioHook(ClientHook[bytes]): - input_type = bytes +class AudioHook(ClientHook[Audio]): + input_type = Audio - def __init__(self, sample_rate=48000, incoming_sample_rate=16000, **kwargs): + def __init__( + self, + sample_rate=48000, + incoming_sample_rate=16000, + save_audio_dir: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) print("Speaker:", sd.query_devices(kind="output")) @@ -86,18 +95,69 @@ def __init__(self, sample_rate=48000, incoming_sample_rate=16000, **kwargs): self._resample if sample_rate != incoming_sample_rate else lambda x: x ) + # When set, incoming audio chunks are buffered per utterance and written + # to a .wav under this directory each time an end-of-utterance marker + # arrives — handy for ear-checking what the TTS actually streamed. + self.save_audio_dir = save_audio_dir + self._audio_buf: List[np.ndarray] = [] + self._audio_sr: Optional[int] = None + self._audio_idx = 0 + if save_audio_dir: + os.makedirs(save_audio_dir, exist_ok=True) + def _resample(self, audio: np.ndarray): return resample( audio, int(len(audio) * self.sample_rate / self.incoming_sample_rate), ).astype(np.int16) - async def hook(self, data: bytes): - audio = np.frombuffer(data, dtype=np.int16) + def _collect_audio(self, samples: np.ndarray, sample_rate: int, end: bool) -> None: + if samples.size: + self._audio_buf.append(samples) + self._audio_sr = sample_rate + if end: + self._flush_audio() + + def _flush_audio(self) -> None: + if not self._audio_buf or self._audio_sr is None: + self._audio_buf = [] + return + audio = np.concatenate(self._audio_buf) + stamp = datetime.now().strftime("%Y%m%d-%H%M%S") + path = os.path.join( + self.save_audio_dir, f"utt-{self._audio_idx:03d}-{stamp}.wav" + ) + self._write_wav(path, audio, self._audio_sr) + print( + f"** saved audio: {path} ({audio.size} samples, " + f"~{audio.size / self._audio_sr:.2f}s @ {self._audio_sr}Hz)" + ) + self._audio_idx += 1 + self._audio_buf = [] + + @staticmethod + def _write_wav(path: str, audio: np.ndarray, sample_rate: int) -> None: + # float32 [-1, 1] -> 16-bit PCM, clipped to avoid wraparound on overshoot. + pcm = np.clip(audio, -1.0, 1.0) + pcm = (pcm * 32767.0).astype(" Date: Wed, 24 Jun 2026 14:58:22 +0200 Subject: [PATCH 12/13] feat(user_config): clean load of user_id, user_id is set client-side --- src/client.py | 5 +++ src/core/client.py | 16 --------- src/core/dataclasses/config.py | 4 +-- src/core/huri.py | 16 +++++---- src/core/user_config.py | 60 ++++++++++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 25 deletions(-) create mode 100644 src/core/user_config.py diff --git a/src/client.py b/src/client.py index d6bba63..5ba67c4 100644 --- a/src/client.py +++ b/src/client.py @@ -6,6 +6,7 @@ from src.core.client import Client from src.core.dataclasses.config import ClientConfig +from src.core.user_config import get_or_create_and_save_user_id def load_client_config(path: str) -> ClientConfig: @@ -16,6 +17,10 @@ def load_client_config(path: str) -> ClientConfig: if not isinstance(raw_resolved, Dict): raise RuntimeError("error yaml does not output a dict") + user_id_file_path = raw_resolved.get("user_id_file_path") + user_id = get_or_create_and_save_user_id(user_id_file_path) + raw_resolved["user_id"] = user_id + return ClientConfig.from_dict(raw_resolved) diff --git a/src/core/client.py b/src/core/client.py index fed316b..ae1f1bd 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -84,7 +84,6 @@ class Client: def __init__( self, config: ClientConfig, - user_id_file: str = os.path.expanduser("~/.huri_user_id"), ): self.config = config @@ -111,18 +110,6 @@ def __init__( ) ) - self.user_id_file = user_id_file - - def _load_user_id(self) -> Optional[str]: - if os.path.exists(self.user_id_file): - with open(self.user_id_file) as f: - return f.read().strip() - return None - - def _save_user_id(self, _user_id: str): - with open(self.user_id_file, "w") as f: - f.write(_user_id) - async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: @@ -173,14 +160,11 @@ async def run(self): async with websockets.connect(self.config.huri_url) as ws: print("Connected to server") - self.config.user_id = self._load_user_id() - await ws.send(json.dumps(asdict(self.config))) init_msg = json.loads(await ws.recv()) if init_msg.get("type") == "session_init": user_id = init_msg["user_id"] - self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") receive_task = asyncio.create_task(self._receive_loop(ws=ws)) diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index 7aad49f..8e47682 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -47,7 +47,7 @@ def from_dict(self, raw: dict) -> "ClientSenderConfig": @dataclass class ClientConfig: - user_id: Optional[str] + user_id: str huri_url: str interface_path: str hooks: Dict[str, ClientHookConfig] @@ -69,7 +69,7 @@ def from_dict(cls, raw: Dict) -> "ClientConfig": for module_id, mod_raw in raw.get("modules", {}).items() } return cls( - user_id=raw.get("user_id"), + user_id=raw["user_id"], huri_url=raw["huri_url"], interface_path=raw["interface_path"], hooks=hooks, diff --git a/src/core/huri.py b/src/core/huri.py index f5e4eeb..f4ef50c 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -78,8 +78,6 @@ async def run_session(self, ws: WebSocket): client_config_raw: Dict = await ws.receive_json() client_config = ClientConfig.from_dict(client_config_raw) - user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - topic_list = [ topic for hook_config in client_config.hooks.values() @@ -87,15 +85,19 @@ async def run_session(self, ws: WebSocket): ] senders: List[Module] = [Sender(ws, topic) for topic in topic_list] modules: List[Module] = ( - self.module_factory.create_from_config(user_id, client_config.modules) + self.module_factory.create_from_config( + client_config.user_id, client_config.modules + ) + senders ) - await ws.send_json({"type": "session_init", "user_id": user_id}) + await ws.send_json({"type": "session_init", "user_id": client_config.user_id}) session_id = str(uuid.uuid4()) self.clients[session_id] = Session(modules) - print(f"Client registered with _user_id={user_id}, config: {client_config}") + print( + f"Client registered with _user_id={client_config.user_id}, config: {client_config}" + ) async def receive_loop(session: Session, ws: WebSocket): try: @@ -122,11 +124,11 @@ async def receive_loop(session: Session, ws: WebSocket): await session.publish(topic, data) except RuntimeError as e: - print(f"[ERROR] Client {user_id}:", e) + print(f"[ERROR] Client {client_config.user_id}:", e) except WebSocketDisconnect: pass finally: - print(f"Client {user_id} disconnected") + print(f"Client {client_config.user_id} disconnected") await receive_loop(self.clients[session_id], ws) del self.clients[session_id] diff --git a/src/core/user_config.py b/src/core/user_config.py new file mode 100644 index 0000000..22f4b9f --- /dev/null +++ b/src/core/user_config.py @@ -0,0 +1,60 @@ +import os +import platform +import uuid +from pathlib import Path + + +def get_config_dir() -> Path: + """Cross-platform config directory.""" + system = platform.system() + + if system == "Windows": + # TODO: To be tested -> also consider language-specific if needed + base = os.environ.get("APPDATA", os.path.expanduser("~/AppData/Roaming")) + elif system == "Darwin": + # TODO: To be tested -> also consider language-specific if needed + base = os.path.expanduser("~/Library/Application Support") + else: + base = os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + + config_dir = Path(base) / "huri" + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir + + +def load_user_id(path: str | None = None) -> str | None: + """Load existing _user_id, or return None if new user.""" + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + if id_file.exists(): + uid = id_file.read_text().strip() + if uid: + return uid + return None + + +def save_user_id(_user_id: str, path: str | None = None): + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + + id_file.write_text(_user_id) + if platform.system() != "Windows": + id_file.chmod(0o600) + + +def get_or_create_and_save_user_id(path: str | None = None) -> str: + """Load existing or generate new _user_id.""" + uid = load_user_id(path) + if uid: + return uid + uid = str(uuid.uuid4()) + save_user_id(uid, path) + return uid From a923fd7153683965c41ee71096a64d55f850b48b Mon Sep 17 00:00:00 2001 From: Popochounet Date: Wed, 24 Jun 2026 16:19:07 +0200 Subject: [PATCH 13/13] fix(linter): make lint --- config/client_aux .yaml | 48 -- config/client_aux.yaml | 51 ++ config/huri_cpu.yaml | 119 +++++ src/client.py | 12 +- src/core/client.py | 12 +- src/core/dataclasses/config.py | 2 +- src/core/events.py | 23 +- src/core/huri.py | 5 +- src/core/module.py | 4 +- src/interfaces/cli_interface.py | 10 +- src/modules/events.py | 6 +- src/modules/gesture/__init__.py | 2 - src/modules/gesture/emage/modeling.py | 484 +++++++++++++++---- src/modules/gesture/emage/processing.py | 76 ++- src/modules/gesture/events.py | 7 +- src/modules/gesture/gesture.py | 131 +++-- src/modules/modules.py | 11 +- src/modules/rag/ingestion.py | 8 +- src/modules/rag/rag.py | 120 ++--- src/modules/speech_to_text/speech_to_text.py | 2 +- src/modules/text_to_speech/text_to_speech.py | 17 +- src/modules/utils/sender.py | 14 +- 22 files changed, 840 insertions(+), 324 deletions(-) delete mode 100644 config/client_aux .yaml create mode 100644 config/client_aux.yaml create mode 100644 config/huri_cpu.yaml diff --git a/config/client_aux .yaml b/config/client_aux .yaml deleted file mode 100644 index d1f5195..0000000 --- a/config/client_aux .yaml +++ /dev/null @@ -1,48 +0,0 @@ -huri_url: ws://localhost:8000/session - -interface_path: src.interfaces.cli_interface:cli_interface - -senders: - audio: - name: audio - topic: audio - args: - sample_rate: 16000 - frame_duration: 0.030 - text: - name: text - topic: question - args: - sample_rate: 16000 - frame_duration: 0.030 - -hooks: - text: - name: text - topics: [question, answer] - audio: - name: audio - topics: [audio] - args: - incoming_sample_rate: ${senders.audio.args.sample_rate} - sample_rate: 44100 - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${senders.audio.args.frame_duration} - stt: - name: stt - args: - language: en - block_duration: ${senders.audio.args.frame_duration} - tag: - name: tag - rag: - name: rag - args: - language: en - tone: formal diff --git a/config/client_aux.yaml b/config/client_aux.yaml new file mode 100644 index 0000000..5caef92 --- /dev/null +++ b/config/client_aux.yaml @@ -0,0 +1,51 @@ +huri_url: ws://localhost:8000/session + +interface_path: src.interfaces.cli_interface:cli_interface + +senders: + # audio: + # name: audio + # topic: audio_in + # args: + # sample_rate: 16000 + # frame_duration: 0.030 + text: + name: text + topic: question + args: + sample_rate: 16000 + frame_duration: 0.030 + +hooks: + text: + name: text + topics: [question, answer] + audio: + name: audio + topics: [audio] + args: + incoming_sample_rate: ${senders.text.args.sample_rate} + sample_rate: 44100 + save_audio_dir: "uuid" + +modules: + # mic: + # name: mic + # args: + # vad_agressiveness: 3 + # silence_duration: 1.5 + # block_duration: ${senders.audio.args.frame_duration} + # stt: + # name: stt + # args: + # language: en + # block_duration: ${senders.audio.args.frame_duration} + # tag: + # name: tag + rag: + name: rag + args: + language: en + tone: formal + tts: + name: tts diff --git a/config/huri_cpu.yaml b/config/huri_cpu.yaml new file mode 100644 index 0000000..2c6169f --- /dev/null +++ b/config/huri_cpu.yaml @@ -0,0 +1,119 @@ +# HuRI — local Ray Serve config (no Kubernetes) +# ============================================================================ +# This is the standalone equivalent of the inline `ray.serveConfig` in +# deploy/examples/local_nvidia_amd/values.yaml, adapted to run on a single +# local Ray node started with `ray start` instead of KubeRay. +# +# Run it with: +# +# # 1. Start a local Ray head: +# ray start --head --num-cpus=8 --num-gpus=1 +# +# # 2. Deploy this config: +# serve deploy config/huri.yaml +# # ...or run it in the foreground (starts its own Ray if none is running): +# serve run config/huri.yaml +# +# # 3. Tear down when done: +# serve shutdown -y && ray stop +# +# NOTE: The Helm chart uses GPU_TYPE_NVIDIA / GPU_TYPE_AMD custom resources to +# pin deployments to vendor-specific worker groups / container images. On a +# single local machine there is only one environment, so those are dropped and +# scheduling is done purely with num_gpus fractions. +# +# Likewise there are no PVC volume mounts here, so the model-path / +# voice-sample / GPU env vars the chart injects from .Values.models, +# .Values.voiceAssets and workerGroups[*].containerEnv are folded into +# runtime_env.env_vars below. Adjust the paths to wherever the weights +# actually live on this machine. +# ============================================================================ + +proxy_location: EveryNode +http_options: + host: 0.0.0.0 + port: 8000 + +applications: + - name: huri-app + route_prefix: / + import_path: src.app:app + runtime_env: + env_vars: + RAY_COLOR_PREFIX: "1" + + # --- Gesture sliding-window defaults (run in the HuRI CPU actor) --- + HURI_GESTURE_CONTEXT_SEC: "2.0" + HURI_GESTURE_MIN_CHUNK_SEC: "0.5" + # Caps the EMAGE process to a fraction of GPU memory so TTS keeps the + # rest. "0" disables the cap. Keep roughly in line with GestureGeneration + # num_gpus below. (Helm sets this on the nvidia worker's containerEnv.) + HURI_GESTURE_GPU_MEM_FRACTION: "0.2" + + # --- CosyVoice3 / TTS --- + # CosyVoice3 contract: "<|endofprompt|>". + # The reference transcript MUST come AFTER the marker, or the LM treats + # it as an instruction and intermittently speaks it (prompt leakage). + HURI_VOICE_TRANSCRIPT: "You are a helpful assistant.<|endofprompt|>Instinct creates its own oppressors and bids us rise up against them." + # From .Values.models.cosytts.env (mountPath/modelId) — edit for local layout. + HURI_MODEL_PATH: /models/cosytts/FunAudioLLM/Fun-CosyVoice3-0.5B-2512 + # Path to the CosyVoice repo root containing third_party/Matcha-TTS. + HURI_COSY_DIR: /app/cosyvoice + # From .Values.voiceAssets.env — the reference voice sample. + HURI_VOICE_SAMPLE_PATH: /assets/voice.wav + + # --- STT (faster-whisper) --- + # From .Values.models.whisper.env (mountPath/repoId) — edit for local layout. + HURI_STT_MODEL_PATH: /models/whisper/Systran/faster-whisper-base + + # --- Gesture (EMAGE) --- + # From .Values.models.emage.env (mountPath/repoId) — edit for local layout. + HURI_EMAGE_REPO: /models/emage/H-Liu1997/emage_audio + + # --- GPU-vendor runtime env (Helm puts these on the worker containers) --- + NVIDIA_VISIBLE_DEVICES: "all" + NVIDIA_DRIVER_CAPABILITIES: "compute,utility" + HF_HUB_DOWNLOAD_TIMEOUT: "10" + + deployments: + # HuRI: FastAPI/WebSocket ingress + per-session router. CPU only — + # all GPU work is offloaded to the handle-backed deployments below. + - name: HuRI + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + + # STT: shared faster-whisper actor. + - name: STT + num_replicas: 1 + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + + # RAG: embeddings (API) + LLM client. No GPU needed. + - name: RAGHandle + num_replicas: 1 + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + user_config: + embedding_model: "bge-large-en-v1.5-gguf-Q4_K_M" + llm_model: "Qwen3.5-4B-GGUF" + + # GPU split (manual override knob): num_gpus are Ray *scheduling* + # fractions that let replicas pack onto the same device and bias the + # split. TTS gets the lion's share so streamed speech stays low-latency; + # gesture gets the remainder. To also cap gesture's actual VRAM, set + # HURI_GESTURE_GPU_MEM_FRACTION above. + # + # These fractions must sum to <= the --num-gpus you pass to `ray start`. + # As written STT(0.5) + TTS(0.8) + Gesture(0.2) = 1.5, so use + # --num-gpus=2 (e.g. two physical GPUs), or lower the fractions to fit 1. + - name: TTS + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + - name: GestureGeneration + ray_actor_options: + num_cpus: 1 + num_gpus: 0 diff --git a/src/client.py b/src/client.py index 5ba67c4..238be1d 100644 --- a/src/client.py +++ b/src/client.py @@ -31,21 +31,11 @@ async def launch_client(): required=True, help="Path to Client config file (YAML)", ) - parser.add_argument( - "--save-audio", - nargs="?", - const="audio_dumps", - default=None, - metavar="DIR", - help="Save streamed TTS audio to .wav files (one per utterance) in DIR " - "for quality-checking. Defaults to ./audio_dumps when the flag is given " - "without a value.", - ) args = parser.parse_args() config = load_client_config(args.config) - await Client(config=config, save_audio_dir=args.save_audio).run() + await Client(config=config).run() if __name__ == "__main__": diff --git a/src/core/client.py b/src/core/client.py index ae1f1bd..3786b4e 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,11 +1,10 @@ import asyncio import importlib import json -import os import struct from collections import defaultdict from dataclasses import asdict -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Type, TypeVar import numpy as np import websockets @@ -123,7 +122,8 @@ async def _receive_loop(self, ws: websockets.ClientConnection): if topic == "audio" and len(data) >= 13: sample_rate, end, pts = struct.unpack(">IBd", data[:13]) - # Samples are native-endian float32 (Sender uses ndarray.tobytes()). + # Samples are native-endian float32 + # (Sender uses ndarray.tobytes()). samples = np.frombuffer(data[13:], dtype=np.float32) data = { "sample_rate": sample_rate, @@ -135,9 +135,9 @@ async def _receive_loop(self, ws: websockets.ClientConnection): pts, fps, n_frames = struct.unpack(">dII", data[:16]) print(f"<< motion: pts={pts:.3f}s frames={n_frames} @ {fps}fps") data = { - "poses": np.ndarray(), - "expressions": np.ndarray(), - "trans": np.ndarray(), + "poses": np.ndarray(0), + "expressions": np.ndarray(0), + "trans": np.ndarray(0), "fps": fps, "pts": pts, } diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index 8e47682..2df9e0a 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping @dataclass diff --git a/src/core/events.py b/src/core/events.py index dfaf66b..43176d2 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -52,7 +52,8 @@ async def publish(self, event_topic, data): if event_topic not in ("audio_in",): # skip mic-frame spam logger.info( "[GRAPH] publish topic=%r subscribers=%s", - event_topic, [type(m).__name__ for m in subs], + event_topic, + [type(m).__name__ for m in subs], ) for module in subs: asyncio.create_task(self._run(module, data)) @@ -68,11 +69,15 @@ async def _run(self, module: Module, data): continue logger.info( "[GRAPH] %s -> %r: %s", - type(module).__name__, module.output_type, _summarize(item), + type(module).__name__, + module.output_type, + _summarize(item), ) await self.publish(module.output_type, item) except Exception: - logger.exception("[GRAPH] async generator failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] async generator failed in %s", type(module).__name__ + ) else: try: @@ -80,14 +85,20 @@ async def _run(self, module: Module, data): if value is not None: logger.info( "[GRAPH] %s -> %r: %s", - type(module).__name__, module.output_type, _summarize(value), + type(module).__name__, + module.output_type, + _summarize(value), ) await self.publish(module.output_type, value) except Exception: - logger.exception("[GRAPH] coroutine failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] coroutine failed in %s", type(module).__name__ + ) except Exception: - logger.exception("[GRAPH] process() call failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] process() call failed in %s", type(module).__name__ + ) def _summarize(item) -> str: diff --git a/src/core/huri.py b/src/core/huri.py index f4ef50c..0e9e665 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -95,9 +95,8 @@ async def run_session(self, ws: WebSocket): session_id = str(uuid.uuid4()) self.clients[session_id] = Session(modules) - print( - f"Client registered with _user_id={client_config.user_id}, config: {client_config}" - ) + print(f"Client registered with _user_id={client_config.user_id}, \ +config: {client_config}") async def receive_loop(session: Session, ws: WebSocket): try: diff --git a/src/core/module.py b/src/core/module.py index 308273b..281a68b 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type +from typing import Any, AsyncGenerator, Coroutine, Optional, Type from ray.serve import handle @@ -24,7 +24,7 @@ class Module: input_type: str output_type: Optional[str] - async def process(self, _) -> Optional[Any]: + def process(self, _) -> Coroutine[Any, Any, Any] | AsyncGenerator[Any, None]: raise NotImplementedError diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py index 4c76d61..39cebb0 100644 --- a/src/interfaces/cli_interface.py +++ b/src/interfaces/cli_interface.py @@ -1,13 +1,14 @@ import asyncio -from typing import Dict, Type, Optional, List import os +import wave +from datetime import datetime +from typing import Dict, List, Optional, Type + import numpy as np import sounddevice as sd from prompt_toolkit import PromptSession from prompt_toolkit.patch_stdout import patch_stdout from scipy.signal import resample -import wave -from datetime import datetime from src.core.client import ClientHook, ClientSender from src.core.interface import Interface @@ -148,7 +149,8 @@ def _write_wav(path: str, audio: np.ndarray, sample_rate: int) -> None: async def hook(self, data: Audio): print( - f"<< audio: pts={data.pts:.3f}s samples={data.data.size} @ {data.sample_rate}Hz " + f"<< audio: pts={data.pts:.3f}s " + f"samples={data.data.size} @ {data.sample_rate}Hz " f"end={bool(data.end)}" ) # audio = np.frombuffer(data, dtype=np.int16) diff --git a/src/modules/events.py b/src/modules/events.py index 93779a2..5026b8a 100644 --- a/src/modules/events.py +++ b/src/modules/events.py @@ -1,9 +1,9 @@ from typing import Dict, Type from src.core.events import EventData -from src.modules.speech_to_text.events import Sentence, Transcript, Voice -from src.modules.text_to_speech.events import Audio, Token from src.modules.gesture.events import Motion +from src.modules.speech_to_text.events import Sentence, Transcript, Voice +from src.modules.text_to_speech.events import Token def get_events() -> Dict[str, Type[EventData | bytes]]: @@ -14,7 +14,7 @@ def get_events() -> Dict[str, Type[EventData | bytes]]: "transcript": Transcript, "question": Sentence, "token": Token, - "motion": Motion + "motion": Motion, } return events diff --git a/src/modules/gesture/__init__.py b/src/modules/gesture/__init__.py index b98a65c..e69de29 100644 --- a/src/modules/gesture/__init__.py +++ b/src/modules/gesture/__init__.py @@ -1,2 +0,0 @@ -from .events import Motion -from .gesture import Gesture diff --git a/src/modules/gesture/emage/modeling.py b/src/modules/gesture/emage/modeling.py index b4d71ab..f58b967 100644 --- a/src/modules/gesture/emage/modeling.py +++ b/src/modules/gesture/emage/modeling.py @@ -15,13 +15,9 @@ VQEncoderV6, WavEncoder, axis_angle_to_rotation_6d, - matrix_to_axis_angle, - matrix_to_rotation_6d, recover_from_mask_ts, rotation_6d_to_axis_angle, - rotation_6d_to_matrix, velocity2position, - axis_angle_to_matrix, ) @@ -47,14 +43,21 @@ class EmageVQVAEConv(PreTrainedModel): def __init__(self, config): super().__init__(config) self.encoder = VQEncoderV5(config) - self.quantizer = Quantizer(config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda) + self.quantizer = Quantizer( + config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda + ) self.decoder = VQDecoderV5(config) def forward(self, inputs): pre_latent = self.encoder(inputs) embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) rec_pose = self.decoder(vq_latent) - return {"poses_feat": vq_latent, "embedding_loss": embedding_loss, "perplexity": perplexity, "rec_pose": rec_pose} + return { + "poses_feat": vq_latent, + "embedding_loss": embedding_loss, + "perplexity": perplexity, + "rec_pose": rec_pose, + } def map2index(self, inputs): pre_latent = self.encoder(inputs) @@ -72,8 +75,8 @@ def decode(self, index): def decode_from_latent(self, latent): z_flattened = latent.contiguous().view(-1, self.quantizer.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.quantizer.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.quantizer.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.quantizer.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -86,20 +89,118 @@ class EmageVQModel(nn.Module): def __init__(self, face_model, upper_model, hands_model, lower_model, global_model): super().__init__() self.joint_mask_upper = [ - False, False, False, True, False, False, True, False, False, True, - False, False, True, True, True, True, True, True, True, True, - True, True, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, + False, + False, + False, + True, + False, + False, + True, + False, + False, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ] self.joint_mask_lower = [ - True, True, True, False, True, True, False, True, True, False, - True, True, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, + True, + True, + True, + False, + True, + True, + False, + True, + True, + False, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ] self.vq_model_face = face_model self.vq_model_upper = upper_model @@ -107,21 +208,37 @@ def __init__(self, face_model, upper_model, hands_model, lower_model, global_mod self.vq_model_lower = lower_model self.global_motion = global_model - def spilt_inputs(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): + def spilt_inputs( + self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None + ): bs, t, j6 = smplx_body_rot6d.shape smplx_body_rot6d = smplx_body_rot6d.reshape(bs, t, j6 // 6, 6) jaw_rot6d = smplx_body_rot6d[:, :, 22:23, :].reshape(bs, t, 6) face = torch.cat([jaw_rot6d, expression], dim=2) - upper_rot6d = smplx_body_rot6d[:, :, self.joint_mask_upper, :].reshape(bs, t, 78) + upper_rot6d = smplx_body_rot6d[:, :, self.joint_mask_upper, :].reshape( + bs, t, 78 + ) hands_rot6d = smplx_body_rot6d[:, :, 25:55, :].reshape(bs, t, 180) - lower_rot6d = smplx_body_rot6d[:, :, self.joint_mask_lower, :].reshape(bs, t, 54) - tar_contact = torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) if tar_contact is None else tar_contact - tar_trans = torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) if tar_trans is None else tar_trans + lower_rot6d = smplx_body_rot6d[:, :, self.joint_mask_lower, :].reshape( + bs, t, 54 + ) + tar_contact = ( + torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) + if tar_contact is None + else tar_contact + ) + tar_trans = ( + torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) + if tar_trans is None + else tar_trans + ) lower = torch.cat([lower_rot6d, tar_trans, tar_contact], dim=2) return dict(face=face, upper=upper_rot6d, hands=hands_rot6d, lower=lower) def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): - inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + inputs = self.spilt_inputs( + smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans + ) return dict( face=self.vq_model_face.map2index(inputs["face"]), upper=self.vq_model_upper.map2index(inputs["upper"]), @@ -129,8 +246,12 @@ def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=No lower=self.vq_model_lower.map2index(inputs["lower"]), ) - def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): - inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + def map2latent( + self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None + ): + inputs = self.spilt_inputs( + smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans + ) return dict( face=self.vq_model_face.map2latent(inputs["face"]), upper=self.vq_model_upper.map2latent(inputs["upper"]), @@ -140,11 +261,27 @@ def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=N def decode( self, - face_index=None, upper_index=None, hands_index=None, lower_index=None, - face_latent=None, upper_latent=None, hands_latent=None, lower_latent=None, - get_global_motion=False, ref_trans=None, + face_index=None, + upper_index=None, + hands_index=None, + lower_index=None, + face_latent=None, + upper_latent=None, + hands_latent=None, + lower_latent=None, + get_global_motion=False, + ref_trans=None, ): - for t in [face_index, upper_index, hands_index, lower_index, face_latent, upper_latent, hands_latent, lower_latent]: + for t in [ + face_index, + upper_index, + hands_index, + lower_index, + face_latent, + upper_latent, + hands_latent, + lower_latent, + ]: if t is not None: bs, seq = t.shape[:2] break @@ -163,34 +300,48 @@ def decode( if upper_index is not None: upper_6d = self.vq_model_upper.decode(upper_index) - upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif upper_latent is not None: upper_6d = self.vq_model_upper.decode_from_latent(upper_latent) - upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: upper = torch.zeros(bs, seq, 39, device=self.vq_model_upper.device) if hands_index is not None: hands_6d = self.vq_model_hands.decode(hands_index) - hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif hands_latent is not None: hands_6d = self.vq_model_hands.decode_from_latent(hands_latent) - hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: hands = torch.zeros(bs, seq, 90, device=self.vq_model_hands.device) if lower_index is not None: lower_mix = self.vq_model_lower.decode(lower_index) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] - lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif lower_latent is not None: lower_mix = self.vq_model_lower.decode_from_latent(lower_latent) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] - lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: lower = torch.zeros(bs, seq, 27, device=self.vq_model_lower.device) transfoot = torch.zeros(bs, seq, 7, device=self.vq_model_lower.device) - lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, seq, -1, 3)).reshape(bs, seq, -1) + lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, seq, -1, 3)).reshape( + bs, seq, -1 + ) lower_mix = torch.cat([lower_6d, transfoot], dim=-1) upper2all = recover_from_mask_ts(upper, self.joint_mask_upper) @@ -198,8 +349,10 @@ def decode( lower2all = recover_from_mask_ts(lower, self.joint_mask_lower) all_motion_axis_angle = upper2all + hands2all + lower2all - all_motion_axis_angle[:, :, 22 * 3:22 * 3 + 3] = face_jaw - all_motion_rot6d = axis_angle_to_rotation_6d(all_motion_axis_angle.reshape(bs, seq, 55, 3)).reshape(bs, seq, 55 * 6) + all_motion_axis_angle[:, :, 22 * 3 : 22 * 3 + 3] = face_jaw + all_motion_rot6d = axis_angle_to_rotation_6d( + all_motion_axis_angle.reshape(bs, seq, 55, 3) + ).reshape(bs, seq, 55 * 6) all_motion4inference = torch.cat([all_motion_rot6d, transfoot], dim=2) global_motion = None @@ -218,8 +371,12 @@ def _get_global_motion(self, lower_body, ref_trans): rec_trans_v_s = global_motion["rec_pose"][:, :, 54:57] if len(ref_trans.shape) == 2: ref_trans = ref_trans.unsqueeze(0).repeat(rec_trans_v_s.shape[0], 1, 1) - rec_x_trans = velocity2position(rec_trans_v_s[:, :, 0:1], 1 / 30, ref_trans[:, 0, 0:1]) - rec_z_trans = velocity2position(rec_trans_v_s[:, :, 2:3], 1 / 30, ref_trans[:, 0, 2:3]) + rec_x_trans = velocity2position( + rec_trans_v_s[:, :, 0:1], 1 / 30, ref_trans[:, 0, 0:1] + ) + rec_z_trans = velocity2position( + rec_trans_v_s[:, :, 2:3], 1 / 30, ref_trans[:, 0, 2:3] + ) rec_y_trans = rec_trans_v_s[:, :, 1:2] return torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) @@ -233,41 +390,97 @@ def __init__(self, config: EmageAudioConfig): self.cfg = config self.audio_encoder_face = WavEncoder(self.cfg.audio_f) self.audio_encoder_body = WavEncoder(self.cfg.audio_f) - self.speaker_embedding_body = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) - self.speaker_embedding_face = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) - self.mask_embedding = nn.Parameter(torch.zeros(1, 1, self.cfg.pose_dims + 3 + 4)) - nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size ** -0.5) + self.speaker_embedding_body = nn.Embedding( + self.cfg.speaker_dims, self.cfg.hidden_size + ) + self.speaker_embedding_face = nn.Embedding( + self.cfg.speaker_dims, self.cfg.hidden_size + ) + self.mask_embedding = nn.Parameter( + torch.zeros(1, 1, self.cfg.pose_dims + 3 + 4) + ) + nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size**-0.5) args_top = copy.deepcopy(self.cfg) args_top.vae_layer = 3 args_top.vae_length = self.cfg.motion_f args_top.vae_test_dim = self.cfg.pose_dims + 3 + 4 self.motion_encoder = VQEncoderV6(args_top) - self.bodyhints_face = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) - self.bodyhints_body = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) + self.bodyhints_face = MLP( + self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f + ) + self.bodyhints_body = MLP( + self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f + ) self.audio_body_motion_proj = nn.Linear(self.cfg.audio_f, self.cfg.hidden_size) self.moton_proj = nn.Linear(self.cfg.motion_f, self.cfg.hidden_size) - self.position_embeddings = PeriodicPositionalEncoding(self.cfg.hidden_size, period=self.cfg.pose_length, max_seq_len=self.cfg.pose_length) - self.transformer_en_layer = nn.TransformerEncoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) - self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) - self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) - self.audio_motion_cross_attn = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=8) - self.motion2latent_upper = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.motion2latent_hands = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.motion2latent_lower = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.body_motion_decoder_upper = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.body_motion_decoder_hands = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.body_motion_decoder_lower = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.motion_out_proj_upper = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_out_proj_hands = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_out_proj_lower = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_upper = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_hands = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_lower = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.audio_face_motion_proj = nn.Linear(self.cfg.audio_f + self.cfg.motion_f, self.cfg.hidden_size) - self.face_motion_decoder = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=4) + self.position_embeddings = PeriodicPositionalEncoding( + self.cfg.hidden_size, + period=self.cfg.pose_length, + max_seq_len=self.cfg.pose_length, + ) + self.transformer_en_layer = nn.TransformerEncoderLayer( + d_model=self.cfg.hidden_size, + nhead=4, + dim_feedforward=self.cfg.hidden_size * 2, + ) + self.motion_self_encoder = nn.TransformerEncoder( + self.transformer_en_layer, num_layers=1 + ) + self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer( + d_model=self.cfg.hidden_size, + nhead=4, + dim_feedforward=self.cfg.hidden_size * 2, + ) + self.audio_motion_cross_attn = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=8 + ) + self.motion2latent_upper = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.motion2latent_hands = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.motion2latent_lower = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.body_motion_decoder_upper = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.body_motion_decoder_hands = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.body_motion_decoder_lower = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.motion_out_proj_upper = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_out_proj_hands = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_out_proj_lower = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_upper = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_hands = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_lower = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.audio_face_motion_proj = nn.Linear( + self.cfg.audio_f + self.cfg.motion_f, self.cfg.hidden_size + ) + self.face_motion_decoder = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=4 + ) self.face_out_proj = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.face_cls = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.face_cls = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): masked_embeddings = self.mask_embedding.expand_as(masked_motion) @@ -281,30 +494,40 @@ def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): audio2body_fea = self.audio_encoder_body(audio) if audio2face_fea.shape[1] > body_hint_face.shape[1]: - audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + audio2face_fea = audio2face_fea[:, : body_hint_face.shape[1]] if audio2body_fea.shape[1] > body_hint_face.shape[1]: - audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + audio2face_fea = audio2face_fea[:, : body_hint_face.shape[1]] bs, t, _ = audio2face_fea.shape - speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat(1, t, 1) + speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat( + 1, t, 1 + ) speaker_face_fea_proj = self.speaker_embedding_face(speaker_id).repeat(1, t, 1) - audio2face_fea_proj = self.audio_face_motion_proj(torch.cat([audio2face_fea, body_hint_face], dim=2)) + audio2face_fea_proj = self.audio_face_motion_proj( + torch.cat([audio2face_fea, body_hint_face], dim=2) + ) face_proj = self.position_embeddings(speaker_face_fea_proj) - decode_face = self.face_motion_decoder(tgt=face_proj.permute(1, 0, 2), memory=audio2face_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + decode_face = self.face_motion_decoder( + tgt=face_proj.permute(1, 0, 2), memory=audio2face_fea_proj.permute(1, 0, 2) + ).permute(1, 0, 2) face_latent = self.face_out_proj(decode_face) classify_face = self.face_cls(face_latent) masked_motion_proj = self.moton_proj(body_hint_body) masked_motion_proj = self.position_embeddings(masked_motion_proj) masked_motion_proj = speaker_motion_fea_proj + masked_motion_proj - motion_fea = self.motion_self_encoder(masked_motion_proj.permute(1, 0, 2)).permute(1, 0, 2) + motion_fea = self.motion_self_encoder( + masked_motion_proj.permute(1, 0, 2) + ).permute(1, 0, 2) audio2body_fea_proj = self.audio_body_motion_proj(audio2body_fea) motion_fea = motion_fea + speaker_motion_fea_proj motion_fea = self.position_embeddings(motion_fea) - audio2body_fea_cross = self.audio_motion_cross_attn(tgt=motion_fea.permute(1, 0, 2), memory=audio2body_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + audio2body_fea_cross = self.audio_motion_cross_attn( + tgt=motion_fea.permute(1, 0, 2), memory=audio2body_fea_proj.permute(1, 0, 2) + ).permute(1, 0, 2) if not use_audio: audio2body_fea_cross = audio2body_fea_cross * 0.0 motion_fea = motion_fea + audio2body_fea_cross @@ -313,9 +536,21 @@ def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): hands_latent = self.motion2latent_hands(motion_fea) lower_latent = self.motion2latent_lower(motion_fea) - motion_upper_refine = self.body_motion_decoder_upper(tgt=upper_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(hands_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) - motion_hands_refine = self.body_motion_decoder_hands(tgt=hands_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) - motion_lower_refine = self.body_motion_decoder_lower(tgt=lower_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + hands_latent).permute(1, 0, 2)).permute(1, 0, 2) + motion_upper_refine = self.body_motion_decoder_upper( + tgt=upper_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(hands_latent + lower_latent).permute(1, 0, 2), + ).permute(1, 0, 2) + motion_hands_refine = self.body_motion_decoder_hands( + tgt=hands_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(upper_latent + lower_latent).permute(1, 0, 2), + ).permute(1, 0, 2) + motion_lower_refine = self.body_motion_decoder_lower( + tgt=lower_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(upper_latent + hands_latent).permute(1, 0, 2), + ).permute(1, 0, 2) upper_latent = self.motion_out_proj_upper(upper_latent + motion_upper_refine) hands_latent = self.motion_out_proj_hands(hands_latent + motion_hands_refine) lower_latent = self.motion_out_proj_lower(lower_latent + motion_lower_refine) @@ -344,12 +579,12 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): fake_foot_and_trans = torch.zeros(bs, length, 7).to(audio.device) fake_motion = torch.cat([fake_motion, fake_foot_and_trans], dim=-1) if masked_motion is not None: - fake_motion[:, :masked_motion.shape[1]] = masked_motion + fake_motion[:, : masked_motion.shape[1]] = masked_motion masked_motion = fake_motion fake_mask = torch.ones_like(masked_motion) if mask is not None: - fake_mask[:, :mask.shape[1]] = mask + fake_mask[:, : mask.shape[1]] = mask mask = fake_mask bs, total_len, c = masked_motion.shape @@ -371,34 +606,71 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): window_motion = masked_motion[:, start_idx:end_idx, :].clone() window_motion[:, :pre_frames, :] = torch.where( (window_mask[:, :pre_frames, :] == 0), - masked_motion[:, start_idx:start_idx + pre_frames, :], + masked_motion[:, start_idx : start_idx + pre_frames, :], last_motion, ) window_mask[:, :pre_frames, :] = 0 audio_slice_len = (end_idx - start_idx) * (16000 // 30) - audio_slice = audio[:, start_idx * (16000 // 30):start_idx * (16000 // 30) + audio_slice_len] - net_out_val = self.forward(audio_slice, speaker_id, masked_motion=window_motion, mask=window_mask, use_audio=True) - - _, cls_face = torch.max(F.log_softmax(net_out_val["cls_face"], dim=2), dim=2) - _, cls_upper = torch.max(F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2) - _, cls_hands = torch.max(F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2) - _, cls_lower = torch.max(F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2) - - face_latent = net_out_val["rec_face"] if self.cfg.lf > 0 and self.cfg.cf == 0 else None - upper_latent = net_out_val["rec_upper"] if self.cfg.lu > 0 and self.cfg.cu == 0 else None - hands_latent = net_out_val["rec_hands"] if self.cfg.lh > 0 and self.cfg.ch == 0 else None - lower_latent = net_out_val["rec_lower"] if self.cfg.ll > 0 and self.cfg.cl == 0 else None + audio_slice = audio[ + :, + start_idx * (16000 // 30) : start_idx * (16000 // 30) + audio_slice_len, + ] + net_out_val = self.forward( + audio_slice, + speaker_id, + masked_motion=window_motion, + mask=window_mask, + use_audio=True, + ) + + _, cls_face = torch.max( + F.log_softmax(net_out_val["cls_face"], dim=2), dim=2 + ) + _, cls_upper = torch.max( + F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2 + ) + _, cls_hands = torch.max( + F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2 + ) + _, cls_lower = torch.max( + F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2 + ) + + face_latent = ( + net_out_val["rec_face"] + if self.cfg.lf > 0 and self.cfg.cf == 0 + else None + ) + upper_latent = ( + net_out_val["rec_upper"] + if self.cfg.lu > 0 and self.cfg.cu == 0 + else None + ) + hands_latent = ( + net_out_val["rec_hands"] + if self.cfg.lh > 0 and self.cfg.ch == 0 + else None + ) + lower_latent = ( + net_out_val["rec_lower"] + if self.cfg.ll > 0 and self.cfg.cl == 0 + else None + ) face_index = cls_face if self.cfg.cf > 0 else None upper_index = cls_upper if self.cfg.cu > 0 else None hands_index = cls_hands if self.cfg.ch > 0 else None lower_index = cls_lower if self.cfg.cl > 0 else None decode_dict = vq_model.decode( - face_latent=face_latent, upper_latent=upper_latent, - lower_latent=lower_latent, hands_latent=hands_latent, - face_index=face_index, upper_index=upper_index, - lower_index=lower_index, hands_index=hands_index, + face_latent=face_latent, + upper_latent=upper_latent, + lower_latent=lower_latent, + hands_latent=hands_latent, + face_index=face_index, + upper_index=upper_index, + lower_index=lower_index, + hands_index=hands_index, ) last_motion = decode_dict["all_motion4inference"][:, -pre_frames:, :] @@ -419,14 +691,24 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): final_motion = masked_motion[:, final_start:final_end, :].clone() final_motion[:, :pre_frames, :] = torch.where( (final_mask[:, :pre_frames, :] == 0), - masked_motion[:, final_start:final_start + pre_frames, :], + masked_motion[:, final_start : final_start + pre_frames, :], last_motion, ) final_mask[:, :pre_frames, :] = 0 audio_slice_len = (final_end - final_start) * (16000 // 30) - audio_slice = audio[:, final_start * (16000 // 30):final_start * (16000 // 30) + audio_slice_len] - net_out_val = self.forward(audio_slice, speaker_id, masked_motion=final_motion, mask=final_mask, use_audio=True) + audio_slice = audio[ + :, + final_start * (16000 // 30) : final_start * (16000 // 30) + + audio_slice_len, + ] + net_out_val = self.forward( + audio_slice, + speaker_id, + masked_motion=final_motion, + mask=final_mask, + use_audio=True, + ) rec_all_face.append(net_out_val["rec_face"]) rec_all_upper.append(net_out_val["rec_upper"]) diff --git a/src/modules/gesture/emage/processing.py b/src/modules/gesture/emage/processing.py index 2d128b1..91266b0 100644 --- a/src/modules/gesture/emage/processing.py +++ b/src/modules/gesture/emage/processing.py @@ -1,4 +1,5 @@ import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -121,7 +122,7 @@ def velocity2position(data_seq, dt, init_pos): if i == 0: res_trans.append(init_pos.unsqueeze(1)) else: - res = data_seq[:, i - 1:i] * dt + res_trans[-1] + res = data_seq[:, i - 1 : i] * dt + res_trans[-1] res_trans.append(res) return torch.cat(res_trans, dim=1) @@ -156,13 +157,15 @@ def forward(self, z): assert z.shape[-1] == self.e_dim z_flattened = z.contiguous().view(-1, self.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) - loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean((z_q.detach() - z) ** 2) + loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean( + (z_q.detach() - z) ** 2 + ) z_q = z + (z_q - z).detach() min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) e_mean = torch.mean(min_encodings, dim=0) @@ -173,8 +176,8 @@ def map2index(self, z): assert z.shape[-1] == self.e_dim z_flattened = z.contiguous().view(-1, self.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -288,24 +291,51 @@ def forward(self, inputs): class BasicBlock(nn.Module): - def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d): + def __init__( + self, + inplanes, + planes, + ker_size, + stride=1, + downsample=None, + dilation=1, + first_dilation=None, + act_layer=nn.LeakyReLU, + norm_layer=nn.BatchNorm1d, + ): super().__init__() self.conv1 = nn.Conv1d( - inplanes, planes, kernel_size=ker_size, stride=stride, - padding=first_dilation, dilation=dilation, bias=True, + inplanes, + planes, + kernel_size=ker_size, + stride=stride, + padding=first_dilation, + dilation=dilation, + bias=True, ) self.bn1 = norm_layer(planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv1d( - planes, planes, kernel_size=ker_size, padding=ker_size // 2, - dilation=dilation, bias=True, + planes, + planes, + kernel_size=ker_size, + padding=ker_size // 2, + dilation=dilation, + bias=True, ) self.bn2 = norm_layer(planes) self.act2 = act_layer(inplace=True) if downsample is not None: self.downsample = nn.Sequential( - nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, - padding=first_dilation, dilation=dilation, bias=True), + nn.Conv1d( + inplanes, + planes, + stride=stride, + kernel_size=ker_size, + padding=first_dilation, + dilation=dilation, + bias=True, + ), norm_layer(planes), ) else: @@ -330,10 +360,16 @@ def __init__(self, out_dim, audio_in=1): super().__init__() self.out_dim = out_dim self.feat_extractor = nn.Sequential( - BasicBlock(audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True), - BasicBlock(out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True), + BasicBlock( + audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True + ), + BasicBlock( + out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True + ), BasicBlock(out_dim // 4, out_dim // 4, 15, 1, first_dilation=7), - BasicBlock(out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True), + BasicBlock( + out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True + ), BasicBlock(out_dim // 2, out_dim // 2, 15, 1, first_dilation=7), BasicBlock(out_dim // 2, out_dim, 15, 3, first_dilation=0, downsample=True), ) @@ -367,14 +403,16 @@ def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(period, d_model) position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) repeat_num = (max_seq_len // period) + 1 pe = pe.repeat(1, repeat_num, 1) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:, :x.size(1), :] + x = x + self.pe[:, : x.size(1), :] return self.dropout(x) diff --git a/src/modules/gesture/events.py b/src/modules/gesture/events.py index 228b6d8..7793de1 100644 --- a/src/modules/gesture/events.py +++ b/src/modules/gesture/events.py @@ -6,10 +6,11 @@ _EMAGE_FPS = 30 + @dataclass class Motion(EventData): - poses: np.ndarray # (t, 165) SMPL-X axis-angle, 55 joints × 3 + poses: np.ndarray # (t, 165) SMPL-X axis-angle, 55 joints × 3 expressions: np.ndarray # (t, 100) facial expression coefficients - trans: np.ndarray # (t, 3) global root translation + trans: np.ndarray # (t, 3) global root translation fps: int = _EMAGE_FPS - pts: float = 0.0 # presentation timestamp in seconds, paired with Audio.pts + pts: float = 0.0 # presentation timestamp in seconds, paired with Audio.pts diff --git a/src/modules/gesture/gesture.py b/src/modules/gesture/gesture.py index 9028a17..554b8f4 100644 --- a/src/modules/gesture/gesture.py +++ b/src/modules/gesture/gesture.py @@ -1,6 +1,5 @@ import asyncio import os -from dataclasses import dataclass from typing import AsyncGenerator, Optional import numpy as np @@ -11,7 +10,6 @@ from src.modules.gesture.events import Motion from src.modules.text_to_speech.events import Audio - _HF_REPO = os.environ.get("HURI_EMAGE_REPO", "H-Liu1997/emage_audio") _EMAGE_SR = 16000 # EMAGE expects 16 kHz mono audio @@ -46,7 +44,7 @@ def __init__( device: Optional[str] = None, gpu_mem_fraction: float = _GPU_MEM_FRACTION, ): - print(f"[Gesture] importing torch...") + print("[Gesture] importing torch...") import torch # Pin algorithm selection so the kernels warmed below are the same ones @@ -61,7 +59,7 @@ def __init__( torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - print(f"[Gesture] importing emage...") + print("[Gesture] importing emage...") from .emage import EmageAudioModel, EmageVAEConv, EmageVQModel, EmageVQVAEConv self.device = torch.device( @@ -84,15 +82,25 @@ def __init__( print(f"[Gesture] WARNING could not cap GPU memory: {e!r}") print("[Gesture] loading face_vq...") - face_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/face").to(self.device) + face_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/face").to( + self.device + ) print("[Gesture] loading upper_vq...") - upper_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/upper").to(self.device) + upper_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/upper" + ).to(self.device) print("[Gesture] loading lower_vq...") - lower_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/lower").to(self.device) + lower_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/lower" + ).to(self.device) print("[Gesture] loading hands_vq...") - hands_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/hands").to(self.device) + hands_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/hands" + ).to(self.device) print("[Gesture] loading global_ae...") - global_ae = EmageVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/global").to(self.device) + global_ae = EmageVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/global" + ).to(self.device) self.motion_vq = EmageVQModel( face_model=face_vq, @@ -108,7 +116,7 @@ def __init__( self.model.eval() self._warmup() - print(f"[Gesture] ready") + print("[Gesture] ready") def _warmup(self) -> None: # The first inference pays one-time costs that are *shape- and @@ -128,21 +136,24 @@ def _warmup(self) -> None: # synchronize so the GPU work is finished before we report ready. # Best-effort: a failure here must never prevent the deployment coming up. import time + import torch # Representative window lengths (seconds). The dominant per-window # transformer forward is a fixed shape warmed by any window, but the # trailing remainder forward varies with total length, so warm a spread. - secs = sorted({ - round(s, 3) - for s in ( - _MIN_CHUNK_SEC, # first tiny window of an utterance - _CONTEXT_SEC, # context-only sized window - _CONTEXT_SEC + _MIN_CHUNK_SEC, # steady-state window - _CONTEXT_SEC + 2 * _MIN_CHUNK_SEC, # a larger fresh chunk - ) - if s and s > 0 - }) or [3.0] + secs = sorted( + { + round(s, 3) + for s in ( + _MIN_CHUNK_SEC, # first tiny window of an utterance + _CONTEXT_SEC, # context-only sized window + _CONTEXT_SEC + _MIN_CHUNK_SEC, # steady-state window + _CONTEXT_SEC + 2 * _MIN_CHUNK_SEC, # a larger fresh chunk + ) + if s and s > 0 + } + ) or [3.0] try: t0 = time.time() @@ -156,10 +167,12 @@ def _warmup(self) -> None: torch.cuda.synchronize(self.device) print( f"[Gesture] warmup pass {pass_idx} {s:.2f}s " - f"({n} samples @ {_WARMUP_SRC_SR} Hz) in {time.time() - ts:.2f}s", + f"({n} samples @ {_WARMUP_SRC_SR} Hz) \ +in {time.time() - ts:.2f}s", ) print( - f"[Gesture] warmup done ({len(secs)} shapes x2) in {time.time() - t0:.2f}s", + f"[Gesture] warmup done ({len(secs)} shapes x2) \ +in {time.time() - t0:.2f}s", ) except Exception as e: # noqa: BLE001 — warmup is an optimisation, never fatal print(f"[Gesture] WARNING warmup failed: {e!r}") @@ -170,7 +183,10 @@ def infer(self, audio_np: np.ndarray, source_sr: int = _EMAGE_SR) -> Motion: if source_sr != _EMAGE_SR: import librosa - audio_np = librosa.resample(audio_np, orig_sr=source_sr, target_sr=_EMAGE_SR) + + audio_np = librosa.resample( + audio_np, orig_sr=source_sr, target_sr=_EMAGE_SR + ) audio_ts = torch.from_numpy(audio_np).to(self.device).unsqueeze(0) speaker_id = torch.zeros(1, 1, dtype=torch.long, device=self.device) @@ -180,21 +196,50 @@ def infer(self, audio_np: np.ndarray, source_sr: int = _EMAGE_SR) -> Motion: latent_dict = self.model.inference(audio_ts, speaker_id, self.motion_vq) cfg = self.model.cfg - face_latent = latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None - upper_latent = latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None - hands_latent = latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None - lower_latent = latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None - face_index = torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] if cfg.cf > 0 else None - upper_index = torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] if cfg.cu > 0 else None - hands_index = torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] if cfg.ch > 0 else None - lower_index = torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] if cfg.cl > 0 else None + face_latent = ( + latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None + ) + upper_latent = ( + latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None + ) + hands_latent = ( + latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None + ) + lower_latent = ( + latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None + ) + face_index = ( + torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] + if cfg.cf > 0 + else None + ) + upper_index = ( + torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] + if cfg.cu > 0 + else None + ) + hands_index = ( + torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] + if cfg.ch > 0 + else None + ) + lower_index = ( + torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] + if cfg.cl > 0 + else None + ) all_pred = self.motion_vq.decode( - face_latent=face_latent, upper_latent=upper_latent, - lower_latent=lower_latent, hands_latent=hands_latent, - face_index=face_index, upper_index=upper_index, - lower_index=lower_index, hands_index=hands_index, - get_global_motion=True, ref_trans=ref_trans[:, 0], + face_latent=face_latent, + upper_latent=upper_latent, + lower_latent=lower_latent, + hands_latent=hands_latent, + face_index=face_index, + upper_index=upper_index, + lower_index=lower_index, + hands_index=hands_index, + get_global_motion=True, + ref_trans=ref_trans[:, 0], ) t = all_pred["motion_axis_angle"].shape[1] @@ -271,9 +316,11 @@ def __init__( # source sample rate; resampling to 16 kHz happens once inside infer(). self._lock = asyncio.Lock() self._sr: Optional[int] = None - self._buffer = np.empty(0, dtype=np.float32) # trailing audio (ctx + unprocessed) - self._buf_start = 0 # source-sr sample index of buffer[0] in utterance timeline - self._emitted = 0 # source-sr samples whose motion has been emitted + self._buffer = np.empty( + 0, dtype=np.float32 + ) # trailing audio (ctx + unprocessed) + self._buf_start = 0 # source-sr sample index of buffer[0] in utterance timeline + self._emitted = 0 # source-sr samples whose motion has been emitted # Last emitted frame per channel, used to ease the next segment's seam. # These persist across the end-of-utterance reset (see _end_utterance) so @@ -292,7 +339,7 @@ def _end_utterance(self) -> None: self._buf_start = 0 self._emitted = 0 - async def process(self, audio: Audio) -> AsyncGenerator[Motion, None]: # type: ignore[override] + async def process(self, audio: Audio) -> AsyncGenerator[Motion, None]: # Each chunk arrives as its own process() task on the shared per-session # instance, so serialise under a lock to keep the buffer ordered. async with self._lock: @@ -319,7 +366,9 @@ async def process(self, audio: Audio) -> AsyncGenerator[Motion, None]: # type: new_samples = global_end - self._emitted # Wait for more audio unless this is the final flush of the utterance. - if new_samples <= 0 or (not end_of_utterance and new_samples < min_new_samples): + if new_samples <= 0 or ( + not end_of_utterance and new_samples < min_new_samples + ): if end_of_utterance: self._end_utterance() return diff --git a/src/modules/modules.py b/src/modules/modules.py index d8fefb9..b551704 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -1,15 +1,22 @@ from typing import Dict, Type +from src.modules.gesture.gesture import Gesture from src.modules.rag.rag import RAG from src.modules.speech_to_text.microphone_vad import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG from src.modules.text_to_speech.text_to_speech import TTS -from src.modules.gesture.gesture import Gesture from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - modules: Dict[str, Type[Module]] = {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG, "tts": TTS, "gesture": Gesture} + modules: Dict[str, Type[Module]] = { + "mic": MIC, + "stt": STT, + "tag": TAG, + "rag": RAG, + "tts": TTS, + "gesture": Gesture, + } return modules diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index 529c7ff..2a825b3 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -185,7 +185,8 @@ def ingest_chunks( def chunk_strat(text: str, args, model: Any) -> list[str] | Any: """Pick the right chunking strategy based on args.""" if args.chunking == "semantic": - # Thomas: I need to import here, bceause it takes too much time earlier, or use a jupyter notebook to do it instead + # Thomas: I need to import here, bceause it takes too much time earlier, + # or use a jupyter notebook to do it instead from .semantic_chunker import SemanticChunker chunker = SemanticChunker( @@ -526,7 +527,10 @@ def main(): if needs_embeddings: if args.embedding_url: - print(f"Embedding remotely via {args.embedding_url} (model={args.embedding_model})") + print( + f"Embedding remotely via {args.embedding_url}" + f"(model={args.embedding_model})" + ) model = RemoteEmbedder(args.embedding_url, args.embedding_model) else: from sentence_transformers import SentenceTransformer diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0e65bbe..166cf04 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -4,7 +4,9 @@ from dataclasses import dataclass, field from typing import Any, AsyncGenerator +import httpx from pydantic import BaseModel +from qdrant_client.models import FieldCondition, Filter, MatchValue from ray import serve from ray.serve import handle @@ -12,9 +14,6 @@ from src.modules.speech_to_text.events import Sentence from src.modules.text_to_speech.events import Token -import httpx - -from qdrant_client.models import FieldCondition, Filter, MatchValue from .qdrant_utils import make_qdrant_client # Default character persona. Overridable per session via the `persona` key in the @@ -116,7 +115,9 @@ def _get_profile(self, collection: str, _user_id: str) -> list[str]: collection_name=collection, scroll_filter=Filter( must=[ - FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + FieldCondition( + key="_user_id", match=MatchValue(value=_user_id) + ), FieldCondition(key="type", match=MatchValue(value="profile")), ] ), @@ -194,7 +195,8 @@ def _build_prompt( "it is relevant, but always answer in character. If you don't know " "something, improvise in character rather than admitting you lack " "information or breaking character. " - "IMPORTANT: Reply in 1-3 short sentences maximum. Be extremely concise. No lists, no emojis, no long explanations." + "IMPORTANT: Reply in 1-3 short sentences maximum. Be extremely concise." + "No lists, no emojis, no long explanations." ) system_prompt = " ".join(parts) @@ -226,28 +228,28 @@ async def _stream_ollama( self, messages: list, max_tokens: int, temperature: float = 0.7 ) -> AsyncGenerator[str, None]: async with self._llm_client.stream( - "POST", - f"{self._cfg.llm_url}/api/chat", - json={ - "model": self._cfg.llm_model, - "messages": messages, - "stream": True, - "options": {"num_predict": max_tokens, "temperature": temperature}, - }, - ) as resp: - resp.raise_for_status() - async for line in resp.aiter_lines(): - if not line: - continue - try: - chunk = json.loads(line) - except json.JSONDecodeError: - continue - delta = chunk.get("message", {}).get("content", "") - if delta: - yield delta - if chunk.get("done"): - return + "POST", + f"{self._cfg.llm_url}/api/chat", + json={ + "model": self._cfg.llm_model, + "messages": messages, + "stream": True, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + delta = chunk.get("message", {}).get("content", "") + if delta: + yield delta + if chunk.get("done"): + return async def _stream_openai_compatible( self, @@ -261,35 +263,33 @@ async def _stream_openai_compatible( if api_key: headers["Authorization"] = f"Bearer {api_key}" async with self._llm_client.stream( - "POST", - url, - headers=headers, - json={ - "model": self._cfg.llm_model, - "messages": messages, - "max_tokens": max_tokens, - "temperature": temperature, - "stream": True, - }, - ) as resp: - resp.raise_for_status() - async for line in resp.aiter_lines(): - if not line or not line.startswith("data:"): - continue - payload = line[len("data:"):].strip() - if payload == "[DONE]": - return - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - continue - delta = ( - chunk.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ) - if delta: - yield delta + "POST", + url, + headers=headers, + json={ + "model": self._cfg.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": True, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + return + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + continue + delta = ( + chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") + ) + if delta: + yield delta async def _llm_stream( self, @@ -391,7 +391,11 @@ def __init__( ): super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) - print(f"[RAG] Initialized with user_id={_user_id}, language={language}, tone={tone}, response_format={response_format}, max_length={max_length}, temperature={temperature}, max_history_turns={max_history_turns}") + print( + f"[RAG] Initialized with user_id={_user_id}, language={language}, " + f"tone={tone}, response_format={response_format}, max_length={max_length}, " + f"temperature={temperature}, max_history_turns={max_history_turns}" + ) self.preferences = { "language": language, @@ -410,7 +414,7 @@ def __init__( self._max_history_turns = max_history_turns self.history: list[dict] = [] - async def process(self, data: Sentence) -> AsyncGenerator[Token, None]: # type: ignore[override] + async def process(self, data: Sentence) -> AsyncGenerator[Token, None]: query = RAGQuery( _user_id=self._user_id if self._user_id else "anonymous", question=data.text, diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 9e48a62..fac12e0 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -104,7 +104,7 @@ def __init__( self.running = False self.lock: asyncio.Lock = asyncio.Lock() - async def process(self, voice: Voice) -> Optional[Transcript]: # type: ignore[override] + async def process(self, voice: Voice) -> Optional[Transcript]: if voice.data is None: self.silence = True else: diff --git a/src/modules/text_to_speech/text_to_speech.py b/src/modules/text_to_speech/text_to_speech.py index ad49dab..95d2099 100644 --- a/src/modules/text_to_speech/text_to_speech.py +++ b/src/modules/text_to_speech/text_to_speech.py @@ -37,9 +37,10 @@ def _normalize_transcript(raw: str) -> str: else f"{_DEFAULT_INSTRUCTION}<|endofprompt|>{raw}" ) -_END_TEXT = object() # sentinel pushed into the text queue to close synth + +_END_TEXT = object() # sentinel pushed into the text queue to close synth _END_AUDIO = object() # sentinel pushed into the audio queue when synth completes -_DONE = object() # sentinel for exhausted sync generator +_DONE = object() # sentinel for exhausted sync generator @serve.deployment(name="TTS", max_ongoing_requests=200) @@ -66,7 +67,7 @@ def __init__( sys.path.insert(0, matcha_path) from cosyvoice.cli.cosyvoice import CosyVoice3 - + # Resolve the reference transcript here (deploy time on the GPU worker) # rather than at module import: importing this module must not require # HURI_VOICE_TRANSCRIPT, since modules.py imports it inside a broad @@ -171,7 +172,7 @@ def __init__(self, _handle: handle.DeploymentHandle): # and silently drop trailing words). self._push_lock = asyncio.Lock() - async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: ignore[override] + async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # Acquire BEFORE any await so lock-acquisition order matches token order. # Setup + push happen under the lock; only the first token of an # utterance goes on to drain/yield audio (outside the lock, so pushes of @@ -181,7 +182,9 @@ async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: i if is_first: self._session_id = str(uuid.uuid4()) self._audio_q = asyncio.Queue() - print(f"[TTS-client] [{self._session_id}] opening new utterance session") + print( + f"[TTS-client] [{self._session_id}] opening new utterance session" + ) await self._handle.start_session.remote(self._session_id) self._stream_task = asyncio.create_task( self._drain_audio(self._session_id, self._audio_q) @@ -210,7 +213,9 @@ async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: i print(f"[TTS-client] [{sid}] utterance complete ({count} chunks)") sample_rate = await self._handle.get_sample_rate.remote() - yield Audio(data=np.array([], dtype=np.float32), sample_rate=sample_rate, end=True) + yield Audio( + data=np.array([], dtype=np.float32), sample_rate=sample_rate, end=True + ) finally: async with self._push_lock: self._session_id = None diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index c1485f0..db5f633 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -1,5 +1,4 @@ import logging - import struct from dataclasses import asdict @@ -18,12 +17,17 @@ class Sender(Module): """Sender Module Send output data to the client. + This data must be JSON serialisable, like a dataclass. - Audio wire format: [4B sample_rate uint32][1B end][8B pts float64][float32 PCM]. - Motion wire format: [8B pts float64][4B fps uint32][4B n_frames uint32] - [poses float32 n*165][expressions float32 n*100][trans float32 n*3]. - input: auto, output: None""" + Audio wire format: + [4B sample_rate uint32][1B end][8B pts float64][float32 PCM]. + Motion wire format: + [8B pts float64][4B fps uint32][4B n_frames uint32] + [poses float32 n*165][expressions float32 n*100][trans float32 n*3]. + + input: auto, + output: None""" output_type = None