信号中间件

hs-net 提供三个信号钩子,用于在请求生命周期中注入自定义逻辑。

三种信号

信号触发时机接收参数可返回
on_request_before请求发送前RequestModelRequestModelResponse
on_response_after响应返回后ResponseResponse
on_request_retry请求重试时Exception

请求前中间件

before.py
with SyncNet() as net:

    @net.on_request_before
    def log_request(req_data):
        """记录每个请求。"""
        print(f"→ {req_data.method} {req_data.url}")

修改请求参数

modify_request.py
@net.on_request_before
def add_auth(req_data):
    """自动添加认证头。"""
    req_data.headers["Authorization"] = "Bearer my-token"
    return req_data  # 返回修改后的 RequestModel

短路返回(请求缓存)

返回 Response 对象会跳过实际网络请求:

cache.py
cache = {}

@net.on_request_before
def check_cache(req_data):
    """缓存命中则直接返回,不发送网络请求。"""
    if req_data.url in cache:
        return cache[req_data.url]  # 返回 Response,跳过请求

@net.on_response_after
def save_cache(resp):
    """将响应存入缓存。"""
    cache[resp.url] = resp

响应后中间件

after.py
@net.on_response_after
def track_status(resp):
    """统计响应状态码。"""
    print(f"← {resp.status_code} {resp.url}")

替换响应

返回新的 Response 会替换原始响应:

replace_response.py
@net.on_response_after
def filter_response(resp):
    """修改或替换响应。"""
    if some_condition:
        return modified_response
    # 不返回则使用原始响应

重试中间件

retry.py
@net.on_request_retry
def on_retry(exc):
    """记录每次重试的原因。"""
    print(f"⟳ 重试: {type(exc).__name__}: {exc}")

异步中间件

异步客户端 Net 支持 async 中间件:

async_middleware.py
async with Net() as net:

    @net.on_request_before
    async def async_before(req_data):
        # 可以执行异步操作
        token = await get_token_from_redis()
        req_data.headers["Authorization"] = f"Bearer {token}"
        return req_data

    @net.on_response_after
    async def async_after(resp):
        # 异步存储
        await save_to_database(resp.json_data)
混合使用

异步客户端的中间件可以同时注册同步和异步回调,框架会自动识别并正确调用。

多个中间件

同一信号可以注册多个中间件,按注册顺序执行:

multi.py
@net.on_request_before
def middleware_1(req_data):
    print("第一个中间件")

@net.on_request_before
def middleware_2(req_data):
    print("第二个中间件")

# 执行顺序: middleware_1 → middleware_2

实战:请求计时

timing.py
import time

with SyncNet() as net:
    timings = {}

    @net.on_request_before
    def start_timer(req_data):
        timings[req_data.url] = time.time()

    @net.on_response_after
    def end_timer(resp):
        elapsed = time.time() - timings[resp.url]
        print(f"{resp.url} 耗时 {elapsed:.3f}s")

    net.get("https://example.com")
    # => https://example.com 耗时 0.234s

实战:请求统计

stats.py
import asyncio

async def main():
    stats = {"total": 0, "success": 0, "fail": 0}

    async with Net(concurrency=5) as net:

        @net.on_request_before
        async def count(req_data):
            stats["total"] += 1

        @net.on_response_after
        async def track(resp):
            if resp.ok:
                stats["success"] += 1
            else:
                stats["fail"] += 1

        urls = [f"https://example.com/?p={i}" for i in range(20)]
        await asyncio.gather(*[net.get(url) for url in urls])

    print(f"总计: {stats['total']}, 成功: {stats['success']}, 失败: {stats['fail']}")