(level: int)
| 253 | |
| 254 | |
| 255 | def get_config_by_level(level: int) -> list[BenchmarkConfig]: |
| 256 | configs = [] |
| 257 | # Early return if level is greater than 3: we generate all combinations of configs, maybe even w/ all compile modes |
| 258 | if level >= 3: |
| 259 | for attn_implementation in BenchmarkConfig.all_attn_implementations: |
| 260 | # Usually there is not much to gain by compiling with other modes, but we allow it for level 4 |
| 261 | compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"] |
| 262 | for cm in compile_modes: |
| 263 | compile_kwargs = {"mode": cm} if cm is not None else None |
| 264 | for kernelize_on in {False, KERNELIZATION_AVAILABLE}: |
| 265 | for cb_on in [False, True]: |
| 266 | configs.append( |
| 267 | BenchmarkConfig( |
| 268 | attn_implementation=attn_implementation, |
| 269 | compile_kwargs=compile_kwargs, |
| 270 | kernelize=kernelize_on, |
| 271 | continuous_batching=cb_on, |
| 272 | ) |
| 273 | ) |
| 274 | return configs |
| 275 | # Otherwise, we add the configs for the given level |
| 276 | if level >= 0: |
| 277 | configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_kwargs={})) |
| 278 | if level >= 1: |
| 279 | configs.append(BenchmarkConfig(attn_implementation="flash_attention_2")) |
| 280 | configs.append(BenchmarkConfig(attn_implementation="eager", compile_kwargs={})) |
| 281 | configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", continuous_batching=True)) |
| 282 | if level >= 2: |
| 283 | configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_kwargs={})) |
| 284 | configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_kwargs={}, kernelize=True)) |
| 285 | configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True)) |
| 286 | configs.append(BenchmarkConfig(attn_implementation="sdpa", continuous_batching=True)) |
| 287 | return configs |
no test coverage detected