#信号中间件
hs-net 提供三个信号钩子,用于在请求生命周期中注入自定义逻辑。
#三种信号
| 信号 | 触发时机 | 接收参数 | 可返回 |
|---|---|---|---|
on_request_before | 请求发送前 | RequestModel | RequestModel 或 Response |
on_response_after | 响应返回后 | Response | Response |
on_request_retry | 请求重试时 | Exception | — |
#请求前中间件
before.py
with SyncNet() as net:
@net.on_request_before
def log_request(req_data):
"""记录每个请求。"""
print(f"→ {req_data.method} {req_data.url}")#修改请求参数
modify_request.py
@net.on_request_before
def add_auth(req_data):
"""自动添加认证头。"""
req_data.headers["Authorization"] = "Bearer my-token"
return req_data # 返回修改后的 RequestModel#短路返回(请求缓存)
返回 Response 对象会跳过实际网络请求:
cache.py
cache = {}
@net.on_request_before
def check_cache(req_data):
"""缓存命中则直接返回,不发送网络请求。"""
if req_data.url in cache:
return cache[req_data.url] # 返回 Response,跳过请求
@net.on_response_after
def save_cache(resp):
"""将响应存入缓存。"""
cache[resp.url] = resp#响应后中间件
after.py
@net.on_response_after
def track_status(resp):
"""统计响应状态码。"""
print(f"← {resp.status_code} {resp.url}")#替换响应
返回新的 Response 会替换原始响应:
replace_response.py
@net.on_response_after
def filter_response(resp):
"""修改或替换响应。"""
if some_condition:
return modified_response
# 不返回则使用原始响应#重试中间件
retry.py
@net.on_request_retry
def on_retry(exc):
"""记录每次重试的原因。"""
print(f"⟳ 重试: {type(exc).__name__}: {exc}")#异步中间件
异步客户端 Net 支持 async 中间件:
async_middleware.py
async with Net() as net:
@net.on_request_before
async def async_before(req_data):
# 可以执行异步操作
token = await get_token_from_redis()
req_data.headers["Authorization"] = f"Bearer {token}"
return req_data
@net.on_response_after
async def async_after(resp):
# 异步存储
await save_to_database(resp.json_data)混合使用
异步客户端的中间件可以同时注册同步和异步回调,框架会自动识别并正确调用。
#多个中间件
同一信号可以注册多个中间件,按注册顺序执行:
multi.py
@net.on_request_before
def middleware_1(req_data):
print("第一个中间件")
@net.on_request_before
def middleware_2(req_data):
print("第二个中间件")
# 执行顺序: middleware_1 → middleware_2#实战:请求计时
timing.py
import time
with SyncNet() as net:
timings = {}
@net.on_request_before
def start_timer(req_data):
timings[req_data.url] = time.time()
@net.on_response_after
def end_timer(resp):
elapsed = time.time() - timings[resp.url]
print(f"{resp.url} 耗时 {elapsed:.3f}s")
net.get("https://example.com")
# => https://example.com 耗时 0.234s#实战:请求统计
stats.py
import asyncio
async def main():
stats = {"total": 0, "success": 0, "fail": 0}
async with Net(concurrency=5) as net:
@net.on_request_before
async def count(req_data):
stats["total"] += 1
@net.on_response_after
async def track(resp):
if resp.ok:
stats["success"] += 1
else:
stats["fail"] += 1
urls = [f"https://example.com/?p={i}" for i in range(20)]
await asyncio.gather(*[net.get(url) for url in urls])
print(f"总计: {stats['total']}, 成功: {stats['success']}, 失败: {stats['fail']}")
