Merge remote-tracking branch 'origin/dev' into feature/2137-refactor-request-api
This commit is contained in:
@@ -1 +1 @@
|
||||
3.0.0b4dev
|
||||
3.0.0b6
|
||||
|
||||
@@ -14,6 +14,8 @@ io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
|
||||
|
||||
game:
|
||||
@@ -29,7 +31,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -110,10 +112,8 @@ agents:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: DatabaseClient
|
||||
- service_name: web_server_web_service
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
@@ -302,63 +302,63 @@ agents:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22:
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
@@ -507,6 +507,24 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -526,7 +544,7 @@ agents:
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
|
||||
@@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb: str = "delete"
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["do_nothing"]
|
||||
# return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid]
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
@@ -443,27 +453,33 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0, 1]:
|
||||
if source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_ip_id == 1:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
if source_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0, 1):
|
||||
if source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
if dest_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
@@ -943,13 +959,22 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
node_ref = entry["node_ref"]
|
||||
nic_num = entry["nic_num"]
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_ref)
|
||||
ip_address = node_obj.ethernet_port[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
obj = cls(
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=None,
|
||||
ip_address_list=ip_address_list or None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
@@ -46,3 +45,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
super().reset_agent_for_episode()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
@@ -44,6 +44,8 @@ class AgentSettings(BaseModel):
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
@@ -134,6 +136,10 @@ class AbstractAgent(ABC):
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Agent reset logic should go here."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
@@ -166,6 +172,7 @@ class ProxyAgent(AbstractAgent):
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
@@ -174,6 +181,7 @@ class ProxyAgent(AbstractAgent):
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -78,7 +79,7 @@ class FileObservation(AbstractObservation):
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["health_status"]}
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -204,12 +205,15 @@ class LinkObservation(AbstractObservation):
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": utilisation_category}}
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -554,7 +558,7 @@ class NodeObservation(AbstractObservation):
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
@@ -644,10 +648,13 @@ class AclObservation(AbstractObservation):
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
@@ -656,15 +663,26 @@ class AclObservation(AbstractObservation):
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -110,10 +110,17 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.info(
|
||||
f"Could not calculate {self.__class__} reward because "
|
||||
"simulation state did not contain enough information."
|
||||
)
|
||||
return 0.0
|
||||
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == "corrupted":
|
||||
if health_status == 2:
|
||||
return -1
|
||||
elif health_status == "good":
|
||||
elif health_status == 1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
@@ -117,7 +115,7 @@ class PrimaiteGame:
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
agent_actions = self.apply_agent_actions() # noqa
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -135,12 +133,15 @@ class PrimaiteGame:
|
||||
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
agent_actions = {}
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
return agent_actions
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
@@ -164,6 +165,7 @@ class PrimaiteGame:
|
||||
self.simulation.reset_component_for_episode(episode=self.episode_counter)
|
||||
for agent in self.agents:
|
||||
agent.reward_function.total_reward = 0.0
|
||||
agent.reset_agent_for_episode()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
@@ -228,31 +230,7 @@ class PrimaiteGame:
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
elif n_type == "router":
|
||||
new_node = Router(
|
||||
hostname=node_cfg["hostname"],
|
||||
num_ports=node_cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
# new_node.enable_port(port_num)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
new_node = Router.from_config(node_cfg)
|
||||
else:
|
||||
_LOGGER.warning(f"invalid node type {n_type} in config")
|
||||
if "services" in node_cfg:
|
||||
@@ -389,6 +367,7 @@ class PrimaiteGame:
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
game.rl_agents.append(new_agent)
|
||||
|
||||
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
@@ -39,6 +39,15 @@
|
||||
"#### Create a Ray algorithm and pass it our config."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(cfg['agents'][2]['agent_settings'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -76,6 +85,13 @@
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
@@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
@@ -81,13 +81,19 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
if self.agent.flatten_obs:
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
else:
|
||||
return self.agent.observation_manager.space
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if not self.agent.flatten_obs:
|
||||
return self.agent.observation_manager.current_observation
|
||||
else:
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
@@ -166,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -180,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {}
|
||||
infos = {"agent_actions": agent_actions}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
|
||||
@@ -24,9 +24,13 @@ class SessionIOSettings(BaseModel):
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
"""Whether to save system logs."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
@@ -39,9 +43,10 @@ class SessionIO:
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
@@ -54,7 +54,7 @@ class PrimaiteSession:
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
self.io_manager: Optional["SessionIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
@@ -101,9 +101,9 @@ class PrimaiteSession:
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"game": game})
|
||||
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
|
||||
|
||||
@@ -44,3 +44,12 @@ def run(overwrite_existing: bool = True) -> None:
|
||||
print(dst_fp)
|
||||
shutil.copy2(src_fp, dst_fp)
|
||||
_LOGGER.info(f"Reset example notebook: {dst_fp}")
|
||||
|
||||
for src_fp in primaite_root.glob("notebooks/_package_data/*"):
|
||||
dst_fp = example_notebooks_user_dir / "_package_data" / src_fp.name
|
||||
if should_copy_file(src_fp, dst_fp, overwrite_existing):
|
||||
if not Path.exists(example_notebooks_user_dir / "_package_data/"):
|
||||
Path.mkdir(example_notebooks_user_dir / "_package_data/")
|
||||
print(dst_fp)
|
||||
shutil.copy2(src_fp, dst_fp)
|
||||
_LOGGER.info(f"Copied notebook resource to: {dst_fp}")
|
||||
|
||||
@@ -7,11 +7,13 @@ from primaite import _PRIMAITE_ROOT
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
|
||||
class __SimOutput:
|
||||
class _SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
@@ -23,4 +25,4 @@ class __SimOutput:
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
SIM_OUTPUT = __SimOutput()
|
||||
SIM_OUTPUT = _SimOutput()
|
||||
|
||||
@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
|
||||
"List containing all the folders in the file system."
|
||||
deleted_folders: Dict[str, Folder] = {}
|
||||
"List containing all the folders that have been deleted."
|
||||
_folders_by_name: Dict[str, Folder] = {}
|
||||
sys_log: SysLog
|
||||
"Instance of SysLog used to create system logs."
|
||||
sim_root: Path
|
||||
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
|
||||
folder = Folder(name=folder_name, sys_log=self.sys_log)
|
||||
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
self._folder_request_manager.add_request(
|
||||
name=folder.name, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
|
||||
:param folder_name: The folder name.
|
||||
:return: The matching Folder.
|
||||
"""
|
||||
return self._folders_by_name.get(folder_name)
|
||||
for folder in self.folders.values():
|
||||
if folder.name == folder_name:
|
||||
return folder
|
||||
return None
|
||||
|
||||
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
|
||||
"""
|
||||
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name:
|
||||
# check if file with name already exists
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
# If not then create it
|
||||
if not folder:
|
||||
folder = self.create_folder(folder_name)
|
||||
else:
|
||||
# Use root folder if folder_name not supplied
|
||||
folder = self._folders_by_name["root"]
|
||||
folder = self.get_folder("root")
|
||||
|
||||
# Create the file and add it to the folder
|
||||
file = File(
|
||||
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
|
||||
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
if folder.deleted:
|
||||
self.deleted_folders.pop(folder.uuid)
|
||||
|
||||
@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
files: Dict[str, File] = {}
|
||||
"Files stored in the folder."
|
||||
_files_by_name: Dict[str, File] = {}
|
||||
"Files by their name as <file name>.<file type>."
|
||||
deleted_files: Dict[str, File] = {}
|
||||
"Files that have been deleted."
|
||||
|
||||
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
@@ -105,7 +101,7 @@ class Folder(FileSystemItemABC):
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(func=lambda request, context: self._file_request_manager),
|
||||
request_type=RequestType(func=self._file_request_manager),
|
||||
)
|
||||
return rm
|
||||
|
||||
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
|
||||
:return: The matching File.
|
||||
"""
|
||||
# TODO: Increment read count?
|
||||
return self._files_by_name.get(file_name)
|
||||
for file in self.files.values():
|
||||
if file.name == file_name:
|
||||
return file
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
|
||||
"""
|
||||
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
|
||||
raise Exception(f"Invalid file: {file}")
|
||||
|
||||
# check if file with id or name already exists in folder
|
||||
if (force is not True) and file.name in self._files_by_name:
|
||||
if self.get_file(file.name) is not None and not force:
|
||||
raise Exception(f"File with name {file.name} already exists in folder")
|
||||
|
||||
if (force is not True) and file.uuid in self.files:
|
||||
if (file.uuid in self.files) and not force:
|
||||
raise Exception(f"File with uuid {file.uuid} already exists in folder")
|
||||
|
||||
# add to list
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
|
||||
file.folder = self
|
||||
|
||||
@@ -275,11 +273,9 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
if self.files.get(file.uuid):
|
||||
self.files.pop(file.uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
self.deleted_files[file.uuid] = file
|
||||
file.delete()
|
||||
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
|
||||
self._file_request_manager.remove_request(file.uuid)
|
||||
else:
|
||||
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
|
||||
|
||||
@@ -300,7 +296,6 @@ class Folder(FileSystemItemABC):
|
||||
self.deleted_files[file_id] = file
|
||||
|
||||
self.files = {}
|
||||
self._files_by_name = {}
|
||||
|
||||
def restore_file(self, file_uuid: str):
|
||||
"""
|
||||
@@ -316,7 +311,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file_uuid)
|
||||
|
||||
148
src/primaite/simulator/network/creation.py
Normal file
148
src/primaite/simulator/network/creation.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
Each switch is assumed to have one port reserved for connecting to a router, reducing the effective
|
||||
number of ports available for PCs. The function calculates the total number of switches needed
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
Example:
|
||||
>>> num_of_switches_required(5)
|
||||
1
|
||||
>>> num_of_switches_required(24,24)
|
||||
2
|
||||
>>> num_of_switches_required(48,24)
|
||||
3
|
||||
>>> num_of_switches_required(25,10)
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
|
||||
|
||||
def create_office_lan(
|
||||
lan_name: str,
|
||||
subnet_base: int,
|
||||
pcs_ip_block_start: int,
|
||||
num_pcs: int,
|
||||
network: Optional[Network] = None,
|
||||
include_router: bool = True,
|
||||
) -> Network:
|
||||
"""
|
||||
Creates a 2-Tier or 3-Tier office local area network (LAN).
|
||||
|
||||
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
|
||||
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
|
||||
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
|
||||
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
|
||||
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
|
||||
|
||||
|
||||
:param str lan_name: The name to be assigned to the LAN.
|
||||
:param int subnet_base: The subnet base number to be used in the IP addresses.
|
||||
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
|
||||
:param int num_pcs: The number of PCs to be added to the LAN.
|
||||
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
|
||||
created.
|
||||
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
|
||||
:return: The network object with the LAN components added.
|
||||
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
|
||||
"""
|
||||
# Initialise the network if not provided
|
||||
if not network:
|
||||
network = Network()
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
# Create a core switch if more than one edge switch is needed
|
||||
if num_of_switches > 1:
|
||||
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
|
||||
core_switch.power_on()
|
||||
network.add_node(core_switch)
|
||||
core_switch_port = 1
|
||||
|
||||
# Initialise the default gateway to None
|
||||
default_gateway = None
|
||||
|
||||
# Optionally include a router in the LAN
|
||||
if include_router:
|
||||
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
|
||||
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
network.add_node(router)
|
||||
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
|
||||
router.enable_port(1)
|
||||
|
||||
# Initialise the first edge switch and connect to the router or core switch
|
||||
switch_port = 0
|
||||
switch_n = 1
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
if num_of_switches > 1:
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
# Connect the new switch to the router or core switch
|
||||
if num_of_switches > 1:
|
||||
core_switch_port += 1
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Create and add a PC to the network
|
||||
pc = Computer(
|
||||
hostname=f"pc_{i}_{lan_name}",
|
||||
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway=default_gateway,
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc.power_on()
|
||||
network.add_node(pc)
|
||||
|
||||
# Connect the PC to the switch
|
||||
switch_port += 1
|
||||
network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1])
|
||||
switch.switch_ports[switch_port].enable()
|
||||
|
||||
return network
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import secrets
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -274,18 +274,40 @@ class NIC(SimComponent):
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive a network frame from the connected link if the NIC is enabled.
|
||||
Receive a network frame from the connected link, processing it if the NIC is enabled.
|
||||
|
||||
The Frame is passed to the Node.
|
||||
This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks
|
||||
if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the
|
||||
connected node. The method also handles the discarding of frames with TTL expired and logs this event.
|
||||
|
||||
:param frame: The network frame being received.
|
||||
The frame's reception is based on various conditions:
|
||||
- If the NIC is disabled, the frame is not processed.
|
||||
- If the TTL of the frame reaches zero after decrement, it is discarded and logged.
|
||||
- If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted.
|
||||
- All other frames are dropped and logged or printed to the console.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
accept_frame = False
|
||||
|
||||
# Check if it's a broadcast:
|
||||
if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}:
|
||||
accept_frame = True
|
||||
else:
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address:
|
||||
accept_frame = True
|
||||
|
||||
if accept_frame:
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
@@ -436,6 +458,9 @@ class SwitchPort(SimComponent):
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
self.pcap.capture(frame)
|
||||
connected_node: Node = self._connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
@@ -671,17 +696,30 @@ class ARPCache:
|
||||
"""Clear the entire ARP cache, removing all stored entries."""
|
||||
self.arp.clear()
|
||||
|
||||
def send_arp_request(self, target_ip_address: Union[IPv4Address, str]):
|
||||
def send_arp_request(
|
||||
self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None
|
||||
):
|
||||
"""
|
||||
Perform a standard ARP request for a given target IP address.
|
||||
|
||||
Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP
|
||||
address.
|
||||
address. This method can be configured to ignore specific networks when sending out ARP requests,
|
||||
which is useful in environments where certain addresses should not be queried.
|
||||
|
||||
:param target_ip_address: The target IP address to send an ARP request for.
|
||||
:param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP
|
||||
request broadcast. Each address in this list indicates a network which will not be queried during the ARP
|
||||
request process. This is particularly useful in complex network environments where traffic should be
|
||||
minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being
|
||||
sent back to their source.
|
||||
"""
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled:
|
||||
use_nic = True
|
||||
if ignore_networks:
|
||||
for ipv4 in ignore_networks:
|
||||
if ipv4 in nic.ip_network:
|
||||
use_nic = False
|
||||
if nic.enabled and use_nic:
|
||||
self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}")
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
@@ -806,7 +844,6 @@ class ICMP:
|
||||
self.arp.send_arp_request(frame.ip.src_ip_address)
|
||||
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
|
||||
return
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
# Network Layer
|
||||
ip_packet = IPPacket(
|
||||
@@ -821,9 +858,7 @@ class ICMP:
|
||||
sequence=frame.icmp.sequence + 1,
|
||||
)
|
||||
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
|
||||
frame = Frame(
|
||||
ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload
|
||||
)
|
||||
frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload)
|
||||
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
|
||||
|
||||
src_nic.send_frame(frame)
|
||||
@@ -1275,8 +1310,8 @@ class Node(SimComponent):
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self._start_up_actions()
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Turned on")
|
||||
for nic in self.nics.values():
|
||||
if nic._connected_link:
|
||||
@@ -1450,7 +1485,7 @@ class Node(SimComponent):
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.info(f"Added service {service.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
|
||||
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
@@ -1485,7 +1520,7 @@ class Node(SimComponent):
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.info(f"Added application {application.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
class ACLAction(Enum):
|
||||
"""Enum for defining the ACL action types."""
|
||||
|
||||
DENY = 0
|
||||
PERMIT = 1
|
||||
DENY = 2
|
||||
|
||||
|
||||
class ACLRule(SimComponent):
|
||||
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24
|
||||
_default_config: Dict[int, dict] = {}
|
||||
"""Config dict describing how the ACL list should look at episode start"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
if not kwargs.get("implicit_action"):
|
||||
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
for i, rule in enumerate(self._acl):
|
||||
if not rule:
|
||||
continue
|
||||
self._default_config[i] = {"action": rule.action.name}
|
||||
if rule.src_ip_address:
|
||||
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
|
||||
if rule.dst_ip_address:
|
||||
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
|
||||
if rule.src_port:
|
||||
self._default_config[i]["src_port"] = rule.src_port.name
|
||||
if rule.dst_port:
|
||||
self._default_config[i]["dst_port"] = rule.dst_port.name
|
||||
if rule.protocol:
|
||||
self._default_config[i]["protocol"] = rule.protocol.name
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -129,9 +162,9 @@ class AccessControlList(SimComponent):
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
None if request[3] == "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
None if request[5] == "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
@@ -333,11 +366,10 @@ class RouteEntry(SimComponent):
|
||||
"""
|
||||
Represents a single entry in a routing table.
|
||||
|
||||
Attributes:
|
||||
address (IPv4Address): The destination IP address or network address.
|
||||
subnet_mask (IPv4Address): The subnet mask for the network.
|
||||
next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded.
|
||||
metric (int): The cost metric for this route. Default is 0.0.
|
||||
:ivar address: The destination IP address or network address.
|
||||
:ivar subnet_mask: The subnet mask for the network.
|
||||
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
|
||||
:ivar metric: The cost metric for this route. Default is 0.0.
|
||||
|
||||
Example:
|
||||
>>> entry = RouteEntry(
|
||||
@@ -357,12 +389,6 @@ class RouteEntry(SimComponent):
|
||||
metric: float = 0.0
|
||||
"The cost metric for this route. Default is 0.0."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key in {"address", "subnet_mask", "next_hop_ip_address"}:
|
||||
if not isinstance(kwargs[key], IPv4Address):
|
||||
kwargs[key] = IPv4Address(kwargs[key])
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
|
||||
@@ -397,10 +423,10 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
|
||||
routes: List[RouteEntry] = []
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
self._original_state["routes_orig"] = self.routes
|
||||
@@ -442,12 +468,35 @@ class RouteTable(SimComponent):
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
|
||||
"""
|
||||
Sets the next-hop IP address for the default route in a routing table.
|
||||
|
||||
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
|
||||
the method creates a new default route with the specified next-hop IP address. If a default route already
|
||||
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
|
||||
IP address, the method logs this action.
|
||||
|
||||
:param ip_address: The next-hop IP address to be set for the default route.
|
||||
"""
|
||||
if not self.default_route:
|
||||
self.default_route = RouteEntry(
|
||||
ip_address=IPv4Address("0.0.0.0"),
|
||||
subnet_mask=IPv4Address("0.0.0.0"),
|
||||
next_hop_ip_address=ip_address,
|
||||
)
|
||||
else:
|
||||
self.default_route.next_hop_ip_address = ip_address
|
||||
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
|
||||
|
||||
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
|
||||
"""
|
||||
Find the best route for a given destination IP.
|
||||
|
||||
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
|
||||
|
||||
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
|
||||
|
||||
:param destination_ip: The destination IP to find the route for.
|
||||
:return: The best matching RouteEntry, or None if no route matches.
|
||||
"""
|
||||
@@ -467,6 +516,9 @@ class RouteTable(SimComponent):
|
||||
longest_prefix = prefix_len
|
||||
lowest_metric = route.metric
|
||||
|
||||
if not best_route and self.default_route:
|
||||
best_route = self.default_route
|
||||
|
||||
return best_route
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -498,12 +550,26 @@ class RouterARPCache(ARPCache):
|
||||
super().__init__(sys_log)
|
||||
self.router: Router = router
|
||||
|
||||
def process_arp_packet(self, from_nic: NIC, frame: Frame):
|
||||
def process_arp_packet(
|
||||
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Overridden method to process a received ARP packet in a router-specific way.
|
||||
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
|
||||
|
||||
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
|
||||
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
|
||||
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
|
||||
handling packet TTL (Time To Live).
|
||||
|
||||
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
|
||||
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
|
||||
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
|
||||
to find the best route and reattempts ARP request processing if needed.
|
||||
|
||||
:param from_nic: The NIC that received the ARP packet.
|
||||
:param frame: The original ARP frame.
|
||||
:param frame: The frame containing the ARP packet.
|
||||
:param route_table: The routing table of the router.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
|
||||
"""
|
||||
arp_packet = frame.arp
|
||||
|
||||
@@ -531,7 +597,11 @@ class RouterARPCache(ARPCache):
|
||||
)
|
||||
arp_packet.sender_mac_addr = nic.mac_address
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
|
||||
# ARP Request
|
||||
self.sys_log.info(
|
||||
@@ -542,16 +612,32 @@ class RouterARPCache(ARPCache):
|
||||
self.add_arp_cache_entry(
|
||||
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
|
||||
)
|
||||
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_packet, from_nic)
|
||||
|
||||
# If the target IP matches one of the router's NICs
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
if arp_packet.target_ip_address in nic.ip_network:
|
||||
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
# Check Route Table
|
||||
route = route_table.find_best_route(arp_packet.target_ip_address)
|
||||
if route:
|
||||
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
|
||||
if not nic:
|
||||
if not is_reattempt:
|
||||
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
|
||||
else:
|
||||
self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found")
|
||||
return
|
||||
else:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
|
||||
class RouterICMP(ICMP):
|
||||
"""
|
||||
@@ -622,7 +708,7 @@ class RouterICMP(ICMP):
|
||||
return
|
||||
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
|
||||
for nic in self.router.nics.values():
|
||||
@@ -642,7 +728,48 @@ class RouterICMP(ICMP):
|
||||
|
||||
return
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
|
||||
class RouterNIC(NIC):
|
||||
"""
|
||||
A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality.
|
||||
|
||||
This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed
|
||||
to handle network frames in a manner specific to routers, allowing them to efficiently process
|
||||
and route network traffic.
|
||||
"""
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive and process a network frame from the connected link, provided the NIC is enabled.
|
||||
|
||||
This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL
|
||||
expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for
|
||||
this NIC's MAC address or is a broadcast frame.
|
||||
|
||||
Key Differences from Standard NIC:
|
||||
- Does not perform Layer 3 (IP-based) broadcast checks.
|
||||
- Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the connected node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.mac_address}/{self.ip_address}"
|
||||
|
||||
|
||||
class Router(Node):
|
||||
@@ -655,7 +782,7 @@ class Router(Node):
|
||||
"""
|
||||
|
||||
num_ports: int
|
||||
ethernet_ports: Dict[int, NIC] = {}
|
||||
ethernet_ports: Dict[int, RouterNIC] = {}
|
||||
acl: AccessControlList
|
||||
route_table: RouteTable
|
||||
arp: RouterARPCache
|
||||
@@ -674,7 +801,7 @@ class Router(Node):
|
||||
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
for i in range(1, self.num_ports + 1):
|
||||
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(nic)
|
||||
self.ethernet_ports[i] = nic
|
||||
|
||||
@@ -725,13 +852,13 @@ class Router(Node):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
state["num_ports"] = self.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
Route a given frame from a source NIC to its destination.
|
||||
Process a Frame.
|
||||
|
||||
:param frame: The frame to be routed.
|
||||
:param from_nic: The source network interface.
|
||||
@@ -746,25 +873,57 @@ class Router(Node):
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address)
|
||||
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
self.arp.send_arp_request(
|
||||
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
|
||||
)
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
# TODO: Add sys_log here
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
if frame.ip.dst_ip_address in nic.ip_network:
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
else:
|
||||
pass
|
||||
# TODO: Deal with routing from route tables
|
||||
self._route_frame(frame, from_nic)
|
||||
|
||||
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
|
||||
if route:
|
||||
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
|
||||
if re_attempt and not nic:
|
||||
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
"""
|
||||
@@ -773,7 +932,7 @@ class Router(Node):
|
||||
:param frame: The incoming frame.
|
||||
:param from_nic: The network interface where the frame is coming from.
|
||||
"""
|
||||
route_frame = False
|
||||
process_frame = False
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
@@ -805,12 +964,12 @@ class Router(Node):
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
else:
|
||||
if src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
|
||||
else:
|
||||
# All other traffic
|
||||
route_frame = True
|
||||
if route_frame:
|
||||
self.route_frame(frame, from_nic)
|
||||
process_frame = True
|
||||
if process_frame:
|
||||
self.process_frame(frame, from_nic)
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
@@ -873,3 +1032,63 @@ class Router(Node):
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
new = Router(
|
||||
hostname=cfg["hostname"],
|
||||
num_ports=cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
new.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
return new
|
||||
|
||||
@@ -90,12 +90,12 @@ class Switch(Node):
|
||||
self._add_mac_table_entry(src_mac, incoming_port)
|
||||
|
||||
outgoing_port = self.mac_address_table.get(dst_mac)
|
||||
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
|
||||
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
|
||||
outgoing_port.send_frame(frame)
|
||||
else:
|
||||
# If the destination MAC is not in the table, flood to all ports except incoming
|
||||
for port in self.switch_ports.values():
|
||||
if port != incoming_port:
|
||||
if port.enabled and port != incoming_port:
|
||||
port.send_frame(frame)
|
||||
|
||||
def disconnect_link_from_port(self, link: Link, port_number: int):
|
||||
|
||||
@@ -38,9 +38,6 @@ class Application(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
@@ -95,6 +92,9 @@ class Application(IOSoftware):
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Running Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def _application_loop(self):
|
||||
"""The main application loop."""
|
||||
|
||||
@@ -72,7 +72,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -83,7 +83,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
payload: Optional[str] = None,
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
data_manipulation_p_of_success: float = 0.1,
|
||||
repeat: bool = False,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the DataManipulatorBot to communicate with a DatabaseService.
|
||||
@@ -168,6 +168,12 @@ class DataManipulationBot(DatabaseClient):
|
||||
Calls the parent classes execute method before starting the application loop.
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
@@ -198,4 +204,4 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
pass
|
||||
|
||||
@@ -41,6 +41,9 @@ class PacketCapture:
|
||||
|
||||
def setup_logger(self):
|
||||
"""Set up the logger configuration."""
|
||||
if not SIM_OUTPUT.save_pcap_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
@@ -88,5 +91,6 @@ class PacketCapture:
|
||||
|
||||
:param frame: The PCAP frame to capture.
|
||||
"""
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -141,41 +141,76 @@ class SessionManager:
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
"""
|
||||
Receive a payload from the SoftwareManager.
|
||||
Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission.
|
||||
|
||||
If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``.
|
||||
This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an
|
||||
IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP.
|
||||
A new session is established if `session_id` is not provided, and an existing session is used otherwise.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
:param dst_ip_address: The destination IP address or network for broadcast. Optional.
|
||||
:param dst_port: The destination port for the TCP packet. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False.
|
||||
:return: The outcome of sending the frame, or None if sending was unsuccessful.
|
||||
"""
|
||||
is_broadcast = False
|
||||
outbound_nic = None
|
||||
dst_mac_address = None
|
||||
|
||||
# Use session details if session_id is provided
|
||||
if session_id:
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
dst_ip_address = session.with_ip_address
|
||||
dst_port = session.dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
# Determine if the payload is for broadcast or unicast
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
# Handle broadcast transmission
|
||||
if isinstance(dst_ip_address, IPv4Network):
|
||||
is_broadcast = True
|
||||
dst_ip_address = dst_ip_address.broadcast_address
|
||||
if dst_ip_address:
|
||||
# Find a suitable NIC for the broadcast
|
||||
for nic in self.arp_cache.nics.values():
|
||||
if dst_ip_address in nic.ip_network and nic.enabled:
|
||||
dst_mac_address = "ff:ff:ff:ff:ff:ff"
|
||||
outbound_nic = nic
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Resolve MAC address for unicast transmission
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
# Resolve outbound NIC for unicast transmission
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
|
||||
# If MAC address not found, initiate ARP request
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
# Reattempt payload transmission after ARP request
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
# Return None if reattempt fails
|
||||
return
|
||||
|
||||
# Check if outbound NIC and destination MAC address are resolved
|
||||
if not outbound_nic or not dst_mac_address:
|
||||
return False
|
||||
|
||||
# Construct the frame for transmission
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
@@ -189,15 +224,17 @@ class SessionManager:
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
# Manage session for unicast transmission
|
||||
if not (is_broadcast and session_id):
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
# Create a new session if it doesn't exist
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
|
||||
# Send the frame through the NIC
|
||||
return outbound_nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -130,20 +130,28 @@ class SoftwareManager:
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_port=dest_port,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
|
||||
@@ -41,6 +41,9 @@ class SysLog:
|
||||
The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out
|
||||
JSON-like messages.
|
||||
"""
|
||||
if not SIM_OUTPUT.save_sys_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
@@ -91,7 +94,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.debug(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
|
||||
def info(self, msg: str):
|
||||
"""
|
||||
@@ -99,7 +103,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.info(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
|
||||
def warning(self, msg: str):
|
||||
"""
|
||||
@@ -107,7 +112,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.warning(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
|
||||
def error(self, msg: str):
|
||||
"""
|
||||
@@ -115,7 +121,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.error(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
|
||||
def critical(self, msg: str):
|
||||
"""
|
||||
@@ -123,4 +130,5 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.critical(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -22,7 +24,7 @@ class DatabaseService(Service):
|
||||
|
||||
password: Optional[str] = None
|
||||
|
||||
backup_server: IPv4Address = None
|
||||
backup_server_ip: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
|
||||
latest_backup_directory: str = None
|
||||
@@ -36,7 +38,6 @@ class DatabaseService(Service):
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
|
||||
def set_original_state(self):
|
||||
@@ -45,8 +46,8 @@ class DatabaseService(Service):
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"_connections",
|
||||
"backup_server",
|
||||
"connections",
|
||||
"backup_server_ip",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
}
|
||||
@@ -64,7 +65,7 @@ class DatabaseService(Service):
|
||||
|
||||
:param: backup_server_ip: The IP address of the backup server
|
||||
"""
|
||||
self.backup_server = backup_server
|
||||
self.backup_server_ip = backup_server
|
||||
|
||||
def backup_database(self) -> bool:
|
||||
"""Create a backup of the database to the configured backup server."""
|
||||
@@ -73,7 +74,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# check if the backup server was configured
|
||||
if self.backup_server is None:
|
||||
if self.backup_server_ip is None:
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
@@ -81,10 +82,14 @@ class DatabaseService(Service):
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# send backup copy of database file to FTP server
|
||||
if not self.db_file:
|
||||
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
|
||||
return False
|
||||
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
src_file_name=self.db_file.name,
|
||||
src_folder_name="database",
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
)
|
||||
@@ -110,7 +115,7 @@ class DatabaseService(Service):
|
||||
src_file_name="database.db",
|
||||
dest_folder_name="downloads",
|
||||
dest_file_name="database.db",
|
||||
dest_ip_address=self.backup_server,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
)
|
||||
|
||||
if not response:
|
||||
@@ -118,13 +123,10 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
|
||||
|
||||
if self._db_file is None:
|
||||
if self.db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
@@ -134,12 +136,30 @@ class DatabaseService(Service):
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def db_file(self) -> File:
|
||||
"""Returns the database file."""
|
||||
return self.file_system.get_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def folder(self) -> Folder:
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
:param connection_id: A unique identifier for the connection
|
||||
:type connection_id: str
|
||||
:param password: Supplied password. It must match self.password for connection success, defaults to None
|
||||
:type password: Optional[str], optional
|
||||
:return: Response to connection request containing success info.
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
@@ -184,7 +204,7 @@ class DatabaseService(Service):
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
@@ -195,17 +215,8 @@ class DatabaseService(Service):
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
@@ -265,3 +276,19 @@ class DatabaseService(Service):
|
||||
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
return payload["status_code"] == 200
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a single timestep of simulation dynamics to this service.
|
||||
|
||||
Here at the first step, the database backup is created, in addition to normal service update logic.
|
||||
"""
|
||||
if timestep == 1:
|
||||
self.backup_database()
|
||||
return super().apply_timestep(timestep)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Perform a database restore when the patching countdown is finished."""
|
||||
super()._update_patch_status()
|
||||
if self._patching_countdown is None:
|
||||
self.restore_backup()
|
||||
|
||||
@@ -89,6 +89,7 @@ class FTPClient(FTPServiceABC):
|
||||
f"{self.name}: Successfully connected to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
|
||||
)
|
||||
self.add_connection(connection_id="server_connection", session_id=session_id)
|
||||
return True
|
||||
else:
|
||||
if is_reattempt:
|
||||
|
||||
@@ -99,5 +99,5 @@ class FTPServer(FTPServiceABC):
|
||||
if payload.status_code is not None:
|
||||
return False
|
||||
|
||||
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import shutil
|
||||
from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -16,6 +16,10 @@ class FTPServiceABC(Service, ABC):
|
||||
Contains shared methods between both classes.
|
||||
"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Returns a Dict of the FTPService state."""
|
||||
return super().describe_state()
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
@@ -52,10 +56,12 @@ class FTPServiceABC(Service, ABC):
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
file_size = payload.ftp_command_args["file_size"]
|
||||
real_file_path = payload.ftp_command_args.get("real_file_path")
|
||||
health_status = payload.ftp_command_args["health_status"]
|
||||
is_real = real_file_path is not None
|
||||
file = self.file_system.create_file(
|
||||
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
|
||||
)
|
||||
file.health_status = health_status
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
|
||||
f"{payload.ftp_command_args['dest_file_name']}"
|
||||
@@ -110,6 +116,7 @@ class FTPServiceABC(Service, ABC):
|
||||
"dest_file_name": dest_file_name,
|
||||
"file_size": file.sim_size,
|
||||
"real_file_path": file.sim_path if file.real else None,
|
||||
"health_status": file.health_status,
|
||||
},
|
||||
packet_payload_size=file.sim_size,
|
||||
status_code=FTPStatusCode.OK if is_response else None,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -43,9 +44,6 @@ class Service(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -98,6 +96,7 @@ class Service(IOSoftware):
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -118,7 +117,6 @@ class Service(IOSoftware):
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
@@ -129,42 +127,39 @@ class Service(IOSoftware):
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.restart_countdown = self.restart_duration
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
@@ -181,5 +176,4 @@ class Service(IOSoftware):
|
||||
if self.restart_countdown <= 0:
|
||||
_LOGGER.debug(f"Restarting finished for service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.restart_countdown -= 1
|
||||
|
||||
@@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -123,7 +124,10 @@ class WebServer(Service):
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception:
|
||||
|
||||
@@ -2,8 +2,8 @@ import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -38,12 +38,12 @@ class SoftwareHealthState(Enum):
|
||||
"Unused state."
|
||||
GOOD = 1
|
||||
"The software is in a good and healthy condition."
|
||||
COMPROMISED = 2
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 3
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
PATCHING = 4
|
||||
PATCHING = 2
|
||||
"The software is undergoing patching or updates."
|
||||
COMPROMISED = 3
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 4
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
|
||||
|
||||
class SoftwareCriticality(Enum):
|
||||
@@ -71,9 +71,9 @@ class Software(SimComponent):
|
||||
|
||||
name: str
|
||||
"The name of the software."
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The actual health state of the software."
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
@@ -195,8 +195,9 @@ class Software(SimComponent):
|
||||
|
||||
def patch(self) -> None:
|
||||
"""Perform a patch on the software."""
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
@@ -282,7 +283,7 @@ class IOSoftware(Software):
|
||||
|
||||
Returns true if the software can perform actions.
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state is NodeOperatingState.OFF:
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
|
||||
return False
|
||||
return True
|
||||
@@ -303,13 +304,13 @@ class IOSoftware(Software):
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
@@ -350,19 +351,22 @@ class IOSoftware(Software):
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network. It delegates the actual sending process to the SoftwareManager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user