Multi-Agent Reinforcement Learning: MAPPO Prototype in Python
1. High‐Level Structure
- Config: Hyperparameters and paths in one place.
- Environment (
PromptEnv
): Encapsulates prompt‐embedding generation and reward logic. - Networks (
ActorCritic
): Shared encoder + separate policy/value heads. - Buffer (
RolloutBuffer
): Stores one‐step (or multi‐step) transitions, computes advantages/returns. - Agent (
MAPPOAgent
): Wraps the networks, providesselect_action()
,evaluate()
, save/load. - Trainer (
MAPPOTrainer
): Orchestrates experience collection, GAE, PPO updates, logging, checkpointing. - Main: Instantiates everything, kicks off training.
1.1. Python Standard Imports & Device Setup
import os import time import random import logging from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.distributions import Categorical
- We use
logging
for standardized, leveled output. device
handling allows CUDA if available.
# Device selection DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("MAPPO")
2. Config
Centralize all hyperparameters and file paths. If you grow this into multiple experiments, replace with a YAML/JSON + Hydra/Argparse later.
class Config: # Environment / Data EMBED_DIM: int = 32 # Dimension of prompt embeddings NUM_AGENTS: int = 4 # Discrete action space size HIDDEN_DIM: int = 64 # Hidden size for networks # PPO Hyperparameters LEARNING_RATE_ACTOR: float = 3e-4 LEARNING_RATE_CRITIC: float = 1e-3 GAMMA: float = 0.99 GAE_LAMBDA: float = 0.95 CLIP_EPS: float = 0.2 ENTROPY_COEF: float = 0.01 VALUE_LOSS_COEF: float = 0.5 MAX_GRAD_NORM: float = 0.5 # Training loop TOTAL_UPDATES: int = 2000 # Number of PPO update iterations BATCH_SIZE: int = 128 # Batch size per update PPO_EPOCHS: int = 4 # Number of PPO passes per batch # Checkpointing CHECKPOINT_DIR: str = "./checkpoints" SAVE_INTERVAL: int = 500 # Save model every N updates # Reproducibility SEED: int = 42 @classmethod def init_seed(cls): torch.manual_seed(cls.SEED) np.random.seed(cls.SEED) random.seed(cls.SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(cls.SEED)
- Call
Config.init_seed()
at the start of your main routine to fix seeds across PyTorch, NumPy, and Python’srandom
.
3. Environment: PromptEnv
Encapsulates prompt‐embedding generation and reward logic. In production, swap out get_synthetic_prompt()
with your real embedding pipeline (e.g., SBERT + projection). The interface mimics a minimal one‐step Gym:
reset()
: returns a new prompt embedding.step(action)
: given an action (agent index), returns(next_obs, reward, done, info)
. Here,done=True
immediately since each “episode” is one decision.
class PromptEnv: """ Synthetic one-step environment for MAPPO: - Observation: a prompt embedding (Config.EMBED_DIM). - Action: integer in [0 .. Config.NUM_AGENTS-1]. - Reward: cosine similarity between prompt and agent-specialty embedding + noise. """ def __init__(self, cfg: Config): self.cfg = cfg # Pre-generate fixed agent embeddings (NUM_AGENTS x EMBED_DIM) self.agent_embeddings = torch.randn(cfg.NUM_AGENTS, cfg.EMBED_DIM, device=DEVICE) # We’ll generate a new prompt each reset() self.current_prompt: Optional[torch.Tensor] = None def reset(self) -> torch.Tensor: """ Generate a new random prompt embedding. In production, replace with actual SBERT projection logic. Returns: prompt_embedding (tensor, shape [EMBED_DIM]) """ # Simulate SBERT: random normal, then normalize prompt = torch.randn(self.cfg.EMBED_DIM, device=DEVICE) prompt = F.normalize(prompt, dim=0) self.current_prompt = prompt return prompt def step(self, action: int) -> Tuple[torch.Tensor, float, bool, dict]: """ Given an action (agent index), compute reward and return a new prompt embedding (next_obs), but we treat each episode as one-step: next_obs is unused, done=True. """ assert ( 0 <= action < self.cfg.NUM_AGENTS ), f"Action {action} out of range [0, {self.cfg.NUM_AGENTS-1}]" # Cosine similarity ([-1,1]) between prompt & agent agent_vec = self.agent_embeddings[action] prompt_vec = self.current_prompt cos_sim = F.cosine_similarity( prompt_vec.unsqueeze(0), agent_vec.unsqueeze(0) ).item() # Add small Gaussian noise reward = cos_sim + float(np.random.normal(scale=0.1)) # This environment is one-step; done always True done = True info = {} # Typically next_obs would be used if multi-step; here we call reset() explicitly next_obs = torch.zeros(self.cfg.EMBED_DIM, device=DEVICE) # placeholder return next_obs, reward, done, info
- In a real Ohwise pipeline,
reset()
would embed a user’s text (via SBERT), andstep()
would forward that through the chosen agent and measure actual user feedback (like satisfaction score, task success). For now, we keep it synthetic.
4. Networks: ActorCritic
A single PyTorch module with a shared encoder (two‐layer MLP) followed by separate policy head (to produce logits over actions) and value head.
class ActorCritic(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, num_actions: int): super().__init__() # Shared encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) # Policy head self.policy_head = nn.Linear(hidden_dim, num_actions) # Value head self.value_head = nn.Linear(hidden_dim, 1) # Initialize weights (orthogonal + small gain) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain=nn.init.calculate_gain("relu")) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through shared encoder → policy logits & state value. :param x: shape [batch_size, input_dim] :returns: - logits: [batch_size, num_actions] - values: [batch_size, 1] """ hidden = self.encoder(x) logits = self.policy_head(hidden) values = self.value_head(hidden) return logits, values
- We use orthogonal initialization for stability.
forward()
returns bothlogits
(for action distribution) andvalues
(critic estimate).
5. Rollout Buffer: RolloutBuffer
Stores one‐step transitions (and can be extended to multi‐step). After collecting BATCH_SIZE
transitions, we compute advantages and returns. For multi‐step or multi‐agent scenarios, you’d adapt GAE across time; here, each “episode” is a single step.
class RolloutBuffer: """ Buffer to collect one-step transitions for PPO updates. Stores: - observations (tensor [BATCH, EMBED_DIM]) - actions (tensor [BATCH]) - log_probs (tensor [BATCH]) - rewards (tensor [BATCH]) - values (tensor [BATCH]) After full batch: compute advantages and returns. """ def __init__(self, cfg: Config): self.cfg = cfg self.obs_buf: List[torch.Tensor] = [] self.actions_buf: List[int] = [] self.log_probs_buf: List[float] = [] self.rewards_buf: List[float] = [] self.values_buf: List[float] = [] def add( self, obs: torch.Tensor, action: int, log_prob: float, reward: float, value: float, ) -> None: self.obs_buf.append(obs.detach().cpu()) self.actions_buf.append(action) self.log_probs_buf.append(log_prob) self.rewards_buf.append(reward) self.values_buf.append(value) def compute_returns_and_advantages(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Converts lists into tensors, computes: - returns: reward (since one-step) → shape [BATCH] - advantages: reward - value (since no next_state) Normalizes advantages. """ device = DEVICE rewards = torch.tensor(self.rewards_buf, dtype=torch.float32, device=device) # [BATCH] values = torch.tensor(self.values_buf, dtype=torch.float32, device=device) # [BATCH] # One-step: target_value = reward (no gamma*V(next_state) since done=True) returns = rewards.clone() # Advantage = (reward + gamma*0) - value advantages = rewards - values advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Convert other buffers obs = torch.stack(self.obs_buf).to(device) # [BATCH, EMBED_DIM] actions = torch.tensor(self.actions_buf, device=device) # [BATCH] old_log_probs = torch.tensor(self.log_probs_buf, device=device) # [BATCH] return obs, actions, old_log_probs, returns, advantages def clear(self) -> None: self.obs_buf.clear() self.actions_buf.clear() self.log_probs_buf.clear() self.rewards_buf.clear() self.values_buf.clear()
- In a true multi‐step environment, you would accumulate entire trajectories, then run GAE across time to compute multi‐step returns. Here, we simplify because each “episode” is immediately terminal.
6. Agent: MAPPOAgent
Encapsulates the ActorCritic
networks, optimizers, action selection, evaluation, and checkpointing.
class MAPPOAgent: def __init__(self, cfg: Config): self.cfg = cfg self.device = DEVICE # Initialize networks self.ac_net = ActorCritic( input_dim=cfg.EMBED_DIM, hidden_dim=cfg.HIDDEN_DIM, num_actions=cfg.NUM_AGENTS, ).to(self.device) # Separate optimizers for actor and critic # Note: In PPO, we often update both using a single optimizer, # but separating can give fine-grained lr control. self.actor_optimizer = optim.Adam( self.ac_net.policy_head.parameters(), lr=cfg.LEARNING_RATE_ACTOR ) self.critic_optimizer = optim.Adam( self.ac_net.value_head.parameters(), lr=cfg.LEARNING_RATE_CRITIC ) def select_action(self, obs: torch.Tensor) -> Tuple[int, float, float]: """ Given a single observation (prompt embedding), compute: - action (int) - log_prob of that action (float) - state_value (float) :param obs: [EMBED_DIM] tensor """ obs = obs.to(self.device).unsqueeze(0) # [1, EMBED_DIM] with torch.no_grad(): logits, value = self.ac_net(obs) # logits: [1, NUM_AGENTS], value: [1,1] probs = F.softmax(logits, dim=-1) # [1, NUM_AGENTS] dist = Categorical(probs=probs) action = dist.sample() # [1] log_prob = dist.log_prob(action) # [1] return int(action.item()), float(log_prob.item()), float(value.item()) def evaluate( self, obs_batch: torch.Tensor, action_batch: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ For a batch of observations & actions, compute: - new log_probs [BATCH] - state values [BATCH] - entropy [scalar] """ logits, values = self.ac_net(obs_batch) # [BATCH, NUM_AGENTS], [BATCH,1] probs = F.softmax(logits, dim=-1) # [BATCH, NUM_AGENTS] dist = Categorical(probs=probs) new_log_probs = dist.log_prob(action_batch) # [BATCH] entropy = dist.entropy().mean() # scalar return new_log_probs, values.squeeze(1), entropy def save_checkpoint(self, update_step: int) -> None: """ Saves actor & critic state dictionaries and optimizer states. """ os.makedirs(self.cfg.CHECKPOINT_DIR, exist_ok=True) ckpt_path = os.path.join(self.cfg.CHECKPOINT_DIR, f"mappo_ckpt_{update_step}.pth") torch.save( { "ac_state_dict": self.ac_net.state_dict(), "actor_opt_state": self.actor_optimizer.state_dict(), "critic_opt_state": self.critic_optimizer.state_dict(), "update_step": update_step, }, ckpt_path, ) logger.info(f"Saved checkpoint: {ckpt_path}") def load_checkpoint(self, filepath: str) -> int: """ Loads from a given checkpoint file. Returns `update_step` stored in it. """ checkpoint = torch.load(filepath, map_location=self.device) self.ac_net.load_state_dict(checkpoint["ac_state_dict"]) self.actor_optimizer.load_state_dict(checkpoint["actor_opt_state"]) self.critic_optimizer.load_state_dict(checkpoint["critic_opt_state"]) update_step = checkpoint["update_step"] logger.info(f"Loaded checkpoint '{filepath}' at update_step={update_step}") return update_step
- We separate actor/critic optimizers so that we can tune LRs independently (actor often needs smaller LR).
select_action()
works on a single observation.evaluate()
takes a batch and returns new log‐probs, values, and entropy (for the PPO loss).- Checkpointing: saves combined state in one file. You might later store scheduler states, random seeds, etc.
7. Trainer: MAPPOTrainer
Coordinates environment interaction, buffer aggregation, PPO updates, logging, and checkpointing. At each update:
- Collect BATCH_SIZE one-step transitions (obs → action → reward → value).
- Move them into
RolloutBuffer
. - Compute advantages and returns.
- Run PPO epochs (several passes over the same batch).
- Every
SAVE_INTERVAL
, checkpoint the model.
class MAPPOTrainer: def __init__(self, cfg: Config): self.cfg = cfg self.device = DEVICE # Initialize environment, agent, buffer self.env = PromptEnv(cfg) self.agent = MAPPOAgent(cfg) self.buffer = RolloutBuffer(cfg) # Metrics tracking self.update_step = 0 self.train_start_time = time.time() def collect_one_batch(self) -> None: """ Collect BATCH_SIZE one-step transitions: (obs, action, log_prob, reward, value) and store them in RolloutBuffer. """ self.buffer.clear() for _ in range(self.cfg.BATCH_SIZE): # 1) reset environment to get new prompt embedding obs = self.env.reset() # [EMBED_DIM] # 2) agent selects an action action, log_prob, value = self.agent.select_action(obs) # 3) environment returns reward (and done) _, reward, done, _ = self.env.step(action) # 4) store in buffer self.buffer.add(obs, action, log_prob, reward, value) # Since environment is one-step, no need to loop on `done` # If multi-step, you’d do something like: # while not done: ... # End for: collected one full batch def update(self) -> None: """ Perform PPO update using data in RolloutBuffer. """ ( obs_batch, action_batch, old_log_probs_batch, returns_batch, advantages_batch, ) = self.buffer.compute_returns_and_advantages() for _ in range(self.cfg.PPO_EPOCHS): # 1) Evaluate current policy on the batch new_log_probs, values_pred, entropy = self.agent.evaluate( obs_batch, action_batch ) # 2) Ratio for clipped surrogate ratios = torch.exp(new_log_probs - old_log_probs_batch) # 3) Surrogate objectives surr1 = ratios * advantages_batch surr2 = ( torch.clamp(ratios, 1.0 - self.cfg.CLIP_EPS, 1.0 + self.cfg.CLIP_EPS) * advantages_batch ) actor_loss = -torch.min(surr1, surr2).mean() - self.cfg.ENTROPY_COEF * entropy # 4) Critic loss (MSE) critic_loss = F.mse_loss(values_pred, returns_batch) # 5) Backprop actor self.agent.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_( self.agent.ac_net.policy_head.parameters(), self.cfg.MAX_GRAD_NORM ) self.agent.actor_optimizer.step() # 6) Backprop critic self.agent.critic_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_( self.agent.ac_net.value_head.parameters(), self.cfg.MAX_GRAD_NORM ) self.agent.critic_optimizer.step() # Log losses & entropy occasionally if self.update_step % 100 == 0: avg_reward = returns_batch.mean().item() logger.info( f"Update {self.update_step:04d} | " f"Avg Reward: {avg_reward:.3f} | " f"Ent: {entropy.item():.3f} | " f"Actor Loss: {actor_loss.item():.3f} | " f"Critic Loss: {critic_loss.item():.3f}" ) def train(self) -> None: """ Main training loop: for TOTAL_UPDATES, collect a batch and update. Checkpoint every SAVE_INTERVAL. """ Config.init_seed() logger.info(f"Starting training on device={self.device}") for update in range(1, self.cfg.TOTAL_UPDATES + 1): self.update_step = update # 1) Collect a batch of experiences self.collect_one_batch() # 2) Perform PPO update self.update() # 3) Checkpoint if update % self.cfg.SAVE_INTERVAL == 0 or update == self.cfg.TOTAL_UPDATES: self.agent.save_checkpoint(update) elapsed = time.time() - self.train_start_time logger.info(f"Training completed in {elapsed/60:.2f} minutes.")
- We log every 100 updates. In production, you might integrate TensorBoard or Weights & Biases.
- This trainer currently blocks the entire thread—if you want to run background checkpointing or asynchronous evaluation, you would dispatch those in separate threads or processes.
8. Main Entry Point
Tie everything together. Optionally parse command‐line arguments (omitted here for brevity) or modify Config
directly.
def main(): cfg = Config() trainer = MAPPOTrainer(cfg) trainer.train() if __name__ == "__main__": main()
9. How to Run & Extend
-
Save this file (e.g.,
mappo.py
). -
Install requirements:
pip install torch torchvision # (Add sentence-transformers when you replace synthetic prompts)
-
Run:
python mappo.py
-
Checkpoint Files will appear in
./checkpoints/mappo_ckpt_<step>.pth
. You can later load with:agent = MAPPOAgent(Config) agent.load_checkpoint("./checkpoints/mappo_ckpt_500.pth")
9.1. Replacing Synthetic Prompts with Real Embeddings
-
Install and load a SentenceTransformer model:
from sentence_transformers import SentenceTransformer class PromptEnv: def __init__(self, cfg: Config): # ... self.sbert = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) # projection layer to EMBED_DIM (if your SBERT dims ≠ EMBED_DIM) self.proj = nn.Linear(384, cfg.EMBED_DIM).to(DEVICE) nn.init.orthogonal_(self.proj.weight) nn.init.constant_(self.proj.bias, 0) def reset(self, prompt_text: Optional[str] = None) -> torch.Tensor: if prompt_text is None: prompt_text = random.choice(self.some_prompt_corpus) emb_384 = self.sbert.encode(prompt_text, convert_to_tensor=True).to(DEVICE) emb_32 = self.proj(emb_384.unsqueeze(0)).squeeze(0) emb_32 = F.normalize(emb_32, dim=0) self.current_prompt = emb_32 return emb_32
-
Then in
collect_one_batch()
, instead ofobs = self.env.reset()
, you’d pass actual user text. If collecting data offline, you might loop over a dataset of(text, reward)
pairs.
9.2. Multi‐Step Episodes & GAE
If your “environment” involves multi‐turn dialogs (e.g., agent → user → agent, etc.), convert RolloutBuffer
to store trajectories of variable length, and compute GAE accordingly:
# Pseudocode for multi-step GAE advantages[t] = rewards[t] + gamma * values[t+1] * (1 - done[t]) - values[t] gae[t] = advantages[t] + gamma * lambda * (1 - done[t]) * gae[t+1]
10. Why This Is “Senior‐Grade”
- Modularity: Each class has a single responsibility (Env, Agent, Buffer, Trainer). When new features arise (e.g., real user feedback, multi‐step RL), you only modify the relevant module.
- Config Centralization: All hyperparameters live in
Config
. Switching from synthetic to real embeddings, changing PPO settings, or toggling GPU usage is a one‐line change. - Logging & Checkpointing: Standard
logging
makes debugging easier; regular checkpointing protects long training jobs. - Device‐Agnostic: All tensors and networks move to
DEVICE
, automatically using GPU if available. - Orthogonal Initialization: Helps training stability—especially important in PPO.
- Clear Separation of Actor vs. Critic: Using separate optimizers and heads allows fine‐tuning each’s learning rate and gradient clipping.
- Scaffold for Real Data: The synthetic environment can be replaced by your real Ohwise pipeline (embed user prompt → forward to chosen agent → gather reward from user feedback logs) with minimal changes.
- Extensible Buffer: The
RolloutBuffer
can be extended to multi‐step episodes simply by changing how you store transitions and compute GAE. - Easy Checkpoint Loading: If training is interrupted, you can resume from the last saved step.
In Summary
This refactored prototype gives you a clean, production‐ready foundation for:
- Collecting real data instead of synthetic prompts.
- Fine‐tuning your policy over time as new feedback arrives.
- Scaling out: swap
PromptEnv
with a distributed data pipeline (e.g., read from a Redis or Postgres table of user logs), use multiple worker processes to collect experiences, and centralize the updates on a learner node. - Extending: build hierarchical or multi‐agent pipelines (e.g., first retrieval agent, then summarizer agent) by increasing action dimensions or adding extra heads/layers.
Continue reading
More tutorialJoin the Discussion
Share your thoughts and insights about this tutorial.