From 53b800e59a36d23631c7590ff389737b258ec551 Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 5 Oct 2025 12:59:42 +0000 Subject: [PATCH 01/13] add watchdog :D --- src/hip_attn/utils/sglang_watchdog.py | 124 ++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/hip_attn/utils/sglang_watchdog.py diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py new file mode 100644 index 00000000..5c73b9f8 --- /dev/null +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -0,0 +1,124 @@ +import datetime +import sys +import os +import subprocess +import threading +import time +import requests + +def log(*args): + comment = " ".join([str(a) for a in args]) + timestamp = "{:%Y-%m-%d %H:%M:%S}".format(datetime.datetime.now()) + print(f"\033[91m[{timestamp} sglang_watchdog] {comment}\033[0m", flush=True) + +class Watchdog: + def __init__( + self, + timeout_bootup = 300, + ): + self.timeout_bootup = 300 + self.timeout_tick = 60 + self.sleep_step = 1 + self.proc: subprocess.Popen = None + self.argv: list[str] = None + self.running: bool = True + + def start_subprocess(self): + args = [ + "python", + "-m", + "sglang.launch_server", + *self.argv + ] + flatten_args = " ".join(args) + log(f"Start subprocess using following command: {flatten_args}") + self.proc = subprocess.Popen(args) + log(f"Start subprocess communication.") + return_code = self.proc.wait() + log(f"Return code is {return_code}") + + def kill_subprocess(self): + log(f"Start kill subprocess") + self.proc.kill() + self.proc = None + log(f"Finish kill subprocess") + + def wait_for_health(self, timeout: int): + t_start = time.time() + while (time.time() - t_start) < timeout: + try: + response = requests.get(self.health_endpoint, timeout=timeout) + response.raise_for_status() + return + except requests.ConnectionError: + time.sleep(self.sleep_step) + raise TimeoutError() + + def main_watchdog(self): + while True: + try: + t_boot = time.time() + booted = False + while ( + (time.time() - t_boot) < self.timeout_bootup + and self.proc.returncode is None + and not booted + ): + try: + self.wait_for_health(timeout=self.timeout_bootup) + booted = True + except (TimeoutError, requests.HTTPError): + # NOTE: may process is not started yet + pass + time.sleep(self.sleep_step) + + if not booted: raise TimeoutError() + + while True: + self.wait_for_health(timeout=self.timeout_tick) + time.sleep(self.sleep_step) + + except (TimeoutError, requests.HTTPError): + self.kill_subprocess() + time.sleep(self.sleep_step) + + def main_starter(self): + while True: + self.start_subprocess() + time.sleep(self.sleep_step) + + def start(self): + if "--" in sys.argv: + argv = sys.argv[sys.argv.index("--") + 1:] + else: + argv = sys.argv[1:] + + assert "--host" in argv + assert "--port" in argv + self.host = argv[argv.index("--host") + 1] + self.port = argv[argv.index("--port") + 1] + self.health_endpoint = f"http://{self.host}:{self.port}/health" + log(f"Watching: {self.health_endpoint}") + + self.argv = argv + + self.thread_watchdog = threading.Thread( + target=self.main_watchdog, + daemon=True + ) + self.thread_starter = threading.Thread( + target=self.main_starter, + daemon=True + ) + + self.thread_watchdog.start() + self.thread_starter.start() + + self.thread_watchdog.join() + self.thread_starter.join() + + self.running = False + +if __name__ == '__main__': + dog = Watchdog() + dog.start() \ No newline at end of file From beeac38eef00ded3b03749cb5a9e8f35cdf947d1 Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 5 Oct 2025 13:08:28 +0000 Subject: [PATCH 02/13] fix watchdog --- src/hip_attn/utils/sglang_watchdog.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py index 5c73b9f8..7947d137 100644 --- a/src/hip_attn/utils/sglang_watchdog.py +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -44,15 +44,8 @@ def kill_subprocess(self): log(f"Finish kill subprocess") def wait_for_health(self, timeout: int): - t_start = time.time() - while (time.time() - t_start) < timeout: - try: - response = requests.get(self.health_endpoint, timeout=timeout) - response.raise_for_status() - return - except requests.ConnectionError: - time.sleep(self.sleep_step) - raise TimeoutError() + response = requests.get(self.health_endpoint, timeout=timeout) + response.raise_for_status() def main_watchdog(self): while True: @@ -66,8 +59,9 @@ def main_watchdog(self): ): try: self.wait_for_health(timeout=self.timeout_bootup) + log("Server booted successfully.") booted = True - except (TimeoutError, requests.HTTPError): + except (TimeoutError, requests.HTTPError, requests.ConnectionError): # NOTE: may process is not started yet pass time.sleep(self.sleep_step) @@ -75,8 +69,10 @@ def main_watchdog(self): if not booted: raise TimeoutError() while True: + log("Try watch dog.") self.wait_for_health(timeout=self.timeout_tick) - time.sleep(self.sleep_step) + log("Done watch dog successfully.") + time.sleep(self.timeout_tick) except (TimeoutError, requests.HTTPError): self.kill_subprocess() @@ -111,8 +107,9 @@ def start(self): daemon=True ) - self.thread_watchdog.start() self.thread_starter.start() + time.sleep(self.sleep_step) + self.thread_watchdog.start() self.thread_watchdog.join() self.thread_starter.join() From 744f5eef4114ab26141545878056dcbe421100c7 Mon Sep 17 00:00:00 2001 From: AinL Date: Mon, 6 Oct 2025 11:11:32 +0000 Subject: [PATCH 03/13] fix watchdog --- src/hip_attn/utils/sglang_watchdog.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py index 7947d137..6a12ff45 100644 --- a/src/hip_attn/utils/sglang_watchdog.py +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -4,6 +4,7 @@ import subprocess import threading import time +import traceback import requests def log(*args): @@ -52,6 +53,9 @@ def main_watchdog(self): try: t_boot = time.time() booted = False + while self.proc is None: + log("Watchdog is waiting for process started...") + time.sleep(self.sleep_step) while ( (time.time() - t_boot) < self.timeout_bootup and self.proc.returncode is None @@ -76,6 +80,12 @@ def main_watchdog(self): except (TimeoutError, requests.HTTPError): self.kill_subprocess() + except Exception as ex: + trace = traceback.format_exc() + log(f"Traceback:\n{trace}") + log(f"Unexpected error on watchdog thread: {ex}") + self.kill_subprocess() + time.sleep(self.sleep_step) def main_starter(self): From b69ee2171bcc85c55d68e267808bd815988a206c Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 08:49:32 +0000 Subject: [PATCH 04/13] fix --- src/hip_attn/v1_2/paged_hip.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index 58fca4b9..7120cb8a 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -9,10 +9,23 @@ import numpy as np import torch import triton -from flash_attn import flash_attn_func from matplotlib import pyplot as plt -from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func -from sgl_kernel.flash_attn import flash_attn_with_kvcache + +try: + from flash_attn import flash_attn_func +except ImportError: + flash_attn_func = None + +try: + from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from sgl_kernel.flash_attn import flash_attn_with_kvcache + IS_AMD = False +except ImportError: + # FIXME: better AMD detection algorithm + IS_AMD = True + + from flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from flash_attn import flash_attn_with_kvcache from hip_attn.v1_2.hip_config import HiPAttentionConfig from hip_attn.v1_2.utils import capture From 4b434d74c7a441b8792415afe43a4bfec02b4525 Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 09:02:05 +0000 Subject: [PATCH 05/13] fix --- src/hip_attn/v1_2/attention_extend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/hip_attn/v1_2/attention_extend.py b/src/hip_attn/v1_2/attention_extend.py index 36ddf535..36441d13 100644 --- a/src/hip_attn/v1_2/attention_extend.py +++ b/src/hip_attn/v1_2/attention_extend.py @@ -16,7 +16,12 @@ from hip_attn.utils.rope import adjust_rope from hip_attn.v1_2.attention_decode_bsa import decode_block_sparse_attention from hip_attn.v1_2.attention_extend_bsa import block_sparse_attention -from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang + +try: + from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang +except ImportError: + block_sparse_attention_tilelang = None + from hip_attn.v1_2.attention_metadata import ( EnsembleScoreStage, EvalScoreStage, From 97cf3caa2b3191c8fc7a14fb125c6f5ca5d151ef Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 09:02:53 +0000 Subject: [PATCH 06/13] fix --- src/hip_attn/v1_2/attention_extend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hip_attn/v1_2/attention_extend.py b/src/hip_attn/v1_2/attention_extend.py index 36441d13..bc710ea8 100644 --- a/src/hip_attn/v1_2/attention_extend.py +++ b/src/hip_attn/v1_2/attention_extend.py @@ -19,7 +19,7 @@ try: from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang -except ImportError: +except ImportError, OSError: block_sparse_attention_tilelang = None from hip_attn.v1_2.attention_metadata import ( From b94b5116c565d57976f9e8569e5fea60cd8687da Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 09:03:22 +0000 Subject: [PATCH 07/13] fix --- src/hip_attn/v1_2/attention_extend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hip_attn/v1_2/attention_extend.py b/src/hip_attn/v1_2/attention_extend.py index bc710ea8..449c2c02 100644 --- a/src/hip_attn/v1_2/attention_extend.py +++ b/src/hip_attn/v1_2/attention_extend.py @@ -19,7 +19,7 @@ try: from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang -except ImportError, OSError: +except (ImportError, OSError): block_sparse_attention_tilelang = None from hip_attn.v1_2.attention_metadata import ( From c17e4d99e43aca0214bb573ef66d556868ba291b Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 11:22:41 +0000 Subject: [PATCH 08/13] fix --- src/hip_attn/utils/sglang_watchdog.py | 16 +++++- src/hip_attn/v1_2/paged_hip.py | 59 ++++++++++++++++++++- src/hip_attn/v1_2/query_sparse_attention.py | 2 +- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py index 6a12ff45..1971205c 100644 --- a/src/hip_attn/utils/sglang_watchdog.py +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -1,3 +1,4 @@ +import argparse import datetime import sys import os @@ -15,9 +16,8 @@ def log(*args): class Watchdog: def __init__( self, - timeout_bootup = 300, ): - self.timeout_bootup = 300 + self.timeout_bootup = 600 self.timeout_tick = 60 self.sleep_step = 1 self.proc: subprocess.Popen = None @@ -95,10 +95,22 @@ def main_starter(self): def start(self): if "--" in sys.argv: + my_args = sys.argv[1:sys.argv.index("--")] argv = sys.argv[sys.argv.index("--") + 1:] else: + my_args = [] argv = sys.argv[1:] + parser = argparse.ArgumentParser() + parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int) + parser.add_argument("--timeout", default=self.timeout_tick, type=int) + parser.add_argument("--sleep-step", default=self.sleep_step, type=int) + + args = parser.parse_args(my_args) + self.timeout_bootup = args.timeout_bootup + self.timeout_tick = args.timeout + self.sleep_step = args.sleep_step + assert "--host" in argv assert "--port" in argv self.host = argv[argv.index("--host") + 1] diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index 7120cb8a..bb79ea9a 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -2,7 +2,7 @@ import math import os import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import cv2 import numba @@ -25,7 +25,62 @@ IS_AMD = True from flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func - from flash_attn import flash_attn_with_kvcache + from flash_attn import flash_attn_with_kvcache as __flash_attn_with_kvcache + + def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + ver=3, + ): + return flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + block_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, # -1 means infinite context window + softcap=softcap, # 0.0 means deactivated + rotary_interleaved=rotary_interleaved, + alibi_slopes=None, + num_splits=num_splits, + return_softmax_lse=return_softmax_lse, + ) from hip_attn.v1_2.hip_config import HiPAttentionConfig from hip_attn.v1_2.utils import capture diff --git a/src/hip_attn/v1_2/query_sparse_attention.py b/src/hip_attn/v1_2/query_sparse_attention.py index d001e213..c862ba1f 100644 --- a/src/hip_attn/v1_2/query_sparse_attention.py +++ b/src/hip_attn/v1_2/query_sparse_attention.py @@ -1907,7 +1907,7 @@ def forward( assert rope_cos.ndim == 2 assert extend_backend in ["self_extend", "nope"] - if rope_sin is not None: + if (rope_sin is not None) and (extend_backend in ["self_extend"]): HEAD_DIM_K_ROPE = rope_sin.shape[-1] HEAD_DIM_K_NOPE = HEAD_DIM_K - HEAD_DIM_K_ROPE else: From fee868ed6ed36eb2b85cffdf93baf7a3e02d5ed3 Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 11:26:08 +0000 Subject: [PATCH 09/13] fix --- src/hip_attn/v1_2/paged_hip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index bb79ea9a..6a5f4c1a 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -60,7 +60,7 @@ def flash_attn_with_kvcache( sinks=None, ver=3, ): - return flash_attn_with_kvcache( + return __flash_attn_with_kvcache( q, k_cache, v_cache, From 3685169ea3e43ed76728723ab05b95f808a3e0cc Mon Sep 17 00:00:00 2001 From: AinL Date: Sun, 12 Oct 2025 11:31:44 +0000 Subject: [PATCH 10/13] fix --- .../mixed_landmark_0814_no_extend_qsa.json | 60 +++++++++++++++++-- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/configs/mixed_landmark_0814_no_extend_qsa.json b/configs/mixed_landmark_0814_no_extend_qsa.json index 1fc11ac2..c69ea1e5 100644 --- a/configs/mixed_landmark_0814_no_extend_qsa.json +++ b/configs/mixed_landmark_0814_no_extend_qsa.json @@ -7,24 +7,74 @@ "__delta_attention_args": "window_0-diff_1-w_16-dense_decode-smooth", "using_extend": false, "dense_layers": [0, 1, 2, 47, 46, 45], - "mask_refresh_interval": [96], + "mask_refresh_interval": [96, 32, 16], "layers": [ { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] }, { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", "scan_extend_backend": "none", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] } ], "prefill_layers": [ From b0568c66a922c7ba903063bbb244f3b8c68fa106 Mon Sep 17 00:00:00 2001 From: AinL Date: Thu, 16 Oct 2025 01:08:04 +0000 Subject: [PATCH 11/13] watchdog bug fix --- src/hip_attn/utils/sglang_watchdog.py | 83 ++++++++++++++------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py index 1971205c..0baca976 100644 --- a/src/hip_attn/utils/sglang_watchdog.py +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -40,8 +40,10 @@ def start_subprocess(self): def kill_subprocess(self): log(f"Start kill subprocess") - self.proc.kill() - self.proc = None + if self.proc is not None: + self.proc.kill() + self.proc = None + subprocess.call(["pkill", "sglang"]) log(f"Finish kill subprocess") def wait_for_health(self, timeout: int): @@ -94,49 +96,52 @@ def main_starter(self): time.sleep(self.sleep_step) def start(self): - if "--" in sys.argv: - my_args = sys.argv[1:sys.argv.index("--")] - argv = sys.argv[sys.argv.index("--") + 1:] - else: - my_args = [] - argv = sys.argv[1:] - - parser = argparse.ArgumentParser() - parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int) - parser.add_argument("--timeout", default=self.timeout_tick, type=int) - parser.add_argument("--sleep-step", default=self.sleep_step, type=int) + try: + if "--" in sys.argv: + my_args = sys.argv[1:sys.argv.index("--")] + argv = sys.argv[sys.argv.index("--") + 1:] + else: + my_args = [] + argv = sys.argv[1:] + + parser = argparse.ArgumentParser() + parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int) + parser.add_argument("--timeout", default=self.timeout_tick, type=int) + parser.add_argument("--sleep-step", default=self.sleep_step, type=int) - args = parser.parse_args(my_args) - self.timeout_bootup = args.timeout_bootup - self.timeout_tick = args.timeout - self.sleep_step = args.sleep_step - - assert "--host" in argv - assert "--port" in argv - self.host = argv[argv.index("--host") + 1] - self.port = argv[argv.index("--port") + 1] - self.health_endpoint = f"http://{self.host}:{self.port}/health" - log(f"Watching: {self.health_endpoint}") + args = parser.parse_args(my_args) + self.timeout_bootup = args.timeout_bootup + self.timeout_tick = args.timeout + self.sleep_step = args.sleep_step + + assert "--host" in argv + assert "--port" in argv + self.host = argv[argv.index("--host") + 1] + self.port = argv[argv.index("--port") + 1] + self.health_endpoint = f"http://{self.host}:{self.port}/health" + log(f"Watching: {self.health_endpoint}") - self.argv = argv + self.argv = argv - self.thread_watchdog = threading.Thread( - target=self.main_watchdog, - daemon=True - ) - self.thread_starter = threading.Thread( - target=self.main_starter, - daemon=True - ) + self.thread_watchdog = threading.Thread( + target=self.main_watchdog, + daemon=True + ) + self.thread_starter = threading.Thread( + target=self.main_starter, + daemon=True + ) - self.thread_starter.start() - time.sleep(self.sleep_step) - self.thread_watchdog.start() + self.thread_starter.start() + time.sleep(self.sleep_step) + self.thread_watchdog.start() - self.thread_watchdog.join() - self.thread_starter.join() + self.thread_watchdog.join() + self.thread_starter.join() - self.running = False + self.running = False + except KeyboardInterrupt: + self.kill_subprocess() if __name__ == '__main__': dog = Watchdog() From 618c15168f9642846d9222f6149c3ad1d8d593f2 Mon Sep 17 00:00:00 2001 From: AinL Date: Mon, 3 Nov 2025 02:12:33 +0000 Subject: [PATCH 12/13] add benchmark --- scripts/bench_latency_paged_attn.py | 195 ++++++++++++++++++++++++++++ src/hip_attn/v1_2/paged_hip.py | 14 +- 2 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 scripts/bench_latency_paged_attn.py diff --git a/scripts/bench_latency_paged_attn.py b/scripts/bench_latency_paged_attn.py new file mode 100644 index 00000000..6553c0b9 --- /dev/null +++ b/scripts/bench_latency_paged_attn.py @@ -0,0 +1,195 @@ +""" +python scripts/benchmark_latency_paged_attn.py +""" + +import os +import json +import traceback +import torch +from transformers import AutoConfig +import triton +from hip_attn.v1_2.paged_hip import forward_paged_hip, HiPAttentionConfig + +def forward_seq_len( + dtype: torch.dtype, + seq_len: int, + q_head: int, + kv_head: int, + head_dim: int, + hip_config: HiPAttentionConfig, + batch_size: int = 1, +): + device = torch.device("cuda:0") + + query = torch.rand( + (batch_size * seq_len, q_head, head_dim), + dtype=torch.bfloat16, + device=device + ) + k_cache = torch.rand( + # NOTE: + 1 is special behavior on SGlang. I am not sure about it is exists in vLLM too. + ((seq_len + 1) * batch_size, kv_head, head_dim), + dtype=torch.bfloat16, + device=device, + ).to(dtype) + v_cache = k_cache.clone() + positions = torch.arange(0, batch_size * seq_len, dtype=torch.long, device=device) + positions = positions % seq_len + seq_lens = torch.zeros((batch_size,), dtype=torch.long, device=device) + seq_lens[:] = seq_len + block_table = torch.arange(0, batch_size * seq_len, dtype=torch.long, device=device) + block_table = block_table.view(batch_size, seq_len) + layer_id = 10 + logit_cap = None + orig_context_length = seq_len + max_context_length = seq_len + is_kv_cache_offload_enable = False + rope_range = (0, head_dim) + extend_prefix_lens_cpu = [0,] * batch_size + extend_seq_lens_cpu = [seq_len,] * batch_size + + torch.cuda.synchronize() + + start = torch.cuda.Event(True) + end = torch.cuda.Event(True) + + start.record() + + forward_paged_hip( + query=query, + sm_scale=1 / (head_dim ** 0.5), + batch_size=batch_size, + k_cache=k_cache, + v_cache=v_cache, + offload_cache=None, + positions=positions, + seq_lens=seq_lens, + req_to_tokens=None, + req_pool_indices=None, + block_table=block_table, + rope_cos=None, + rope_sin=None, + layer_id=layer_id, + logit_cap=logit_cap, + orig_context_len=orig_context_length, + max_context_len=max_context_length, + hip_config=hip_config, + is_kv_cache_offload_enabled=is_kv_cache_offload_enable, + rope_range=rope_range, + extend_prefix_lens_cpu=extend_prefix_lens_cpu, + extend_seq_lens_cpu=extend_seq_lens_cpu, + is_decode=False, + ) + + end.record() + end.synchronize() + return start.elapsed_time(end) + +def try_set_environ(name: str, value): + if name in os.environ: + return + os.environ[name] = value + +def evaluate_autotune( + sa_block_size: int, + bsa_block_k: int, + hip_config: HiPAttentionConfig, +): + model_name = "Qwen/Qwen3-235B-A22B-Instruct-2507" + + try_set_environ("BSA_K", "32") + try_set_environ("BSA_EXACT_K", "32") + try_set_environ("BSA_BLOCK_K", str(bsa_block_k)) + try_set_environ("HIP_DEBUG_DELTA_QSA", "1") + try_set_environ("HIP_DEBUG_RECOMPUTE_SPLIT", "0") + try_set_environ("TRITON_PRINT_AUTOTUNING", "1") + try_set_environ("SA_BLOCK_SIZE", str(sa_block_size)) + try_set_environ("SA_DECODE_BLOCK_SIZE", "128") + try_set_environ("HIP_DISABLE_AUTOTUNE", "0") + + n_warmup = 3 + n_measure = 20 + n_tp = 8 + + config = AutoConfig.from_pretrained(model_name) + q_head = config.num_attention_heads + q_head = triton.cdiv(q_head, n_tp) + kv_head = config.num_key_value_heads + kv_head = triton.cdiv(kv_head, n_tp) + head_dim = config.hidden_size // config.num_attention_heads + + seq_lens = [32, 64, 128, 256, 384, 512, 768, 1024] + dtypes = [torch.bfloat16, torch.float8_e5m2] + + data = [] + + for seq_len in seq_lens: + for dtype in dtypes: + for _ in range(n_warmup): + try: + forward_seq_len( + dtype, + seq_len * 1024, + q_head, + kv_head, + head_dim, + hip_config=hip_config, + ) + except Exception: + traceback.print_exc() + + latencies = [] + for _ in range(n_measure): + try: + latency = forward_seq_len( + dtype, + seq_len * 1024, + q_head, + kv_head, + head_dim, + hip_config=hip_config, + ) + exception = "" + except Exception: + latency = float("nan") + exception = traceback.format_exc() + latencies.append(latency) + latency = sum(latencies) / len(latencies) + + data_point = { + "dtype": str(dtype), + "seq_len": seq_len, + "model": model_name, + "latency": latency, + "exception": exception, + } + print(data_point, flush=True) + data.append(data_point) + + return data + +def main(): + hip_config = HiPAttentionConfig( + json_or_path="./configs/mixed_landmark_0814_no_extend_qsa.json", + json_override='{"__seq_thresh_fa3": 0}' + ) + + bsa_block_ks = [32, 64] + sa_block_sizes = [64, 128, 256] + + data = [] + + for bsa_block_k in bsa_block_ks: + for sa_block_size in sa_block_sizes: + data.extend(evaluate_autotune( + bsa_block_k=bsa_block_k, + sa_block_size=sa_block_size, + hip_config=hip_config, + )) + + os.makedirs("saves/bench_latency_paged_attn", exist_ok=True) + with open("saves/bench_latency_paged_attn/measures.json", "w") as f: + json.dump(data, f) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index 6a5f4c1a..7df14fa3 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -198,9 +198,14 @@ def flash_attn_varlen_func( split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, + get_tp_group, ) - SGLANG_DIST_ACTIVATED = True + try: + get_tp_group() + SGLANG_DIST_ACTIVATED = True + except AssertionError: + SGLANG_DIST_ACTIVATED = False except ImportError as ex: SGLANG_DIST_ACTIVATED = False @@ -433,8 +438,11 @@ def forward_paged_hip( positions=positions[start_len : start_len + seq_len], seq_lens=seq_lens[idx_batch : idx_batch + 1], req_to_tokens=req_to_tokens, - req_pool_indices=req_pool_indices[idx_batch : idx_batch + 1], - block_table=None, + req_pool_indices=( + req_pool_indices[idx_batch : idx_batch + 1] + if req_pool_indices is not None else None + ), + block_table=block_table, rope_cos=rope_cos, rope_sin=rope_sin, rope_range=rope_range, From 4bdc62c2fbce45fabc9c737863e1af1d31f889df Mon Sep 17 00:00:00 2001 From: AinL Date: Mon, 3 Nov 2025 02:38:37 +0000 Subject: [PATCH 13/13] fix --- scripts/bench_latency_paged_attn.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scripts/bench_latency_paged_attn.py b/scripts/bench_latency_paged_attn.py index 6553c0b9..eda3e2a9 100644 --- a/scripts/bench_latency_paged_attn.py +++ b/scripts/bench_latency_paged_attn.py @@ -5,6 +5,7 @@ import os import json import traceback +import pandas as pd import torch from transformers import AutoConfig import triton @@ -108,7 +109,7 @@ def evaluate_autotune( try_set_environ("HIP_DISABLE_AUTOTUNE", "0") n_warmup = 3 - n_measure = 20 + n_measure = 100 n_tp = 8 config = AutoConfig.from_pretrained(model_name) @@ -159,6 +160,8 @@ def evaluate_autotune( data_point = { "dtype": str(dtype), "seq_len": seq_len, + "bsa_block_k": bsa_block_k, + "sa_block_size": sa_block_size, "model": model_name, "latency": latency, "exception": exception, @@ -174,8 +177,8 @@ def main(): json_override='{"__seq_thresh_fa3": 0}' ) - bsa_block_ks = [32, 64] - sa_block_sizes = [64, 128, 256] + bsa_block_ks = [64, 32] + sa_block_sizes = [256, 128, 64] data = [] @@ -190,6 +193,9 @@ def main(): os.makedirs("saves/bench_latency_paged_attn", exist_ok=True) with open("saves/bench_latency_paged_attn/measures.json", "w") as f: json.dump(data, f) + + df = pd.DataFrame(data) + df.to_csv("saves/bench_latency_paged_attn/measures.csv") if __name__ == "__main__": main() \ No newline at end of file