# Copyright (c) 2022-2024.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
from typing import Optional, Union
from vmas import scenarios
from vmas.simulator.environment import Environment, Wrapper
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import DEVICE_TYPING
[docs]
def make_env(
scenario: Union[str, BaseScenario],
num_envs: int,
device: DEVICE_TYPING = "cpu",
continuous_actions: bool = True,
wrapper: Optional[Union[Wrapper, str]] = None,
max_steps: Optional[int] = None,
seed: Optional[int] = None,
dict_spaces: bool = False,
multidiscrete_actions: bool = False,
clamp_actions: bool = False,
grad_enabled: bool = False,
terminated_truncated: bool = False,
wrapper_kwargs: Optional[dict] = None,
**kwargs,
):
"""Create a vmas environment.
Args:
scenario (Union[str, BaseScenario]): Scenario to load.
Can be the name of a file in `vmas.scenarios` folder or a :class:`~vmas.simulator.scenario.BaseScenario` class,
num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch.
This argument indicates the number of vectorized environments that should be simulated in a batch. It will also
determine the batch size of the environment.
device (Union[str, int, torch.device], optional): Device for simulation. All the tensors created by VMAS
will be placed on this device. Default is ``"cpu"``,
continuous_actions (bool, optional): Whether to use continuous actions. If ``False``, actions
will be discrete. The number of actions and their size will depend on the chosen scenario. Default is ``True``,
wrapper (Union[Wrapper, str], optional): Wrapper class to use. For example, it can be
``"rllib"``, ``"gym"``, ``"gymnasium"``, ``"gymnasium_vec"``. Default is ``None``.
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can
be terminating or not. If ``max_steps`` is specified,
the scenario is also terminated whenever this horizon is reached,
seed (int, optional): Seed for the environment. Defaults to ``None``,
dict_spaces (bool, optional): Weather to use dictionaries spaces with format ``{"agent_name": tensor, ...}``
for obs, rewards, and info instead of tuples. Defaults to ``False``: obs, rewards, info are tuples with length number of agents,
multidiscrete_actions (bool, optional): Whether to use multidiscrete action spaces when ``continuous_actions=False``.
Default is ``False``: the action space will be ``Discrete``, and it will be the cartesian product of the
discrete action spaces available to an agent,
clamp_actions (bool, optional): Weather to clamp input actions to their range instead of throwing
an error when ``continuous_actions==True`` and actions are out of bounds,
grad_enabled (bool, optional): If ``True`` the simulator will not call ``detach()`` on input actions and gradients can
be taken from the simulator output. Default is ``False``.
terminated_truncated (bool, optional): Weather to use terminated and truncated flags in the output of the step method (or single done).
Default is ``False``.
wrapper_kwargs (dict, optional): Keyword arguments to pass to the wrapper class. Default is ``{}``.
**kwargs (dict, optional): Keyword arguments to pass to the :class:`~vmas.simulator.scenario.BaseScenario` class.
Examples:
>>> from vmas import make_env
>>> env = make_env(
... "waterfall",
... num_envs=3,
... num_agents=2,
... )
>>> print(env.reset())
"""
# load scenario from script
if isinstance(scenario, str):
if not scenario.endswith(".py"):
scenario += ".py"
scenario = scenarios.load(scenario).Scenario()
env = Environment(
scenario,
num_envs=num_envs,
device=device,
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
dict_spaces=dict_spaces,
multidiscrete_actions=multidiscrete_actions,
clamp_actions=clamp_actions,
grad_enabled=grad_enabled,
terminated_truncated=terminated_truncated,
**kwargs,
)
if wrapper is not None and isinstance(wrapper, str):
wrapper = Wrapper[wrapper.upper()]
if wrapper_kwargs is None:
wrapper_kwargs = {}
return wrapper.get_env(env, **wrapper_kwargs) if wrapper is not None else env