diff --git a/config/client_aux.yaml b/config/client_aux.yaml index fe3e332..c219ee4 100644 --- a/config/client_aux.yaml +++ b/config/client_aux.yaml @@ -1,6 +1,6 @@ huri_url: ws://localhost:8000/session -topic_list: [question] +topic_list: [transcript, question, token, motion] senders: audio: @@ -8,6 +8,12 @@ senders: args: sample_rate: 16000 frame_duration: 0.030 + text: + name: text + topic: question + args: + sample_rate: 16000 + frame_duration: 0.030 modules: mic: @@ -26,3 +32,27 @@ modules: tag: name: tag logging: INFO + emo: + name: emo + args: + block_duration: ${senders.audio.args.frame_duration} + eag: + name: eag + qag: + name: qag + rag: + name: rag + args: + language: en + tone: formal + response_format: paragraph + max_length: 1024 + logging: INFO + tts: + name: tts + args: + min_clause_chars: 20 + logging: INFO + gesture: + name: gesture + logging: INFO diff --git a/config/client_full.yaml b/config/client_full.yaml index 1481430..4dc4c89 100644 --- a/config/client_full.yaml +++ b/config/client_full.yaml @@ -8,6 +8,20 @@ senders: args: sample_rate: 16000 frame_duration: 0.030 + text: + name: text + topic: question + args: + sample_rate: 16000 + frame_duration: 0.030 + +hooks: + audio: + name: audio + topics: [audio] + args: + incoming_sample_rate: ${senders.audio.args.sample_rate} + sample_rate: 44100 modules: mic: @@ -26,6 +40,14 @@ modules: tag: name: tag logging: INFO + emo: + name: emo + args: + block_duration: ${senders.audio.args.frame_duration} + eag: + name: eag + qag: + name: qag rag: name: rag args: diff --git a/src/core/client.py b/src/core/client.py index 2af97f2..c651c07 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -93,8 +93,8 @@ async def _receive_loop(self, ws: websockets.ClientConnection): print(f"<< bytes ({len(msg)}B, no topic)") continue (topic_len,) = struct.unpack(">H", msg[:2]) - topic = msg[2:2 + topic_len].decode() - payload = msg[2 + topic_len:] + topic = msg[2 : 2 + topic_len].decode() + payload = msg[2 + topic_len :] if topic == "audio" and len(payload) >= 13: sample_rate, end_flag, pts = struct.unpack(">IBd", payload[:13]) diff --git a/src/core/events.py b/src/core/events.py index dfaf66b..43176d2 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -52,7 +52,8 @@ async def publish(self, event_topic, data): if event_topic not in ("audio_in",): # skip mic-frame spam logger.info( "[GRAPH] publish topic=%r subscribers=%s", - event_topic, [type(m).__name__ for m in subs], + event_topic, + [type(m).__name__ for m in subs], ) for module in subs: asyncio.create_task(self._run(module, data)) @@ -68,11 +69,15 @@ async def _run(self, module: Module, data): continue logger.info( "[GRAPH] %s -> %r: %s", - type(module).__name__, module.output_type, _summarize(item), + type(module).__name__, + module.output_type, + _summarize(item), ) await self.publish(module.output_type, item) except Exception: - logger.exception("[GRAPH] async generator failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] async generator failed in %s", type(module).__name__ + ) else: try: @@ -80,14 +85,20 @@ async def _run(self, module: Module, data): if value is not None: logger.info( "[GRAPH] %s -> %r: %s", - type(module).__name__, module.output_type, _summarize(value), + type(module).__name__, + module.output_type, + _summarize(value), ) await self.publish(module.output_type, value) except Exception: - logger.exception("[GRAPH] coroutine failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] coroutine failed in %s", type(module).__name__ + ) except Exception: - logger.exception("[GRAPH] process() call failed in %s", type(module).__name__) + logger.exception( + "[GRAPH] process() call failed in %s", type(module).__name__ + ) def _summarize(item) -> str: diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/emotion/__init__.py b/src/modules/emotion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/emotion/emotion_aggregator.py b/src/modules/emotion/emotion_aggregator.py new file mode 100644 index 0000000..f3144fa --- /dev/null +++ b/src/modules/emotion/emotion_aggregator.py @@ -0,0 +1,73 @@ +from collections import defaultdict +from typing import Dict, Optional + +from src.core.module import Module +from src.modules.rag.events import PartialQuestion + +from .events import Emotion + + +class EAG(Module): + """EAG Module + + Aggregate all emotions and send when voice end. + + input: emotion, + output: partial_question + + :ema_alpha: if not None, the aggragation will use ema computation instead + of average. Recents emotion will have stronger impact on the final score. + Lower alpha will make impact lower, and higher alpha will make it higher. \ + Default alpha would be ~0.3. + """ + + input_type = "emotion" + output_type = "partial_question" + + def __init__(self, ema_alpha: Optional[float] = None): + super().__init__() + + self.scores: Dict[str, float] = defaultdict(float) + self.count: int = 0 + + self.ema_alpha = ema_alpha + + def _finalize(self) -> Emotion: + avg_scores = ( + {label: score / self.count for label, score in self.scores.items()} + if self.ema_alpha is None + else self.scores + ) + + best_label = max(avg_scores, key=lambda label: avg_scores[label]) + + result = Emotion( + label=best_label, + confidence=avg_scores[best_label], + scores=avg_scores, + end=True, + ) + + self.scores.clear() + self.count = 0 + + return result + + async def process(self, emotion: Emotion) -> Optional[PartialQuestion]: + if self.ema_alpha is not None: + for label, score in emotion.scores.items(): + self.scores[label] = ( + self.ema_alpha * score + (1 - self.ema_alpha) * self.scores[label] + ) + else: + for label, score in emotion.scores.items(): + self.scores[label] += score + + self.count += 1 + + if emotion.end: + emotion = self._finalize() + + return PartialQuestion(transcript=None, emotion=emotion) + + return None diff --git a/src/modules/emotion/events.py b/src/modules/emotion/events.py new file mode 100644 index 0000000..4c29c37 --- /dev/null +++ b/src/modules/emotion/events.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Dict + +from src.core.events import EventData + + +@dataclass +class Emotion(EventData): + label: str + confidence: float + scores: Dict[str, float] + end: bool diff --git a/src/modules/emotion/prosody_analysis.py b/src/modules/emotion/prosody_analysis.py new file mode 100644 index 0000000..cd50b3d --- /dev/null +++ b/src/modules/emotion/prosody_analysis.py @@ -0,0 +1,109 @@ +import asyncio +from typing import List, Optional + +import numpy as np +import torch +from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor + +from src.core.module import Module +from src.modules.speech_to_text.events import Voice + +from .events import Emotion + + +class EMO(Module): + """EMO Module + + Prosody Analysis of user voice speech. + + input: voice, + output: emotion + + :model_name: name of the Emotion Analysis model. + :sample_rate: size of received voice audio. Usually 8000, 16000 or 48000. + :block_duration: size of received voice audio (in s). + :analysis_window: duration of audio per analysis (in s). + """ + + input_type = "voice" + output_type = "emotion" + + def __init__( + self, + model_name: str = "superb/hubert-large-superb-er", + sample_rate: int = 16000, + block_duration: float = 0.020, # s + analysis_window: float = 4.0, # s + ): + super().__init__() + + self.model = AutoModelForAudioClassification.from_pretrained(model_name) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) + + self.sample_rate = sample_rate + self.window_size = int(analysis_window / block_duration) + + self.buffer: List[np.ndarray] = [] + + self.silence: bool = True + + self.running = False + self.lock: asyncio.Lock = asyncio.Lock() + + def _predict_emotion(self, audio_np: np.ndarray): + inputs = self.feature_extractor( + audio_np, sampling_rate=self.sample_rate, return_tensors="pt", padding=True + ) + + with torch.no_grad(): + logits = self.model(**inputs).logits + probs = torch.softmax(logits, dim=-1)[0] + + predicted_id = int(torch.argmax(probs).item()) + + labels = self.model.config.id2label + + return { + "label": labels[predicted_id], + "confidence": float(probs[predicted_id]), + "scores": {labels[i]: float(probs[i]) for i in range(len(labels))}, + } + + async def process(self, voice: Voice) -> Optional[Emotion]: + if voice.data is None: + self.silence = True + else: + self.silence = False + async with self.lock: + self.buffer.append(voice.data) + + async with self.lock: + if self.running: + return None + self.running = True + + async with self.lock: + buffer_size = len(self.buffer) + if buffer_size == 0 or ( + self.silence is False and buffer_size < self.window_size + ): + self.running = False + return None + processing_chunks = self.buffer[: self.window_size] + + processing_audio = np.concatenate(processing_chunks, axis=0) + + emotion_result = await asyncio.to_thread( + self._predict_emotion, audio_np=processing_audio + ) + + async with self.lock: + self.buffer = self.buffer[self.window_size :] + self.running = False + + return Emotion( + emotion_result["label"], + emotion_result["confidence"], + emotion_result["scores"], + self.silence, + ) diff --git a/src/modules/events.py b/src/modules/events.py index 93779a2..8cdc066 100644 --- a/src/modules/events.py +++ b/src/modules/events.py @@ -1,9 +1,11 @@ from typing import Dict, Type from src.core.events import EventData -from src.modules.speech_to_text.events import Sentence, Transcript, Voice -from src.modules.text_to_speech.events import Audio, Token +from src.modules.emotion.events import Emotion from src.modules.gesture.events import Motion +from src.modules.rag.events import PartialQuestion, RAGQuestion +from src.modules.speech_to_text.events import Transcript, Voice +from src.modules.text_to_speech.events import Token def get_events() -> Dict[str, Type[EventData | bytes]]: @@ -12,9 +14,11 @@ def get_events() -> Dict[str, Type[EventData | bytes]]: "audio": bytes, "voice": Voice, "transcript": Transcript, - "question": Sentence, + "emotion": Emotion, + "partial_question": PartialQuestion, + "question": RAGQuestion, "token": Token, - "motion": Motion + "motion": Motion, } return events diff --git a/src/modules/gesture/emage/modeling.py b/src/modules/gesture/emage/modeling.py index b4d71ab..2f50d2a 100644 --- a/src/modules/gesture/emage/modeling.py +++ b/src/modules/gesture/emage/modeling.py @@ -14,6 +14,7 @@ VQEncoderV5, VQEncoderV6, WavEncoder, + axis_angle_to_matrix, axis_angle_to_rotation_6d, matrix_to_axis_angle, matrix_to_rotation_6d, @@ -21,7 +22,6 @@ rotation_6d_to_axis_angle, rotation_6d_to_matrix, velocity2position, - axis_angle_to_matrix, ) @@ -47,14 +47,21 @@ class EmageVQVAEConv(PreTrainedModel): def __init__(self, config): super().__init__(config) self.encoder = VQEncoderV5(config) - self.quantizer = Quantizer(config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda) + self.quantizer = Quantizer( + config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda + ) self.decoder = VQDecoderV5(config) def forward(self, inputs): pre_latent = self.encoder(inputs) embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) rec_pose = self.decoder(vq_latent) - return {"poses_feat": vq_latent, "embedding_loss": embedding_loss, "perplexity": perplexity, "rec_pose": rec_pose} + return { + "poses_feat": vq_latent, + "embedding_loss": embedding_loss, + "perplexity": perplexity, + "rec_pose": rec_pose, + } def map2index(self, inputs): pre_latent = self.encoder(inputs) @@ -72,8 +79,8 @@ def decode(self, index): def decode_from_latent(self, latent): z_flattened = latent.contiguous().view(-1, self.quantizer.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.quantizer.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.quantizer.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.quantizer.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -86,20 +93,118 @@ class EmageVQModel(nn.Module): def __init__(self, face_model, upper_model, hands_model, lower_model, global_model): super().__init__() self.joint_mask_upper = [ - False, False, False, True, False, False, True, False, False, True, - False, False, True, True, True, True, True, True, True, True, - True, True, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, + False, + False, + False, + True, + False, + False, + True, + False, + False, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ] self.joint_mask_lower = [ - True, True, True, False, True, True, False, True, True, False, - True, True, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, False, - False, False, False, False, False, + True, + True, + True, + False, + True, + True, + False, + True, + True, + False, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ] self.vq_model_face = face_model self.vq_model_upper = upper_model @@ -107,21 +212,37 @@ def __init__(self, face_model, upper_model, hands_model, lower_model, global_mod self.vq_model_lower = lower_model self.global_motion = global_model - def spilt_inputs(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): + def spilt_inputs( + self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None + ): bs, t, j6 = smplx_body_rot6d.shape smplx_body_rot6d = smplx_body_rot6d.reshape(bs, t, j6 // 6, 6) jaw_rot6d = smplx_body_rot6d[:, :, 22:23, :].reshape(bs, t, 6) face = torch.cat([jaw_rot6d, expression], dim=2) - upper_rot6d = smplx_body_rot6d[:, :, self.joint_mask_upper, :].reshape(bs, t, 78) + upper_rot6d = smplx_body_rot6d[:, :, self.joint_mask_upper, :].reshape( + bs, t, 78 + ) hands_rot6d = smplx_body_rot6d[:, :, 25:55, :].reshape(bs, t, 180) - lower_rot6d = smplx_body_rot6d[:, :, self.joint_mask_lower, :].reshape(bs, t, 54) - tar_contact = torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) if tar_contact is None else tar_contact - tar_trans = torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) if tar_trans is None else tar_trans + lower_rot6d = smplx_body_rot6d[:, :, self.joint_mask_lower, :].reshape( + bs, t, 54 + ) + tar_contact = ( + torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) + if tar_contact is None + else tar_contact + ) + tar_trans = ( + torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) + if tar_trans is None + else tar_trans + ) lower = torch.cat([lower_rot6d, tar_trans, tar_contact], dim=2) return dict(face=face, upper=upper_rot6d, hands=hands_rot6d, lower=lower) def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): - inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + inputs = self.spilt_inputs( + smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans + ) return dict( face=self.vq_model_face.map2index(inputs["face"]), upper=self.vq_model_upper.map2index(inputs["upper"]), @@ -129,8 +250,12 @@ def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=No lower=self.vq_model_lower.map2index(inputs["lower"]), ) - def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): - inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + def map2latent( + self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None + ): + inputs = self.spilt_inputs( + smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans + ) return dict( face=self.vq_model_face.map2latent(inputs["face"]), upper=self.vq_model_upper.map2latent(inputs["upper"]), @@ -140,11 +265,27 @@ def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=N def decode( self, - face_index=None, upper_index=None, hands_index=None, lower_index=None, - face_latent=None, upper_latent=None, hands_latent=None, lower_latent=None, - get_global_motion=False, ref_trans=None, + face_index=None, + upper_index=None, + hands_index=None, + lower_index=None, + face_latent=None, + upper_latent=None, + hands_latent=None, + lower_latent=None, + get_global_motion=False, + ref_trans=None, ): - for t in [face_index, upper_index, hands_index, lower_index, face_latent, upper_latent, hands_latent, lower_latent]: + for t in [ + face_index, + upper_index, + hands_index, + lower_index, + face_latent, + upper_latent, + hands_latent, + lower_latent, + ]: if t is not None: bs, seq = t.shape[:2] break @@ -163,34 +304,48 @@ def decode( if upper_index is not None: upper_6d = self.vq_model_upper.decode(upper_index) - upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif upper_latent is not None: upper_6d = self.vq_model_upper.decode_from_latent(upper_latent) - upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: upper = torch.zeros(bs, seq, 39, device=self.vq_model_upper.device) if hands_index is not None: hands_6d = self.vq_model_hands.decode(hands_index) - hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif hands_latent is not None: hands_6d = self.vq_model_hands.decode_from_latent(hands_latent) - hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: hands = torch.zeros(bs, seq, 90, device=self.vq_model_hands.device) if lower_index is not None: lower_mix = self.vq_model_lower.decode(lower_index) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] - lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) elif lower_latent is not None: lower_mix = self.vq_model_lower.decode_from_latent(lower_latent) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] - lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape( + bs, seq, -1 + ) else: lower = torch.zeros(bs, seq, 27, device=self.vq_model_lower.device) transfoot = torch.zeros(bs, seq, 7, device=self.vq_model_lower.device) - lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, seq, -1, 3)).reshape(bs, seq, -1) + lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, seq, -1, 3)).reshape( + bs, seq, -1 + ) lower_mix = torch.cat([lower_6d, transfoot], dim=-1) upper2all = recover_from_mask_ts(upper, self.joint_mask_upper) @@ -198,8 +353,10 @@ def decode( lower2all = recover_from_mask_ts(lower, self.joint_mask_lower) all_motion_axis_angle = upper2all + hands2all + lower2all - all_motion_axis_angle[:, :, 22 * 3:22 * 3 + 3] = face_jaw - all_motion_rot6d = axis_angle_to_rotation_6d(all_motion_axis_angle.reshape(bs, seq, 55, 3)).reshape(bs, seq, 55 * 6) + all_motion_axis_angle[:, :, 22 * 3 : 22 * 3 + 3] = face_jaw + all_motion_rot6d = axis_angle_to_rotation_6d( + all_motion_axis_angle.reshape(bs, seq, 55, 3) + ).reshape(bs, seq, 55 * 6) all_motion4inference = torch.cat([all_motion_rot6d, transfoot], dim=2) global_motion = None @@ -218,8 +375,12 @@ def _get_global_motion(self, lower_body, ref_trans): rec_trans_v_s = global_motion["rec_pose"][:, :, 54:57] if len(ref_trans.shape) == 2: ref_trans = ref_trans.unsqueeze(0).repeat(rec_trans_v_s.shape[0], 1, 1) - rec_x_trans = velocity2position(rec_trans_v_s[:, :, 0:1], 1 / 30, ref_trans[:, 0, 0:1]) - rec_z_trans = velocity2position(rec_trans_v_s[:, :, 2:3], 1 / 30, ref_trans[:, 0, 2:3]) + rec_x_trans = velocity2position( + rec_trans_v_s[:, :, 0:1], 1 / 30, ref_trans[:, 0, 0:1] + ) + rec_z_trans = velocity2position( + rec_trans_v_s[:, :, 2:3], 1 / 30, ref_trans[:, 0, 2:3] + ) rec_y_trans = rec_trans_v_s[:, :, 1:2] return torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) @@ -233,41 +394,97 @@ def __init__(self, config: EmageAudioConfig): self.cfg = config self.audio_encoder_face = WavEncoder(self.cfg.audio_f) self.audio_encoder_body = WavEncoder(self.cfg.audio_f) - self.speaker_embedding_body = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) - self.speaker_embedding_face = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) - self.mask_embedding = nn.Parameter(torch.zeros(1, 1, self.cfg.pose_dims + 3 + 4)) - nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size ** -0.5) + self.speaker_embedding_body = nn.Embedding( + self.cfg.speaker_dims, self.cfg.hidden_size + ) + self.speaker_embedding_face = nn.Embedding( + self.cfg.speaker_dims, self.cfg.hidden_size + ) + self.mask_embedding = nn.Parameter( + torch.zeros(1, 1, self.cfg.pose_dims + 3 + 4) + ) + nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size**-0.5) args_top = copy.deepcopy(self.cfg) args_top.vae_layer = 3 args_top.vae_length = self.cfg.motion_f args_top.vae_test_dim = self.cfg.pose_dims + 3 + 4 self.motion_encoder = VQEncoderV6(args_top) - self.bodyhints_face = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) - self.bodyhints_body = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) + self.bodyhints_face = MLP( + self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f + ) + self.bodyhints_body = MLP( + self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f + ) self.audio_body_motion_proj = nn.Linear(self.cfg.audio_f, self.cfg.hidden_size) self.moton_proj = nn.Linear(self.cfg.motion_f, self.cfg.hidden_size) - self.position_embeddings = PeriodicPositionalEncoding(self.cfg.hidden_size, period=self.cfg.pose_length, max_seq_len=self.cfg.pose_length) - self.transformer_en_layer = nn.TransformerEncoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) - self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) - self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) - self.audio_motion_cross_attn = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=8) - self.motion2latent_upper = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.motion2latent_hands = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.motion2latent_lower = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) - self.body_motion_decoder_upper = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.body_motion_decoder_hands = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.body_motion_decoder_lower = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) - self.motion_out_proj_upper = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_out_proj_hands = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_out_proj_lower = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_upper = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_hands = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.motion_cls_lower = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.audio_face_motion_proj = nn.Linear(self.cfg.audio_f + self.cfg.motion_f, self.cfg.hidden_size) - self.face_motion_decoder = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=4) + self.position_embeddings = PeriodicPositionalEncoding( + self.cfg.hidden_size, + period=self.cfg.pose_length, + max_seq_len=self.cfg.pose_length, + ) + self.transformer_en_layer = nn.TransformerEncoderLayer( + d_model=self.cfg.hidden_size, + nhead=4, + dim_feedforward=self.cfg.hidden_size * 2, + ) + self.motion_self_encoder = nn.TransformerEncoder( + self.transformer_en_layer, num_layers=1 + ) + self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer( + d_model=self.cfg.hidden_size, + nhead=4, + dim_feedforward=self.cfg.hidden_size * 2, + ) + self.audio_motion_cross_attn = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=8 + ) + self.motion2latent_upper = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.motion2latent_hands = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.motion2latent_lower = MLP( + self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size + ) + self.body_motion_decoder_upper = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.body_motion_decoder_hands = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.body_motion_decoder_lower = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=1 + ) + self.motion_out_proj_upper = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_out_proj_hands = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_out_proj_lower = nn.Linear( + self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_upper = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_hands = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.motion_cls_lower = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) + self.audio_face_motion_proj = nn.Linear( + self.cfg.audio_f + self.cfg.motion_f, self.cfg.hidden_size + ) + self.face_motion_decoder = nn.TransformerDecoder( + self.audio_motion_cross_attn_layer, num_layers=4 + ) self.face_out_proj = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) - self.face_cls = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.face_cls = MLP( + self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size + ) def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): masked_embeddings = self.mask_embedding.expand_as(masked_motion) @@ -281,30 +498,40 @@ def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): audio2body_fea = self.audio_encoder_body(audio) if audio2face_fea.shape[1] > body_hint_face.shape[1]: - audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + audio2face_fea = audio2face_fea[:, : body_hint_face.shape[1]] if audio2body_fea.shape[1] > body_hint_face.shape[1]: - audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + audio2face_fea = audio2face_fea[:, : body_hint_face.shape[1]] bs, t, _ = audio2face_fea.shape - speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat(1, t, 1) + speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat( + 1, t, 1 + ) speaker_face_fea_proj = self.speaker_embedding_face(speaker_id).repeat(1, t, 1) - audio2face_fea_proj = self.audio_face_motion_proj(torch.cat([audio2face_fea, body_hint_face], dim=2)) + audio2face_fea_proj = self.audio_face_motion_proj( + torch.cat([audio2face_fea, body_hint_face], dim=2) + ) face_proj = self.position_embeddings(speaker_face_fea_proj) - decode_face = self.face_motion_decoder(tgt=face_proj.permute(1, 0, 2), memory=audio2face_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + decode_face = self.face_motion_decoder( + tgt=face_proj.permute(1, 0, 2), memory=audio2face_fea_proj.permute(1, 0, 2) + ).permute(1, 0, 2) face_latent = self.face_out_proj(decode_face) classify_face = self.face_cls(face_latent) masked_motion_proj = self.moton_proj(body_hint_body) masked_motion_proj = self.position_embeddings(masked_motion_proj) masked_motion_proj = speaker_motion_fea_proj + masked_motion_proj - motion_fea = self.motion_self_encoder(masked_motion_proj.permute(1, 0, 2)).permute(1, 0, 2) + motion_fea = self.motion_self_encoder( + masked_motion_proj.permute(1, 0, 2) + ).permute(1, 0, 2) audio2body_fea_proj = self.audio_body_motion_proj(audio2body_fea) motion_fea = motion_fea + speaker_motion_fea_proj motion_fea = self.position_embeddings(motion_fea) - audio2body_fea_cross = self.audio_motion_cross_attn(tgt=motion_fea.permute(1, 0, 2), memory=audio2body_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + audio2body_fea_cross = self.audio_motion_cross_attn( + tgt=motion_fea.permute(1, 0, 2), memory=audio2body_fea_proj.permute(1, 0, 2) + ).permute(1, 0, 2) if not use_audio: audio2body_fea_cross = audio2body_fea_cross * 0.0 motion_fea = motion_fea + audio2body_fea_cross @@ -313,9 +540,21 @@ def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): hands_latent = self.motion2latent_hands(motion_fea) lower_latent = self.motion2latent_lower(motion_fea) - motion_upper_refine = self.body_motion_decoder_upper(tgt=upper_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(hands_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) - motion_hands_refine = self.body_motion_decoder_hands(tgt=hands_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) - motion_lower_refine = self.body_motion_decoder_lower(tgt=lower_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + hands_latent).permute(1, 0, 2)).permute(1, 0, 2) + motion_upper_refine = self.body_motion_decoder_upper( + tgt=upper_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(hands_latent + lower_latent).permute(1, 0, 2), + ).permute(1, 0, 2) + motion_hands_refine = self.body_motion_decoder_hands( + tgt=hands_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(upper_latent + lower_latent).permute(1, 0, 2), + ).permute(1, 0, 2) + motion_lower_refine = self.body_motion_decoder_lower( + tgt=lower_latent.permute(1, 0, 2) + + speaker_motion_fea_proj.permute(1, 0, 2), + memory=(upper_latent + hands_latent).permute(1, 0, 2), + ).permute(1, 0, 2) upper_latent = self.motion_out_proj_upper(upper_latent + motion_upper_refine) hands_latent = self.motion_out_proj_hands(hands_latent + motion_hands_refine) lower_latent = self.motion_out_proj_lower(lower_latent + motion_lower_refine) @@ -344,12 +583,12 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): fake_foot_and_trans = torch.zeros(bs, length, 7).to(audio.device) fake_motion = torch.cat([fake_motion, fake_foot_and_trans], dim=-1) if masked_motion is not None: - fake_motion[:, :masked_motion.shape[1]] = masked_motion + fake_motion[:, : masked_motion.shape[1]] = masked_motion masked_motion = fake_motion fake_mask = torch.ones_like(masked_motion) if mask is not None: - fake_mask[:, :mask.shape[1]] = mask + fake_mask[:, : mask.shape[1]] = mask mask = fake_mask bs, total_len, c = masked_motion.shape @@ -371,34 +610,71 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): window_motion = masked_motion[:, start_idx:end_idx, :].clone() window_motion[:, :pre_frames, :] = torch.where( (window_mask[:, :pre_frames, :] == 0), - masked_motion[:, start_idx:start_idx + pre_frames, :], + masked_motion[:, start_idx : start_idx + pre_frames, :], last_motion, ) window_mask[:, :pre_frames, :] = 0 audio_slice_len = (end_idx - start_idx) * (16000 // 30) - audio_slice = audio[:, start_idx * (16000 // 30):start_idx * (16000 // 30) + audio_slice_len] - net_out_val = self.forward(audio_slice, speaker_id, masked_motion=window_motion, mask=window_mask, use_audio=True) - - _, cls_face = torch.max(F.log_softmax(net_out_val["cls_face"], dim=2), dim=2) - _, cls_upper = torch.max(F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2) - _, cls_hands = torch.max(F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2) - _, cls_lower = torch.max(F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2) - - face_latent = net_out_val["rec_face"] if self.cfg.lf > 0 and self.cfg.cf == 0 else None - upper_latent = net_out_val["rec_upper"] if self.cfg.lu > 0 and self.cfg.cu == 0 else None - hands_latent = net_out_val["rec_hands"] if self.cfg.lh > 0 and self.cfg.ch == 0 else None - lower_latent = net_out_val["rec_lower"] if self.cfg.ll > 0 and self.cfg.cl == 0 else None + audio_slice = audio[ + :, + start_idx * (16000 // 30) : start_idx * (16000 // 30) + audio_slice_len, + ] + net_out_val = self.forward( + audio_slice, + speaker_id, + masked_motion=window_motion, + mask=window_mask, + use_audio=True, + ) + + _, cls_face = torch.max( + F.log_softmax(net_out_val["cls_face"], dim=2), dim=2 + ) + _, cls_upper = torch.max( + F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2 + ) + _, cls_hands = torch.max( + F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2 + ) + _, cls_lower = torch.max( + F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2 + ) + + face_latent = ( + net_out_val["rec_face"] + if self.cfg.lf > 0 and self.cfg.cf == 0 + else None + ) + upper_latent = ( + net_out_val["rec_upper"] + if self.cfg.lu > 0 and self.cfg.cu == 0 + else None + ) + hands_latent = ( + net_out_val["rec_hands"] + if self.cfg.lh > 0 and self.cfg.ch == 0 + else None + ) + lower_latent = ( + net_out_val["rec_lower"] + if self.cfg.ll > 0 and self.cfg.cl == 0 + else None + ) face_index = cls_face if self.cfg.cf > 0 else None upper_index = cls_upper if self.cfg.cu > 0 else None hands_index = cls_hands if self.cfg.ch > 0 else None lower_index = cls_lower if self.cfg.cl > 0 else None decode_dict = vq_model.decode( - face_latent=face_latent, upper_latent=upper_latent, - lower_latent=lower_latent, hands_latent=hands_latent, - face_index=face_index, upper_index=upper_index, - lower_index=lower_index, hands_index=hands_index, + face_latent=face_latent, + upper_latent=upper_latent, + lower_latent=lower_latent, + hands_latent=hands_latent, + face_index=face_index, + upper_index=upper_index, + lower_index=lower_index, + hands_index=hands_index, ) last_motion = decode_dict["all_motion4inference"][:, -pre_frames:, :] @@ -419,14 +695,24 @@ def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): final_motion = masked_motion[:, final_start:final_end, :].clone() final_motion[:, :pre_frames, :] = torch.where( (final_mask[:, :pre_frames, :] == 0), - masked_motion[:, final_start:final_start + pre_frames, :], + masked_motion[:, final_start : final_start + pre_frames, :], last_motion, ) final_mask[:, :pre_frames, :] = 0 audio_slice_len = (final_end - final_start) * (16000 // 30) - audio_slice = audio[:, final_start * (16000 // 30):final_start * (16000 // 30) + audio_slice_len] - net_out_val = self.forward(audio_slice, speaker_id, masked_motion=final_motion, mask=final_mask, use_audio=True) + audio_slice = audio[ + :, + final_start * (16000 // 30) : final_start * (16000 // 30) + + audio_slice_len, + ] + net_out_val = self.forward( + audio_slice, + speaker_id, + masked_motion=final_motion, + mask=final_mask, + use_audio=True, + ) rec_all_face.append(net_out_val["rec_face"]) rec_all_upper.append(net_out_val["rec_upper"]) diff --git a/src/modules/gesture/emage/processing.py b/src/modules/gesture/emage/processing.py index 2d128b1..91266b0 100644 --- a/src/modules/gesture/emage/processing.py +++ b/src/modules/gesture/emage/processing.py @@ -1,4 +1,5 @@ import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -121,7 +122,7 @@ def velocity2position(data_seq, dt, init_pos): if i == 0: res_trans.append(init_pos.unsqueeze(1)) else: - res = data_seq[:, i - 1:i] * dt + res_trans[-1] + res = data_seq[:, i - 1 : i] * dt + res_trans[-1] res_trans.append(res) return torch.cat(res_trans, dim=1) @@ -156,13 +157,15 @@ def forward(self, z): assert z.shape[-1] == self.e_dim z_flattened = z.contiguous().view(-1, self.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) - loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean((z_q.detach() - z) ** 2) + loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean( + (z_q.detach() - z) ** 2 + ) z_q = z + (z_q - z).detach() min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) e_mean = torch.mean(min_encodings, dim=0) @@ -173,8 +176,8 @@ def map2index(self, z): assert z.shape[-1] == self.e_dim z_flattened = z.contiguous().view(-1, self.e_dim) d = ( - torch.sum(z_flattened ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -288,24 +291,51 @@ def forward(self, inputs): class BasicBlock(nn.Module): - def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d): + def __init__( + self, + inplanes, + planes, + ker_size, + stride=1, + downsample=None, + dilation=1, + first_dilation=None, + act_layer=nn.LeakyReLU, + norm_layer=nn.BatchNorm1d, + ): super().__init__() self.conv1 = nn.Conv1d( - inplanes, planes, kernel_size=ker_size, stride=stride, - padding=first_dilation, dilation=dilation, bias=True, + inplanes, + planes, + kernel_size=ker_size, + stride=stride, + padding=first_dilation, + dilation=dilation, + bias=True, ) self.bn1 = norm_layer(planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv1d( - planes, planes, kernel_size=ker_size, padding=ker_size // 2, - dilation=dilation, bias=True, + planes, + planes, + kernel_size=ker_size, + padding=ker_size // 2, + dilation=dilation, + bias=True, ) self.bn2 = norm_layer(planes) self.act2 = act_layer(inplace=True) if downsample is not None: self.downsample = nn.Sequential( - nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, - padding=first_dilation, dilation=dilation, bias=True), + nn.Conv1d( + inplanes, + planes, + stride=stride, + kernel_size=ker_size, + padding=first_dilation, + dilation=dilation, + bias=True, + ), norm_layer(planes), ) else: @@ -330,10 +360,16 @@ def __init__(self, out_dim, audio_in=1): super().__init__() self.out_dim = out_dim self.feat_extractor = nn.Sequential( - BasicBlock(audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True), - BasicBlock(out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True), + BasicBlock( + audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True + ), + BasicBlock( + out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True + ), BasicBlock(out_dim // 4, out_dim // 4, 15, 1, first_dilation=7), - BasicBlock(out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True), + BasicBlock( + out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True + ), BasicBlock(out_dim // 2, out_dim // 2, 15, 1, first_dilation=7), BasicBlock(out_dim // 2, out_dim, 15, 3, first_dilation=0, downsample=True), ) @@ -367,14 +403,16 @@ def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(period, d_model) position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) repeat_num = (max_seq_len // period) + 1 pe = pe.repeat(1, repeat_num, 1) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:, :x.size(1), :] + x = x + self.pe[:, : x.size(1), :] return self.dropout(x) diff --git a/src/modules/gesture/events.py b/src/modules/gesture/events.py index 228b6d8..7793de1 100644 --- a/src/modules/gesture/events.py +++ b/src/modules/gesture/events.py @@ -6,10 +6,11 @@ _EMAGE_FPS = 30 + @dataclass class Motion(EventData): - poses: np.ndarray # (t, 165) SMPL-X axis-angle, 55 joints × 3 + poses: np.ndarray # (t, 165) SMPL-X axis-angle, 55 joints × 3 expressions: np.ndarray # (t, 100) facial expression coefficients - trans: np.ndarray # (t, 3) global root translation + trans: np.ndarray # (t, 3) global root translation fps: int = _EMAGE_FPS - pts: float = 0.0 # presentation timestamp in seconds, paired with Audio.pts + pts: float = 0.0 # presentation timestamp in seconds, paired with Audio.pts diff --git a/src/modules/gesture/gesture.py b/src/modules/gesture/gesture.py index 9028a17..3dcc99b 100644 --- a/src/modules/gesture/gesture.py +++ b/src/modules/gesture/gesture.py @@ -11,7 +11,6 @@ from src.modules.gesture.events import Motion from src.modules.text_to_speech.events import Audio - _HF_REPO = os.environ.get("HURI_EMAGE_REPO", "H-Liu1997/emage_audio") _EMAGE_SR = 16000 # EMAGE expects 16 kHz mono audio @@ -84,15 +83,25 @@ def __init__( print(f"[Gesture] WARNING could not cap GPU memory: {e!r}") print("[Gesture] loading face_vq...") - face_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/face").to(self.device) + face_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/face").to( + self.device + ) print("[Gesture] loading upper_vq...") - upper_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/upper").to(self.device) + upper_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/upper" + ).to(self.device) print("[Gesture] loading lower_vq...") - lower_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/lower").to(self.device) + lower_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/lower" + ).to(self.device) print("[Gesture] loading hands_vq...") - hands_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/hands").to(self.device) + hands_vq = EmageVQVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/hands" + ).to(self.device) print("[Gesture] loading global_ae...") - global_ae = EmageVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/global").to(self.device) + global_ae = EmageVAEConv.from_pretrained( + hf_repo, subfolder="emage_vq/global" + ).to(self.device) self.motion_vq = EmageVQModel( face_model=face_vq, @@ -128,21 +137,24 @@ def _warmup(self) -> None: # synchronize so the GPU work is finished before we report ready. # Best-effort: a failure here must never prevent the deployment coming up. import time + import torch # Representative window lengths (seconds). The dominant per-window # transformer forward is a fixed shape warmed by any window, but the # trailing remainder forward varies with total length, so warm a spread. - secs = sorted({ - round(s, 3) - for s in ( - _MIN_CHUNK_SEC, # first tiny window of an utterance - _CONTEXT_SEC, # context-only sized window - _CONTEXT_SEC + _MIN_CHUNK_SEC, # steady-state window - _CONTEXT_SEC + 2 * _MIN_CHUNK_SEC, # a larger fresh chunk - ) - if s and s > 0 - }) or [3.0] + secs = sorted( + { + round(s, 3) + for s in ( + _MIN_CHUNK_SEC, # first tiny window of an utterance + _CONTEXT_SEC, # context-only sized window + _CONTEXT_SEC + _MIN_CHUNK_SEC, # steady-state window + _CONTEXT_SEC + 2 * _MIN_CHUNK_SEC, # a larger fresh chunk + ) + if s and s > 0 + } + ) or [3.0] try: t0 = time.time() @@ -170,7 +182,10 @@ def infer(self, audio_np: np.ndarray, source_sr: int = _EMAGE_SR) -> Motion: if source_sr != _EMAGE_SR: import librosa - audio_np = librosa.resample(audio_np, orig_sr=source_sr, target_sr=_EMAGE_SR) + + audio_np = librosa.resample( + audio_np, orig_sr=source_sr, target_sr=_EMAGE_SR + ) audio_ts = torch.from_numpy(audio_np).to(self.device).unsqueeze(0) speaker_id = torch.zeros(1, 1, dtype=torch.long, device=self.device) @@ -180,21 +195,50 @@ def infer(self, audio_np: np.ndarray, source_sr: int = _EMAGE_SR) -> Motion: latent_dict = self.model.inference(audio_ts, speaker_id, self.motion_vq) cfg = self.model.cfg - face_latent = latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None - upper_latent = latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None - hands_latent = latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None - lower_latent = latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None - face_index = torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] if cfg.cf > 0 else None - upper_index = torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] if cfg.cu > 0 else None - hands_index = torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] if cfg.ch > 0 else None - lower_index = torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] if cfg.cl > 0 else None + face_latent = ( + latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None + ) + upper_latent = ( + latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None + ) + hands_latent = ( + latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None + ) + lower_latent = ( + latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None + ) + face_index = ( + torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] + if cfg.cf > 0 + else None + ) + upper_index = ( + torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] + if cfg.cu > 0 + else None + ) + hands_index = ( + torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] + if cfg.ch > 0 + else None + ) + lower_index = ( + torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] + if cfg.cl > 0 + else None + ) all_pred = self.motion_vq.decode( - face_latent=face_latent, upper_latent=upper_latent, - lower_latent=lower_latent, hands_latent=hands_latent, - face_index=face_index, upper_index=upper_index, - lower_index=lower_index, hands_index=hands_index, - get_global_motion=True, ref_trans=ref_trans[:, 0], + face_latent=face_latent, + upper_latent=upper_latent, + lower_latent=lower_latent, + hands_latent=hands_latent, + face_index=face_index, + upper_index=upper_index, + lower_index=lower_index, + hands_index=hands_index, + get_global_motion=True, + ref_trans=ref_trans[:, 0], ) t = all_pred["motion_axis_angle"].shape[1] @@ -271,9 +315,11 @@ def __init__( # source sample rate; resampling to 16 kHz happens once inside infer(). self._lock = asyncio.Lock() self._sr: Optional[int] = None - self._buffer = np.empty(0, dtype=np.float32) # trailing audio (ctx + unprocessed) - self._buf_start = 0 # source-sr sample index of buffer[0] in utterance timeline - self._emitted = 0 # source-sr samples whose motion has been emitted + self._buffer = np.empty( + 0, dtype=np.float32 + ) # trailing audio (ctx + unprocessed) + self._buf_start = 0 # source-sr sample index of buffer[0] in utterance timeline + self._emitted = 0 # source-sr samples whose motion has been emitted # Last emitted frame per channel, used to ease the next segment's seam. # These persist across the end-of-utterance reset (see _end_utterance) so @@ -319,7 +365,9 @@ async def process(self, audio: Audio) -> AsyncGenerator[Motion, None]: # type: new_samples = global_end - self._emitted # Wait for more audio unless this is the final flush of the utterance. - if new_samples <= 0 or (not end_of_utterance and new_samples < min_new_samples): + if new_samples <= 0 or ( + not end_of_utterance and new_samples < min_new_samples + ): if end_of_utterance: self._end_utterance() return diff --git a/src/modules/modules.py b/src/modules/modules.py index d8fefb9..66b627d 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -1,15 +1,29 @@ from typing import Dict, Type +from src.modules.emotion.emotion_aggregator import EAG +from src.modules.emotion.prosody_analysis import EMO +from src.modules.gesture.gesture import Gesture +from src.modules.rag.question_aggregator import QAG from src.modules.rag.rag import RAG from src.modules.speech_to_text.microphone_vad import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG from src.modules.text_to_speech.text_to_speech import TTS -from src.modules.gesture.gesture import Gesture from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - modules: Dict[str, Type[Module]] = {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG, "tts": TTS, "gesture": Gesture} + modules: Dict[str, Type[Module]] = { + "mic": MIC, + "stt": STT, + "tag": TAG, + "emo": EMO, + "rag": RAG, + "eag": EAG, + "qag": QAG, + "rag": RAG, + "tts": TTS, + "gesture": Gesture, + } return modules diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/rag/events.py b/src/modules/rag/events.py index 5d237d2..24f883f 100644 --- a/src/modules/rag/events.py +++ b/src/modules/rag/events.py @@ -1,6 +1,9 @@ from dataclasses import dataclass, field +from typing import Optional from src.core.events import EventData +from src.modules.emotion.events import Emotion +from src.modules.speech_to_text.events import Transcript @dataclass @@ -9,3 +12,19 @@ class RAGResult(EventData): answer: str sources: list[dict] = field(default_factory=list) + + +@dataclass +class PartialQuestion(EventData): + """Partial question used to aggregate a sentence to an emotion.""" + + transcript: Optional[Transcript] + emotion: Optional[Emotion] + + +@dataclass +class RAGQuestion(EventData): + """Fully aggregated question to send to the RAG.""" + + transcript: Transcript + emotion: Optional[Emotion] diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index 529c7ff..f0d3e75 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -526,7 +526,9 @@ def main(): if needs_embeddings: if args.embedding_url: - print(f"Embedding remotely via {args.embedding_url} (model={args.embedding_model})") + print( + f"Embedding remotely via {args.embedding_url} (model={args.embedding_model})" + ) model = RemoteEmbedder(args.embedding_url, args.embedding_model) else: from sentence_transformers import SentenceTransformer diff --git a/src/modules/rag/question_aggregator.py b/src/modules/rag/question_aggregator.py new file mode 100644 index 0000000..8fd864b --- /dev/null +++ b/src/modules/rag/question_aggregator.py @@ -0,0 +1,47 @@ +from typing import Optional + +from src.core.module import Module +from src.modules.emotion.events import Emotion +from src.modules.speech_to_text.events import Transcript + +from .events import PartialQuestion, RAGQuestion + + +class QAG(Module): + """QAG Module + + Aggregate sentence and emotion into a RAGQuestion. + + input: partial_question, + output: question + + :use_emotion: default True. + Set to False if you do not analyze emotion or do not need it. + """ + + input_type = "partial_question" + output_type = "question" + + def __init__(self, use_emotion: bool = True): + super().__init__() + + self.current_transcript: Optional[Transcript] = None + self.current_emotion: Optional[Emotion] = None + + self.use_emotion = use_emotion + + async def process(self, partial_question: PartialQuestion) -> Optional[RAGQuestion]: + if partial_question.emotion is not None: + self.current_emotion = partial_question.emotion + + if partial_question.transcript is not None: + self.current_transcript = partial_question.transcript + + if self.current_transcript is not None: + if self.use_emotion: + if self.current_emotion is not None: + return RAGQuestion(self.current_transcript, self.current_emotion) + else: + return RAGQuestion(self.current_transcript, None) + + return None diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0e65bbe..18ff2e1 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -4,17 +4,16 @@ from dataclasses import dataclass, field from typing import Any, AsyncGenerator +import httpx from pydantic import BaseModel +from qdrant_client.models import FieldCondition, Filter, MatchValue from ray import serve from ray.serve import handle from src.core.module import ModuleWithHandle, ModuleWithId -from src.modules.speech_to_text.events import Sentence from src.modules.text_to_speech.events import Token -import httpx - -from qdrant_client.models import FieldCondition, Filter, MatchValue +from .events import RAGQuestion from .qdrant_utils import make_qdrant_client # Default character persona. Overridable per session via the `persona` key in the @@ -116,7 +115,9 @@ def _get_profile(self, collection: str, _user_id: str) -> list[str]: collection_name=collection, scroll_filter=Filter( must=[ - FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + FieldCondition( + key="_user_id", match=MatchValue(value=_user_id) + ), FieldCondition(key="type", match=MatchValue(value="profile")), ] ), @@ -226,28 +227,28 @@ async def _stream_ollama( self, messages: list, max_tokens: int, temperature: float = 0.7 ) -> AsyncGenerator[str, None]: async with self._llm_client.stream( - "POST", - f"{self._cfg.llm_url}/api/chat", - json={ - "model": self._cfg.llm_model, - "messages": messages, - "stream": True, - "options": {"num_predict": max_tokens, "temperature": temperature}, - }, - ) as resp: - resp.raise_for_status() - async for line in resp.aiter_lines(): - if not line: - continue - try: - chunk = json.loads(line) - except json.JSONDecodeError: - continue - delta = chunk.get("message", {}).get("content", "") - if delta: - yield delta - if chunk.get("done"): - return + "POST", + f"{self._cfg.llm_url}/api/chat", + json={ + "model": self._cfg.llm_model, + "messages": messages, + "stream": True, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + delta = chunk.get("message", {}).get("content", "") + if delta: + yield delta + if chunk.get("done"): + return async def _stream_openai_compatible( self, @@ -261,35 +262,33 @@ async def _stream_openai_compatible( if api_key: headers["Authorization"] = f"Bearer {api_key}" async with self._llm_client.stream( - "POST", - url, - headers=headers, - json={ - "model": self._cfg.llm_model, - "messages": messages, - "max_tokens": max_tokens, - "temperature": temperature, - "stream": True, - }, - ) as resp: - resp.raise_for_status() - async for line in resp.aiter_lines(): - if not line or not line.startswith("data:"): - continue - payload = line[len("data:"):].strip() - if payload == "[DONE]": - return - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - continue - delta = ( - chunk.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ) - if delta: - yield delta + "POST", + url, + headers=headers, + json={ + "model": self._cfg.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": True, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload == "[DONE]": + return + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + continue + delta = ( + chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") + ) + if delta: + yield delta async def _llm_stream( self, @@ -391,7 +390,9 @@ def __init__( ): super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) - print(f"[RAG] Initialized with user_id={_user_id}, language={language}, tone={tone}, response_format={response_format}, max_length={max_length}, temperature={temperature}, max_history_turns={max_history_turns}") + print( + f"[RAG] Initialized with user_id={_user_id}, language={language}, tone={tone}, response_format={response_format}, max_length={max_length}, temperature={temperature}, max_history_turns={max_history_turns}" + ) self.preferences = { "language": language, @@ -410,10 +411,16 @@ def __init__( self._max_history_turns = max_history_turns self.history: list[dict] = [] - async def process(self, data: Sentence) -> AsyncGenerator[Token, None]: # type: ignore[override] + async def process(self, data: RAGQuestion) -> AsyncGenerator[Token, None]: # type: ignore[override] + """ + Called when a "question" event arrives through the event bus. + Packages _user_id + question, sends to the stateless RAGHandle. + """ + question_text = data.transcript.text + query = RAGQuery( _user_id=self._user_id if self._user_id else "anonymous", - question=data.text, + question=question_text, preferences=self.preferences, history=list(self.history), # snapshot of prior turns ) @@ -425,7 +432,7 @@ async def process(self, data: Sentence) -> AsyncGenerator[Token, None]: # type: yield Token(text=delta, end=False) yield Token(text="", end=True) - self._record_turn(data.text, "".join(parts)) + self._record_turn(question_text, "".join(parts)) def _record_turn(self, question: str, answer: str) -> None: """Append this turn to the session history (raw Q/A, no RAG context) diff --git a/src/modules/speech_to_text/events.py b/src/modules/speech_to_text/events.py index e80f522..fc04674 100644 --- a/src/modules/speech_to_text/events.py +++ b/src/modules/speech_to_text/events.py @@ -15,8 +15,3 @@ class Transcript(EventData): @dataclass class Voice(EventData): data: Optional[np.ndarray] - - -@dataclass -class Sentence(EventData): - text: str diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 9e48a62..53a1207 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -70,6 +70,8 @@ class STT(ModuleWithHandle): as "en" or "fr". :sample_rate: size of received voice audio. Usually 8000, 16000 or 48000. :block_duration: size of received voice audio (in s). + :transcribe_window: duration of audio per transcription (in s). + :transcribe_step: overlap between consecutive transcription windows (in s). """ _handle_cls = STTDeployment @@ -98,9 +100,6 @@ def __init__( self.silence: bool = True - self.prev_text: str = "" - self.stable_text: str = "" - self.running = False self.lock: asyncio.Lock = asyncio.Lock() @@ -126,7 +125,6 @@ async def process(self, voice: Voice) -> Optional[Transcript]: # type: ignore[o return None processing_chunks = self.buffer[: self.window_size] - self.pending_silence = False processing_audio = np.concatenate(processing_chunks, axis=0) current_text = await self._handle.transcribe.remote( diff --git a/src/modules/speech_to_text/text_aggregator.py b/src/modules/speech_to_text/text_aggregator.py index 72760af..b9ef1ce 100644 --- a/src/modules/speech_to_text/text_aggregator.py +++ b/src/modules/speech_to_text/text_aggregator.py @@ -2,8 +2,9 @@ from typing import Optional from src.core.module import Module +from src.modules.rag.events import PartialQuestion -from .events import Sentence, Transcript +from .events import Transcript class TAG(Module): @@ -12,11 +13,11 @@ class TAG(Module): Aggregate all transcriptions and send when transcript end. input: transcript, - output: question + output: partial_question """ input_type = "transcript" - output_type = "question" + output_type = "partial_question" def __init__( self, @@ -36,7 +37,7 @@ def _merge(self, current: str, new: str) -> str: self.prev_index = len(current) return current + new - async def process(self, transcript: Transcript) -> Optional[Sentence]: + async def process(self, transcript: Transcript) -> Optional[PartialQuestion]: text = transcript.text if text != "": @@ -46,9 +47,9 @@ async def process(self, transcript: Transcript) -> Optional[Sentence]: self.sentence = self._merge(self.sentence, text) if transcript.end and self.sentence != "": - sentence = Sentence(self.sentence) + transcript = Transcript(self.sentence, True) self.sentence = "" self.prev_index = 0 - return sentence + return PartialQuestion(transcript=transcript, emotion=None) else: return None diff --git a/src/modules/text_to_speech/text_to_speech.py b/src/modules/text_to_speech/text_to_speech.py index ad49dab..92eb66b 100644 --- a/src/modules/text_to_speech/text_to_speech.py +++ b/src/modules/text_to_speech/text_to_speech.py @@ -37,9 +37,10 @@ def _normalize_transcript(raw: str) -> str: else f"{_DEFAULT_INSTRUCTION}<|endofprompt|>{raw}" ) -_END_TEXT = object() # sentinel pushed into the text queue to close synth + +_END_TEXT = object() # sentinel pushed into the text queue to close synth _END_AUDIO = object() # sentinel pushed into the audio queue when synth completes -_DONE = object() # sentinel for exhausted sync generator +_DONE = object() # sentinel for exhausted sync generator @serve.deployment(name="TTS", max_ongoing_requests=200) @@ -66,7 +67,7 @@ def __init__( sys.path.insert(0, matcha_path) from cosyvoice.cli.cosyvoice import CosyVoice3 - + # Resolve the reference transcript here (deploy time on the GPU worker) # rather than at module import: importing this module must not require # HURI_VOICE_TRANSCRIPT, since modules.py imports it inside a broad @@ -181,7 +182,9 @@ async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: i if is_first: self._session_id = str(uuid.uuid4()) self._audio_q = asyncio.Queue() - print(f"[TTS-client] [{self._session_id}] opening new utterance session") + print( + f"[TTS-client] [{self._session_id}] opening new utterance session" + ) await self._handle.start_session.remote(self._session_id) self._stream_task = asyncio.create_task( self._drain_audio(self._session_id, self._audio_q) @@ -210,7 +213,9 @@ async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: i print(f"[TTS-client] [{sid}] utterance complete ({count} chunks)") sample_rate = await self._handle.get_sample_rate.remote() - yield Audio(data=np.array([], dtype=np.float32), sample_rate=sample_rate, end=True) + yield Audio( + data=np.array([], dtype=np.float32), sample_rate=sample_rate, end=True + ) finally: async with self._push_lock: self._session_id = None diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index c1bac03..da4deab 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -39,7 +39,11 @@ async def process(self, _): elif isinstance(data, Audio): logger.info( "[Sender:%s] Audio samples=%d sr=%d end=%s pts=%.3fs", - self.input_type, data.data.shape[0], data.sample_rate, data.end, data.pts, + self.input_type, + data.data.shape[0], + data.sample_rate, + data.end, + data.pts, ) header = struct.pack(">IBd", data.sample_rate, int(data.end), data.pts) await self.ws.send_bytes(self._prefix(header + data.data.tobytes())) @@ -47,7 +51,10 @@ async def process(self, _): n_frames = data.poses.shape[0] logger.info( "[Sender:%s] Motion frames=%d fps=%d pts=%.3fs", - self.input_type, n_frames, data.fps, data.pts, + self.input_type, + n_frames, + data.fps, + data.pts, ) header = struct.pack(">dII", data.pts, data.fps, n_frames) body = (