(config_class, output_dir)
| 1114 | |
| 1115 | |
| 1116 | def build_composite_models(config_class, output_dir): |
| 1117 | import tempfile |
| 1118 | |
| 1119 | from transformers import ( |
| 1120 | BertConfig, |
| 1121 | BertLMHeadModel, |
| 1122 | BertModel, |
| 1123 | BertTokenizer, |
| 1124 | BertTokenizerFast, |
| 1125 | EncoderDecoderModel, |
| 1126 | GPT2Config, |
| 1127 | GPT2LMHeadModel, |
| 1128 | GPT2Tokenizer, |
| 1129 | GPT2TokenizerFast, |
| 1130 | SpeechEncoderDecoderModel, |
| 1131 | VisionEncoderDecoderModel, |
| 1132 | VisionTextDualEncoderModel, |
| 1133 | ViTConfig, |
| 1134 | ViTModel, |
| 1135 | Wav2Vec2Config, |
| 1136 | Wav2Vec2Model, |
| 1137 | Wav2Vec2Processor, |
| 1138 | ) |
| 1139 | |
| 1140 | # These will be removed at the end if they are empty |
| 1141 | result = {"error": None, "warnings": []} |
| 1142 | |
| 1143 | if config_class.model_type == "encoder-decoder": |
| 1144 | encoder_config_class = BertConfig |
| 1145 | decoder_config_class = BertConfig |
| 1146 | encoder_processor = (BertTokenizerFast, BertTokenizer) |
| 1147 | decoder_processor = (BertTokenizerFast, BertTokenizer) |
| 1148 | encoder_class = BertModel |
| 1149 | decoder_class = BertLMHeadModel |
| 1150 | model_class = EncoderDecoderModel |
| 1151 | elif config_class.model_type == "vision-encoder-decoder": |
| 1152 | encoder_config_class = ViTConfig |
| 1153 | decoder_config_class = GPT2Config |
| 1154 | encoder_processor = (ViTImageProcessor,) |
| 1155 | decoder_processor = (GPT2TokenizerFast, GPT2Tokenizer) |
| 1156 | encoder_class = ViTModel |
| 1157 | decoder_class = GPT2LMHeadModel |
| 1158 | model_class = VisionEncoderDecoderModel |
| 1159 | elif config_class.model_type == "speech-encoder-decoder": |
| 1160 | encoder_config_class = Wav2Vec2Config |
| 1161 | decoder_config_class = BertConfig |
| 1162 | encoder_processor = (Wav2Vec2Processor,) |
| 1163 | decoder_processor = (BertTokenizerFast, BertTokenizer) |
| 1164 | encoder_class = Wav2Vec2Model |
| 1165 | decoder_class = BertLMHeadModel |
| 1166 | model_class = SpeechEncoderDecoderModel |
| 1167 | elif config_class.model_type == "vision-text-dual-encoder": |
| 1168 | # Not encoder-decoder, but encoder-encoder. We just keep the same name as above to make code easier |
| 1169 | encoder_config_class = ViTConfig |
| 1170 | decoder_config_class = BertConfig |
| 1171 | encoder_processor = (ViTImageProcessor,) |
| 1172 | decoder_processor = (BertTokenizerFast, BertTokenizer) |
| 1173 | encoder_class = ViTModel |
no test coverage detected