Decorator marking a test that requires Flash Attention. These tests are skipped when Flash Attention isn't installed.
(test_case)
| 669 | |
| 670 | |
| 671 | def require_flash_attn(test_case): |
| 672 | """ |
| 673 | Decorator marking a test that requires Flash Attention. |
| 674 | |
| 675 | These tests are skipped when Flash Attention isn't installed. |
| 676 | |
| 677 | """ |
| 678 | flash_attn_available = is_flash_attn_2_available() |
| 679 | kernels_available = is_kernels_available() |
| 680 | try: |
| 681 | from kernels import get_kernel |
| 682 | |
| 683 | get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"]) |
| 684 | except Exception as _: |
| 685 | kernels_available = False |
| 686 | |
| 687 | return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case) |
| 688 | |
| 689 | |
| 690 | def require_kernels(test_case): |
nothing calls this directly
no test coverage detected