(self, llm_model, flow_model, hift_model)
| 65 | self.hift_cache_dict = {} |
| 66 | |
| 67 | def load(self, llm_model, flow_model, hift_model): |
| 68 | self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) |
| 69 | self.llm.to(self.device).eval() |
| 70 | self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) |
| 71 | self.flow.to(self.device).eval() |
| 72 | # in case hift_model is a hifigan model |
| 73 | hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} |
| 74 | self.hift.load_state_dict(hift_state_dict, strict=True) |
| 75 | self.hift.to(self.device).eval() |
| 76 | |
| 77 | def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): |
| 78 | llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) |
no outgoing calls
no test coverage detected