(test_case)
| 716 | |
| 717 | |
| 718 | def require_all_flash_attn(test_case): |
| 719 | flash_attn_available = is_flash_attn_2_available() |
| 720 | kernels_available = is_kernels_available() |
| 721 | try: |
| 722 | from kernels import get_kernel |
| 723 | |
| 724 | get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"]) |
| 725 | except Exception as _: |
| 726 | kernels_available = False |
| 727 | |
| 728 | return unittest.skipUnless( |
| 729 | all( |
| 730 | ( |
| 731 | flash_attn_available | kernels_available, |
| 732 | is_flash_attn_3_available(), |
| 733 | is_flash_attn_4_available(), |
| 734 | ) |
| 735 | ), |
| 736 | "test requires all mainline Flash Attention packages", |
| 737 | )(test_case) |
| 738 | |
| 739 | |
| 740 | def require_flash_linear_attention(test_case): |
nothing calls this directly
no test coverage detected