Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _fake_agent_loop():
rollout_ctl = MagicMock()
rollout_ctl.continue_generation.remote = AsyncMock()
rollout_ctl.pause_generation.remote = AsyncMock()
rollout_ctl.cleanup_after_pause.remote = AsyncMock()
rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
agent_loop = MagicMock()
agent_loop.rollout_ctl = rollout_ctl
Expand Down
11 changes: 8 additions & 3 deletions tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class MockRolloutState:
def __init__(self, id, seq_staleness=1, status=Status.COMPLETED, reward_score=None):
self.id = id
self.uid = id
self.message_uid = None
self.session_uid = None
self.status = status
self.finish_reason = "abort" if status == Status.ABORTED else "stop"
self.seq_staleness = seq_staleness
self.response_ids = []
self.extra_fields = {}
Expand Down Expand Up @@ -87,9 +90,13 @@ def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None):
mock_agent_loop = MagicMock()
mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.cleanup_after_pause.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})

async def mock_pause():
await mock_agent_loop.rollout_ctl.pause_generation.remote()

mock_agent_loop.pause = mock_pause

sleep_by_id = sleep_by_id or {}

async def mock_gen(rs, **kwargs):
Expand Down Expand Up @@ -626,7 +633,6 @@ async def test_async_produce_strategy_pause_produce_is_explicit(self):
expired = await self.replay_buffer.count(task_name, Status.EXPIRED)
self.assertEqual(completed + aborted + expired, 3)
self.assertEqual(mock_agent_loop.rollout_ctl.pause_generation.remote.await_count, 1)
self.assertEqual(mock_agent_loop.rollout_ctl.cleanup_after_pause.remote.await_count, 1)

async def test_async_produce_strategy_pause_produce_collects_without_cancelling(self):
# 验证 pending task 在 pause 等待窗口内完成时会被收集,而不是直接取消丢失结果。
Expand Down Expand Up @@ -657,7 +663,6 @@ async def test_async_produce_strategy_pause_produce_collects_without_cancelling(
expired = await self.replay_buffer.count(task_name, Status.EXPIRED)
self.assertEqual(completed + aborted + expired, 3)
self.assertEqual(mock_agent_loop.rollout_ctl.pause_generation.remote.await_count, 1)
self.assertEqual(mock_agent_loop.rollout_ctl.cleanup_after_pause.remote.await_count, 1)

async def test_async_produce_strategy_returns_update_abort_without_sampling(self):
# 验证 update_event 已设置时策略立即返回 UPDATE_WEIGHT_AND_ABORT,不再采样新 rollout。
Expand Down
30 changes: 0 additions & 30 deletions tests/rl/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from unittest.mock import AsyncMock, MagicMock, patch

from xtuner.v1.data_proto.rl_data import Status
from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker
from xtuner.v1.rl.rollout.sglang import SGLangWorker
from xtuner.v1.rl.rollout.worker import RolloutWorker
from xtuner.v1.utils.httpx_utils import HttpRequestErrorType
Expand Down Expand Up @@ -61,13 +60,6 @@ async def test_pause_generation_sets_abort_flag(self):
self.assertTrue(worker.receive_abort_request.is_set())
worker._send_abort_request.assert_awaited_once_with()

async def test_cleanup_after_pause_is_noop_by_default(self):
worker = RolloutWorker.__new__(RolloutWorker)

result = await worker.cleanup_after_pause()

self.assertIsNone(result)

async def test_send_abort_request_uses_abort_timeout(self):
worker = RolloutWorker.__new__(RolloutWorker)
worker.server_url = "http://test"
Expand All @@ -92,28 +84,6 @@ async def test_send_abort_request_uses_abort_timeout(self):
json={"abort_all": True},
)

async def test_lmdeploy_cleanup_after_pause_clears_shared_store_when_routed_experts_enabled(self):
worker = LMDeployWorker.__new__(LMDeployWorker)
worker.enable_return_routed_experts = True
worker.logger = MagicMock()
lmdeploy_actor = MagicMock()
lmdeploy_actor.clear.remote = AsyncMock(return_value=None)

with patch("xtuner.v1.rl.rollout.lmdeploy.ray.get_actor", return_value=lmdeploy_actor) as get_actor:
await worker.cleanup_after_pause()

get_actor.assert_called_once_with("shared_store", namespace="lmdeploy")
lmdeploy_actor.clear.remote.assert_awaited_once_with()

async def test_lmdeploy_cleanup_after_pause_skips_without_routed_experts(self):
worker = LMDeployWorker.__new__(LMDeployWorker)
worker.enable_return_routed_experts = False

with patch("xtuner.v1.rl.rollout.lmdeploy.ray.get_actor") as get_actor:
await worker.cleanup_after_pause()

get_actor.assert_not_called()

async def test_safe_post_request_returns_aborted_on_cancellation(self):
worker = RolloutWorker.__new__(RolloutWorker)
worker.receive_abort_request = threading.Event()
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def reset_rollout_response(rollout_state: RolloutState) -> RolloutState:
from ray import ObjectRef as RayObjectRef

if isinstance(routed_experts, RayObjectRef):
ray.internal.free([routed_experts])
ray.internal.free([routed_experts], local_only=False)
rollout_state.routed_experts = None
prompt_ids = getattr(rollout_state, "prompt_ids", None)
rollout_state.tokens = list(prompt_ids) if prompt_ids is not None else None
Expand Down
69 changes: 64 additions & 5 deletions xtuner/v1/rl/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import asyncio
import math
from abc import ABC, abstractmethod
from typing import TypeAlias, cast

import ray
from pydantic import BaseModel, ConfigDict
from ray.actor import ActorClass, ActorProxy
from ray.util.placement_group import PlacementGroup
Expand All @@ -21,6 +23,10 @@
from xtuner.v1.utils.processing_utils import load_processor, load_tokenizer


AGENT_LOOP_CONCURRENCY_GROUP_GENERATE = "generate"
DEFAULT_JUDGER_CANCEL_TIMEOUT_S = 5.0


class AgentLoopConfig(ABC, BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
hf_checkpoint: str
Expand All @@ -35,6 +41,13 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) -
logger=logger,
)

get_generate_concurrency = rollout_controller.get_generate_concurrency
if hasattr(get_generate_concurrency, "remote"):
total_generate_concurrency = ray.get(get_generate_concurrency.remote())
else:
total_generate_concurrency = get_generate_concurrency()
concurrency = max(1, math.ceil(total_generate_concurrency / self.cpu_resources.num_workers))

register_cpu_resources(
name=f"agent_loop:{self.__class__.__name__}",
cpu_resources=self.cpu_resources,
Expand All @@ -44,12 +57,14 @@ def build(self, rollout_controller, judger: Judger | None = None, logger=None) -
return self._build_router(
rollout_controller=rollout_controller,
cpu_resources=self.cpu_resources,
concurrency=concurrency,
judger=judger,
logger=logger,
)
return self._build_ray_actor(
rollout_controller=rollout_controller,
cpu_resources=self.cpu_resources,
concurrency=concurrency,
judger=judger,
logger=logger,
)
Expand All @@ -66,14 +81,20 @@ def _build_ray_actor(
self,
rollout_controller: RolloutController,
cpu_resources: CPUResourcesConfig,
concurrency: int,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
logger=None,
) -> RayAgentLoopProxy:
ray_agent_loop = ray.remote(
concurrency_groups={
AGENT_LOOP_CONCURRENCY_GROUP_GENERATE: concurrency,
},
)(AgentLoopActor)
return cast(
"RayAgentLoopProxy",
CPUActorLauncher.build_actor(
AgentLoopActor,
ray_agent_loop,
self,
rollout_controller,
judger,
Expand All @@ -89,15 +110,21 @@ def _build_ray_actors(
self,
rollout_controller: RolloutController,
cpu_resources: CPUResourcesConfig,
concurrency: int,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
logger=None,
start_bundle_idx: int = 0,
) -> list[RayAgentLoopProxy]:
ray_agent_loop = ray.remote(
concurrency_groups={
AGENT_LOOP_CONCURRENCY_GROUP_GENERATE: concurrency,
},
)(AgentLoopActor)
return cast(
list["RayAgentLoopProxy"],
CPUActorLauncher.build_actors(
AgentLoopActor,
ray_agent_loop,
self,
rollout_controller,
judger,
Expand All @@ -114,6 +141,7 @@ def _build_router(
self,
rollout_controller: RolloutController,
cpu_resources: CPUResourcesConfig,
concurrency: int,
pg: PlacementGroup | None = None,
judger: Judger | None = None,
logger=None,
Expand All @@ -123,6 +151,7 @@ def _build_router(
workers=self._build_ray_actors(
rollout_controller=rollout_controller,
cpu_resources=cpu_resources,
concurrency=concurrency,
pg=pg,
judger=judger,
logger=logger,
Expand Down Expand Up @@ -165,6 +194,20 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l
group_samples = await generated_samples
return group_samples

async def pause(self) -> None:
# Base AgentLoop only pauses rollout generation.
#
# We intentionally do not define generic judger pause behavior in the
# Judger base class. Judger subclasses can implement judge() in very
# different ways, and one base pause implementation cannot cover all of
# them. Requiring users to follow a base-class pause protocol would also
# increase the mental overhead of writing a new judge() implementation.
#
# For now, only SingleTurnAgentLoop defines how to pause an in-flight
# judger call. Other AgentLoop subclasses should override pause() if
# they need their own judger pause semantics.
await self.rollout_ctl.pause_generation.remote() # type: ignore[attr-defined]


class RouterAgentLoop:
def __init__(self, workers: list[RayAgentLoopProxy], rollout_ctl: RolloutController):
Expand Down Expand Up @@ -204,6 +247,11 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l
def get_worker_status(self) -> dict[str, int]:
return {str(worker): load for worker, load in self._worker_loads.items()}

async def pause(self) -> None:
await asyncio.gather(
*(worker.pause.remote() for worker in self.workers),
)


async def get_agent_loop_rollout_ctl(agent_loop: AgentLoopSpec) -> RolloutController:
rollout_ctl = getattr(agent_loop, "rollout_ctl", None)
Expand All @@ -230,19 +278,30 @@ def __init__(
logger=logger,
)

@ray_method
@ray_method(concurrency_group=AGENT_LOOP_CONCURRENCY_GROUP_GENERATE)
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
return await self.agent_loop.generate_sample(rollout_state, **kwargs)

@ray_method
@ray_method(concurrency_group=AGENT_LOOP_CONCURRENCY_GROUP_GENERATE)
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
return await self.agent_loop.generate_group(rollout_state, **kwargs)

@ray_method
async def get_rollout_ctl(self):
return self.agent_loop.rollout_ctl

@ray_method
async def pause(self) -> None:
return await self.agent_loop.pause()


RayAgentLoop = cast(ActorClass[AgentLoopActor], CPUActorLauncher.to_actor_class(AgentLoopActor))
RayAgentLoop = cast(
ActorClass[AgentLoopActor],
ray.remote(
concurrency_groups={
AGENT_LOOP_CONCURRENCY_GROUP_GENERATE: 1000,
},
)(AgentLoopActor),
)
RayAgentLoopProxy: TypeAlias = ActorProxy[AgentLoopActor]
AgentLoopSpec: TypeAlias = AgentLoop | RayAgentLoopProxy | RouterAgentLoop
57 changes: 52 additions & 5 deletions xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from typing import overload

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.utils import create_task
from xtuner.v1.rl.utils import cancel_and_drain, create_task

from .agent_loop import AgentLoop, AgentLoopConfig
from .agent_loop import DEFAULT_JUDGER_CANCEL_TIMEOUT_S, AgentLoop, AgentLoopConfig


class SingleTurnAgentLoopConfig(AgentLoopConfig):
Expand Down Expand Up @@ -62,6 +63,51 @@ def __init__(
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.enable_batch_judge = enable_batch_judge
self._pause_event = asyncio.Event()

@overload
async def run_judger(self, rollout_state: RolloutState) -> RolloutState: ...

@overload
async def run_judger(self, rollout_state: list[RolloutState]) -> list[RolloutState]: ...

async def run_judger(self, rollout_state: RolloutState | list[RolloutState]) -> RolloutState | list[RolloutState]:
assert self.judger is not None
judge_task = create_task(self.judger.judge(rollout_state))
pause_task = create_task(self._pause_event.wait())
try:
done, _ = await asyncio.wait({judge_task, pause_task}, return_when=asyncio.FIRST_COMPLETED)
if judge_task in done:
return await judge_task
try:
return await asyncio.wait_for(
asyncio.shield(judge_task),
timeout=DEFAULT_JUDGER_CANCEL_TIMEOUT_S,
)
except asyncio.TimeoutError:
await cancel_and_drain([judge_task])
for sample in rollout_state if isinstance(rollout_state, list) else [rollout_state]:
sample.status = Status.ABORTED
sample.finish_reason = "abort"
sample.reward = None
return rollout_state
except asyncio.CancelledError:
await cancel_and_drain([judge_task])
for sample in rollout_state if isinstance(rollout_state, list) else [rollout_state]:
sample.status = Status.ABORTED
sample.finish_reason = "abort"
sample.reward = None
return rollout_state
finally:
await cancel_and_drain([pause_task])

async def pause(self) -> None:
self._pause_event.set()
# TODO: Decide whether Judger needs an explicit pause API for resources not owned by SingleTurnAgentLoop.
try:
await super().pause()
finally:
self._pause_event.clear()

async def generate_sample(
self,
Expand All @@ -78,7 +124,7 @@ async def generate_sample(
return rollout_state
if self.judger is not None and not self.enable_batch_judge:
# 如果开启了批量打分,则在 generate_group 里统一打分,不在这里逐条打分
rollout_state = await self.judger.judge(rollout_state)
rollout_state = await self.run_judger(rollout_state)
return rollout_state

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
Expand All @@ -90,6 +136,7 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l
generated_samples = asyncio.gather(*pending_tasks)
group_samples = await generated_samples
if self.judger is not None and self.enable_batch_judge:
# 批量打分
group_samples = await self.judger.judge(group_samples)
if not any(sample.status == Status.ABORTED for sample in group_samples):
# 批量打分
group_samples = await self.run_judger(group_samples)
return group_samples
Loading
Loading