Source code for vmas.make_env

#  Copyright (c) 2022-2024.
#  ProrokLab (https://www.proroklab.org/)
#  All rights reserved.

from typing import Optional, Union

from vmas import scenarios
from vmas.simulator.environment import Environment, Wrapper
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import DEVICE_TYPING


[docs] def make_env( scenario: Union[str, BaseScenario], num_envs: int, device: DEVICE_TYPING = "cpu", continuous_actions: bool = True, wrapper: Optional[Union[Wrapper, str]] = None, max_steps: Optional[int] = None, seed: Optional[int] = None, dict_spaces: bool = False, multidiscrete_actions: bool = False, clamp_actions: bool = False, grad_enabled: bool = False, terminated_truncated: bool = False, wrapper_kwargs: Optional[dict] = None, **kwargs, ): """Create a vmas environment. Args: scenario (Union[str, BaseScenario]): Scenario to load. Can be the name of a file in `vmas.scenarios` folder or a :class:`~vmas.simulator.scenario.BaseScenario` class, num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch. This argument indicates the number of vectorized environments that should be simulated in a batch. It will also determine the batch size of the environment. device (Union[str, int, torch.device], optional): Device for simulation. All the tensors created by VMAS will be placed on this device. Default is ``"cpu"``, continuous_actions (bool, optional): Whether to use continuous actions. If ``False``, actions will be discrete. The number of actions and their size will depend on the chosen scenario. Default is ``True``, wrapper (Union[Wrapper, str], optional): Wrapper class to use. For example, it can be ``"rllib"``, ``"gym"``, ``"gymnasium"``, ``"gymnasium_vec"``. Default is ``None``. max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can be terminating or not. If ``max_steps`` is specified, the scenario is also terminated whenever this horizon is reached, seed (int, optional): Seed for the environment. Defaults to ``None``, dict_spaces (bool, optional): Weather to use dictionaries spaces with format ``{"agent_name": tensor, ...}`` for obs, rewards, and info instead of tuples. Defaults to ``False``: obs, rewards, info are tuples with length number of agents, multidiscrete_actions (bool, optional): Whether to use multidiscrete action spaces when ``continuous_actions=False``. Default is ``False``: the action space will be ``Discrete``, and it will be the cartesian product of the discrete action spaces available to an agent, clamp_actions (bool, optional): Weather to clamp input actions to their range instead of throwing an error when ``continuous_actions==True`` and actions are out of bounds, grad_enabled (bool, optional): If ``True`` the simulator will not call ``detach()`` on input actions and gradients can be taken from the simulator output. Default is ``False``. terminated_truncated (bool, optional): Weather to use terminated and truncated flags in the output of the step method (or single done). Default is ``False``. wrapper_kwargs (dict, optional): Keyword arguments to pass to the wrapper class. Default is ``{}``. **kwargs (dict, optional): Keyword arguments to pass to the :class:`~vmas.simulator.scenario.BaseScenario` class. Examples: >>> from vmas import make_env >>> env = make_env( ... "waterfall", ... num_envs=3, ... num_agents=2, ... ) >>> print(env.reset()) """ # load scenario from script if isinstance(scenario, str): if not scenario.endswith(".py"): scenario += ".py" scenario = scenarios.load(scenario).Scenario() env = Environment( scenario, num_envs=num_envs, device=device, continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, dict_spaces=dict_spaces, multidiscrete_actions=multidiscrete_actions, clamp_actions=clamp_actions, grad_enabled=grad_enabled, terminated_truncated=terminated_truncated, **kwargs, ) if wrapper is not None and isinstance(wrapper, str): wrapper = Wrapper[wrapper.upper()] if wrapper_kwargs is None: wrapper_kwargs = {} return wrapper.get_env(env, **wrapper_kwargs) if wrapper is not None else env