(self, config: BltConfig)
| 901 | |
| 902 | class BltModel(BltPreTrainedModel): |
| 903 | def __init__(self, config: BltConfig): |
| 904 | super().__init__(config) |
| 905 | self.gradient_checkpointing = False |
| 906 | |
| 907 | self.config = config |
| 908 | self.local_encoder = BltLocalEncoder(config.encoder_config) |
| 909 | self.global_transformer = BltGlobalTransformer(config.global_config) |
| 910 | self.local_decoder = BltLocalDecoder(config.decoder_config) |
| 911 | num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) |
| 912 | total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings |
| 913 | self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) |
| 914 | if self.config.patch_in_forward: |
| 915 | self.patcher = BltPatcher(config.patcher_config) |
| 916 | self.patcher.eval() |
| 917 | for param in self.patcher.parameters(): |
| 918 | param.requires_grad = False |
| 919 | else: |
| 920 | self.patcher = None |
| 921 | self.post_init() |
| 922 | |
| 923 | @merge_with_config_defaults |
| 924 | @capture_outputs |
no test coverage detected