# Copyright (c) 2022-2024.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
import typing
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
from torch import Tensor
from vmas.simulator.core import Agent, World
from vmas.simulator.utils import (
AGENT_INFO_TYPE,
AGENT_OBS_TYPE,
AGENT_REWARD_TYPE,
INITIAL_VIEWER_SIZE,
VIEWER_DEFAULT_ZOOM,
)
if typing.TYPE_CHECKING:
from vmas.simulator.rendering import Geom
[docs]
class BaseScenario(ABC):
"""Base class for scenarios.
This is the class that scenarios inherit from.
The methods that are **compulsory to instantiate** are:
- :class:`make_world`
- :class:`reset_world_at`
- :class:`observation`
- :class:`reward`
The methods that are **optional to instantiate** are:
- :class:`info`
- :class:`extra_render`
- :class:`process_action`
"""
def __init__(self):
"""Do not override."""
self._world = None
self.viewer_size = INITIAL_VIEWER_SIZE
"""The size of the rendering viewer window. This can be changed in the :class:`~make_world` function. """
self.viewer_zoom = VIEWER_DEFAULT_ZOOM
"""The zoom of the rendering camera (a lower value means more zoom). This can be changed in the :class:`~make_world` function. """
self.render_origin = (0.0, 0.0)
"""The origin of the rendering camera when ``agent_index_to_focus`` is None in the ``render()`` arguments. This can be changed in the :class:`~make_world` function. """
self.plot_grid = False
"""Whether to plot a grid in the scenario rendering background. This can be changed in the :class:`~make_world` function. """
self.grid_spacing = 0.1
"""If :class:`~plot_grid`, the distance between lines in the background grid. This can be changed in the :class:`~make_world` function. """
@property
def world(self):
"""The :class:`~vmas.simulator.core.World` associated toi this scenario."""
assert (
self._world is not None
), "You first need to set `self._world` in the `make_world` method"
return self._world
[docs]
def to(self, device: torch.device):
"""Casts the scenario to a different device.
Args:
device (Union[str, int, torch.device]): the device to cast to
"""
for attr, value in self.__dict__.items():
if isinstance(value, Tensor):
self.__dict__[attr] = value.to(device)
self.world.to(device)
def env_make_world(self, batch_dim: int, device: torch.device, **kwargs) -> World:
# Do not override
self._world = self.make_world(batch_dim, device, **kwargs)
return self._world
def env_reset_world_at(self, env_index: typing.Optional[int]):
# Do not override
self.world.reset(env_index)
self.reset_world_at(env_index)
def env_process_action(self, agent: Agent):
# Do not override
if agent.action_script is not None:
agent.action_callback(self.world)
# Customizable action processor
self.process_action(agent)
agent.dynamics.check_and_process_action()
[docs]
@abstractmethod
def make_world(self, batch_dim: int, device: torch.device, **kwargs) -> World:
"""
This function needs to be implemented when creating a scenario.
In this function the user should instantiate the world and insert agents and landmarks in it.
Args:
batch_dim (int): the number of vecotrized environments.
device (Union[str, int, torch.device], optional): the device of the environmemnt.
kwargs (dict, optional): named arguments passed from environment creation
Returns:
:class:`~vmas.simulator.core.World` : the :class:`~vmas.simulator.core.World`
instance which is automatically set in :class:`~world`.
Examples:
>>> from vmas.simulator.core import Agent, World, Landmark, Sphere, Box
>>> from vmas.simulator.scenario import BaseScenario
>>> from vmas.simulator.utils import Color
>>> class Scenario(BaseScenario):
>>> def make_world(self, batch_dim: int, device: torch.device, **kwargs):
... # Pass any kwargs you desire when creating the environment
... n_agents = kwargs.get("n_agents", 5)
...
... # Create world
... world = World(batch_dim, device, dt=0.1, drag=0.25, dim_c=0)
... # Add agents
... for i in range(n_agents):
... agent = Agent(
... name=f"agent {i}",
... collide=True,
... mass=1.0,
... shape=Sphere(radius=0.04),
... max_speed=None,
... color=Color.BLUE,
... u_range=1.0,
... )
... world.add_agent(agent)
... # Add landmarks
... for i in range(5):
... landmark = Landmark(
... name=f"landmark {i}",
... collide=True,
... movable=False,
... shape=Box(length=0.3,width=0.1),
... color=Color.RED,
... )
... world.add_landmark(landmark)
... return world
"""
raise NotImplementedError()
[docs]
@abstractmethod
def reset_world_at(self, env_index: Optional[int] = None):
"""Resets the world at the specified env_index.
When a ``None`` index is passed, the world should make a vectorized (batched) reset.
The ``entity.set_x()`` methods already have this logic integrated and will perform
batched operations when index is ``None``.
When this function is called, all entities have already had their state reset to zeros according to the ``env_index``.
In this function you shoud change the values of the reset states according to your task.
For example, some functions you might want to use are:
- ``entity.set_pos()``,
- ``entity.set_vel()``,
- ``entity.set_rot()``,
- ``entity.set_ang_vel()``.
Implementors can access the world at :class:`world`.
To increase performance, torch tensors should be created with the device already set, like:
``torch.tensor(..., device=self.world.device)``
Args:
env_index (int, otpional): index of the environment to reset. If ``None`` a vectorized reset should be performed.
Spawning at fixed positions
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> import torch
>>> class Scenario(BaseScenario):
>>> def reset_world_at(self, env_index)
... for i, agent in enumerate(self.world.agents):
... agent.set_pos(
... torch.tensor(
... [-0.2 + 0.1 * i, 1.0],
... dtype=torch.float32,
... device=self.world.device,
... ),
... batch_index=env_index,
... )
... for i, landmark in enumerate(self.world.landmarks):
... landmark.set_pos(
... torch.tensor(
... [0.2 if i % 2 else -0.2, 0.6 - 0.3 * i],
... dtype=torch.float32,
... device=self.world.device,
... ),
... batch_index=env_index,
... )
... landmark.set_rot(
... torch.tensor(
... [torch.pi / 4 if i % 2 else -torch.pi / 4],
... dtype=torch.float32,
... device=self.world.device,
... ),
... batch_index=env_index,
... )
Spawning at random positions
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> from vmas.simulator.utils import ScenarioUtils
>>> class Scenario(BaseScenario):
>>> def reset_world_at(self, env_index)
>>> ScenarioUtils.spawn_entities_randomly(
... self.world.agents + self.world.landmarks,
... self.world,
... env_index,
... min_dist_between_entities=0.02,
... x_bounds=(-1.0,1.0),
... y_bounds=(-1.0,1.0),
... )
"""
raise NotImplementedError()
[docs]
@abstractmethod
def observation(self, agent: Agent) -> AGENT_OBS_TYPE:
"""This function computes the observations for ``agent`` in a vectorized way.
The returned tensor should contain the observations for ``agent`` in all envs and should have
shape ``(self.world.batch_dim, n_agent_obs)``, or be a dict with leaves following that shape.
Implementors can access the world at :class:`world`.
To increase performance, torch tensors should be created with the device already set, like:
``torch.tensor(..., device=self.world.device)``
Args:
agent (Agent): the agent to compute the observations for
Returns:
Union[torch.Tensor, Dict[str, torch.Tensor]]: the observation
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> import torch
>>> class Scenario(BaseScenario):
>>> def observation(self, agent):
... # get positions of all landmarks in this agent's reference frame
... landmark_rel_poses = []
... for landmark in self.world.landmarks:
... landmark_rel_poses.append(landmark.state.pos - agent.state.pos)
... return torch.cat([agent.state.pos, agent.state.vel, *landmark_rel_poses], dim=-1)
You can also return observations in a dictionary
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> from vmas.simulator.utils import Color
>>> class Scenario(BaseScenario):
>>> def observation(self, agent):
... return {"pos": agent.state.pos, "vel": agent.state.vel}
"""
raise NotImplementedError()
[docs]
@abstractmethod
def reward(self, agent: Agent) -> AGENT_REWARD_TYPE:
"""This function computes the reward for ``agent`` in a vectorized way.
The returned tensor should contain the reward for ``agent`` in all envs and should have
shape ``(self.world.batch_dim)`` and dtype ``torch.float``.
Implementors can access the world at :class:`world`.
To increase performance, torch tensors should be created with the device already set, like:
``torch.tensor(..., device=self.world.device)``
Args:
agent (Agent): the agent to compute the reward for
Returns:
torch.Tensor: reward tensor of shape ``(self.world.batch_dim)``
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> import torch
>>> class Scenario(BaseScenario):
>>> def reward(self, agent):
... # reward every agent proportionally to distance from first landmark
... rew = -torch.linalg.vector_norm(agent.state.pos - self.world.landmarks[0].state.pos, dim=-1)
... return rew
"""
raise NotImplementedError()
[docs]
def done(self):
"""This function computes the done flag for each env in a vectorized way.
The returned tensor should contain the ``done`` for all envs and should have
shape ``(n_envs)`` and dtype ``torch.bool``.
Implementors can access the world at :class:`world`.
To increase performance, torch tensors should be created with the device already set, like:
``torch.tensor(..., device=self.world.device)``
By default, this function returns all ``False`` s.
The scenario can still be done if ``max_steps`` has been set at envirtonment construction.
Returns:
torch.Tensor: done tensor of shape ``(self.world.batch_dim)``
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> import torch
>>> class Scenario(BaseScenario):
>>> def done(self):
... # retrun done when all agents have battery level lower than a threshold
... return torch.stack([a.battery_level < threshold for a in self.world.agents], dim=-1).all(-1)
"""
return torch.tensor([False], device=self.world.device).expand(
self.world.batch_dim
)
[docs]
def info(self, agent: Agent) -> AGENT_INFO_TYPE:
"""This function computes the info dict for ``agent`` in a vectorized way.
The returned dict should have a key for each info of interest and the corresponding value should
be a tensor of shape ``(n_envs, info_size)``
By default this function returns an empty dictionary.
Implementors can access the world at :class:`world`.
To increase performance, torch tensors should be created with the device already set, like:
``torch.tensor(..., device=self.world.device)``
Args:
agent (Agent): the agent to compute the info for
Returns:
Union[torch.Tensor, Dict[str, torch.Tensor]]: the info
"""
return {}
[docs]
def process_action(self, agent: Agent):
"""This function can be overridden to process the agent actions before the simulation step.
It has access to the world through the :class:`world` attribute
It can also be used for any other type of computation that has to happen after
the input actions have been set but before the simulation step.
For example here you can manage additional actions before passing them to the dynamics.
Args:
agent (Agent): the agent process the action of
Examples:
>>> from vmas.simulator.scenario import BaseScenario
>>> from vmas.simulator.utils import TorchUtils
>>> class Scenario(BaseScenario):
>>> def process_action(self, agent):
>>> # Clamp square to circle
>>> agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, agent.u_range)
>>> # Can use a PID controller
>>> agent.controller.process_force()
>>> return
"""
return