CCOD-DQN (Contention Window Optimization in IEEE 802.11ax Networks with Deep Reinforcement Learning)
Please read "Contention Window Optimization in IEEE 802.11ax Networks with Deep Reinforcement Learning" article published at WCNC 2021. Preprint available at Arxiv - https://arxiv.org/pdf/2003.01492 first. The original code for the paper can be found at https://github.com/wwydmanski/RLinWiFi. But in this lab, I will try to show how to measure the throughputs for the traditional 802.11ax (CSMACA) and CCOD-DQN. But for CCOD-DQN, I use PARL framework, not tensorflow or keras, to do the reinforcement learning.
Please follow the instructions at https://github.com/tkn-tub/ns3-gym to install ns3-gym. Also, follow the instructions at https://github.com/PaddlePaddle/PARL to install PARL.
[steps]
Prepare cw.cc and test_dqn.py under scratch/myrlwifi
cw.cc (code is from https://github.com/wwydmanski/RLinWiFi/blob/master/linear-mesh/cw.cc)
#include "ns3/core-module.h" #include "ns3/network-module.h" #include "ns3/applications-module.h" #include "ns3/wifi-module.h" #include "ns3/mobility-module.h" #include "ns3/csma-module.h" #include "ns3/internet-module.h" #include "ns3/flow-monitor-module.h" #include "ns3/opengym-module.h" #include "ns3/propagation-module.h" #include "ns3/ipv4-flow-classifier.h" #include "ns3/yans-wifi-channel.h" #include <cmath> #include <ctime> #include <sstream> #include <fstream> #include <string> #include <math.h> #include <ctime> #include <iomanip> #include <deque> #include <algorithm> #include <csignal> #define PI 3.14159265 using namespace ns3; using namespace std; NS_LOG_COMPONENT_DEFINE ("wifi1"); void recordHistory(); double SimTime = 100.0; uint64_t lastTotalRx = 0; uint32_t mactxno,macrxno,phyrxok,phyrxerror,phytx; Ptr<FlowMonitor> monitor; FlowMonitorHelper flowmon; double envStepTime = 0.1; double simulationTime = 10; double current_time = 0.0; bool verbose = false; int end_delay = 0; bool dry_run = false; uint32_t CW = 0; uint32_t history_length = 20; string type = "discrete"; deque<float> history; Ptr<PacketSink> sinkApp; Ptr<OpenGymSpace> MyGetObservationSpace(void) { current_time += envStepTime; float low = 0.0; float high = 10.0; std::vector<uint32_t> shape = { history_length, }; std::string dtype = TypeNameGet<float>(); Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace>(low, high, shape, dtype); if (verbose) NS_LOG_UNCOND("MyGetObservationSpace: " << space); return space; } Ptr<OpenGymSpace> MyGetActionSpace(void) { float low = 0.0; float high = 10.0; std::vector<uint32_t> shape = { 1, }; std::string dtype = TypeNameGet<float>(); Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace>(low, high, shape, dtype); if (verbose) NS_LOG_UNCOND("MyGetActionSpace: " << space); return space; } uint64_t g_rxPktNum = 0; uint64_t g_txPktNum = 0; uint64_t my_rxPktNum=0; std::string MyGetExtraInfo(void) { static float ticks = 0.0; static float lastValue = 0.0; //g_rxPktNum = sinkApp->GetTotalRxPkt(); g_rxPktNum = my_rxPktNum; //std::cout << "in MyGetExtraInfo(), g_rxPktNum=" << g_rxPktNum << std::endl; float obs = g_rxPktNum - lastValue; lastValue = g_rxPktNum; ticks += envStepTime; float sentMbytes = obs * (1500 - 20 - 8 - 8) * 8.0 / 1024 / 1024; std::string myInfo = std::to_string(sentMbytes); myInfo = myInfo + "|" + to_string(CW); if (verbose) NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); return myInfo; } bool MyExecuteActions(Ptr<OpenGymDataContainer> action) { if (verbose) NS_LOG_UNCOND("MyExecuteActions: " << action); Ptr<OpenGymBoxContainer<float>> box = DynamicCast<OpenGymBoxContainer<float>>(action); std::vector<float> actionVector = box->GetData(); if (type == "discrete") { CW = pow(2, int(4 + actionVector.at(0))); } else if (type == "continuous") { CW = pow(2, actionVector.at(0) + 4); } else if (type == "direct_continuous") { CW = actionVector.at(0); } else { std::cout << "Unsupported agent type!" << endl; exit(0); } if (verbose) { NS_LOG_UNCOND("actionVector.at(0): " << actionVector.at(0)); }
uint32_t min_cw = 16; uint32_t max_cw = 1024; CW = min(max_cw, max(CW, min_cw)); if (verbose) { NS_LOG_UNCOND("CW: " << CW); } if(!dry_run){ //Config::Set("/$ns3::NodeListPriv/NodeList/*/$ns3::Node/DeviceList/*/$ns3::WifiNetDevice/Mac/$ns3::RegularWifiMac/BE_Txop/$ns3::QosTxop/MinCw", UintegerValue(CW)); //Config::Set("/$ns3::NodeListPriv/NodeList/*/$ns3::Node/DeviceList/*/$ns3::WifiNetDevice/Mac/$ns3::RegularWifiMac/BE_Txop/$ns3::QosTxop/MaxCw", UintegerValue(CW)); Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(CW)); Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(CW)); //std::cout << "MinCw and MaxCw are set to " << CW << std::endl; } return true; } float MyGetReward(void) { static float ticks = 0.0; static uint32_t last_packets = 0; static float last_reward = 0.0; ticks += envStepTime; g_rxPktNum = my_rxPktNum; //g_rxPktNum = sinkApp->GetTotalRxPkt(); float res = g_rxPktNum - last_packets; //Need to understand why float reward = res * (1500 - 20 - 8 - 8) * 8.0 / 1024 / 1024 / (5 * 150 * envStepTime) * 10; last_packets = g_rxPktNum; if (ticks <= 2 * envStepTime) return 0.0; if (verbose) NS_LOG_UNCOND("MyGetReward: " << reward); if(reward>1.0f || reward<0.0f) reward = last_reward; last_reward = reward; return last_reward; } Ptr<OpenGymDataContainer> MyGetObservation() { recordHistory(); std::vector<uint32_t> shape = { history_length, }; Ptr<OpenGymBoxContainer<float>> box = CreateObject<OpenGymBoxContainer<float>>(shape); for (uint32_t i = 0; i < history.size(); i++) { if (history[i] >= -100 && history[i] <= 100) box->AddValue(history[i]); else box->AddValue(0); } for (uint32_t i = history.size(); i < history_length; i++) { box->AddValue(0); } if (verbose) NS_LOG_UNCOND("MyGetObservation: " << box); return box; } bool MyGetGameOver(void) { // bool isGameOver = (ns3::Simulator::Now().GetSeconds() > simulationTime + end_delay + 1.0); /* if (verbose) { bool isGameOver = false; static float stepCounter = 0.0; stepCounter += 1; if (stepCounter == 200) { isGameOver = true; } NS_LOG_UNCOND("MyGetGameOver: " << isGameOver); return isGameOver; } */ return false; } void ScheduleNextStateRead(double envStepTime, Ptr<OpenGymInterface> openGymInterface) { // if(ns3::Simulator::Now().GetSeconds()<simulationTime + end_delay + 1.0) // { Simulator::Schedule(Seconds(envStepTime), &ScheduleNextStateRead, envStepTime, openGymInterface); // } openGymInterface->NotifyCurrentState(); } void recordHistory() { static uint32_t last_rx = 0; static uint32_t last_tx = 0; static uint32_t calls = 0; calls++; g_rxPktNum = my_rxPktNum; //g_rxPktNum = sinkApp->GetTotalRxPkt(); float received = g_rxPktNum - last_rx; float sent = g_txPktNum - last_tx; float errs = sent - received; float ratio; ratio = errs / sent; history.push_front(ratio); if (history.size() > history_length) { history.pop_back(); } last_rx = g_rxPktNum; last_tx = g_txPktNum; } void packetReceived(Ptr<const Packet> packet) { //std::cout << "packetReceived() is called, pktsize=" << packet->GetSize() << "bytes" << std::endl; my_rxPktNum++; } void packetSent(Ptr<const Packet> packet) { //std::cout << "packetSent() is called, pktsize=" << packet->GetSize() << "bytes" << std::endl; g_txPktNum++; } void signalHandler(int signum) { cout << "Interrupt signal " << signum << " received.\n"; exit(signum); } int main(int argc, char *argv[]) { uint32_t nSta = 1; uint32_t cwmin = 15; uint32_t cwmax = 1023; uint32_t openGymPort = 5555;
double txStartTime = 0.1;
int mcs = 6; int channelWidth = 20; int guardInterval = 800; signal(SIGTERM, signalHandler); CommandLine cmd; cmd.AddValue("openGymPort", "Specify port number. Default: 5555", openGymPort); cmd.AddValue("CW", "Value of Contention Window", CW); cmd.AddValue("historyLength", "Length of history window", history_length); cmd.AddValue("verbose", "Tell echo applications to log if true", verbose); cmd.AddValue("dryRun", "Execute scenario with BEB and no agent interaction", dry_run); cmd.AddValue("simTime", "Simulation time in seconds. Default: 10s", simulationTime); cmd.AddValue("envStepTime", "Step time in seconds. Default: 0.1s", envStepTime); cmd.AddValue ("nSta", "Number of wifi STA devices", nSta); cmd.AddValue ("cwmin", "Minimum contention window size", cwmin); cmd.AddValue ("cwmax", "Maximum contention window size", cwmax); cmd.AddValue ("agentType", "Agent Type", type); cmd.Parse (argc, argv); Config::SetDefault ("ns3::WifiRemoteStationManager::FragmentationThreshold", StringValue ("2200")); Config::SetDefault ("ns3::WifiRemoteStationManager::RtsCtsThreshold", StringValue ("2200")); NS_LOG_UNCOND("Ns3Env parameters:"); NS_LOG_UNCOND("--nSta: " << nSta); NS_LOG_UNCOND("--simulationTime: " << simulationTime); NS_LOG_UNCOND("--openGymPort: " << openGymPort); NS_LOG_UNCOND("--envStepTime: " << envStepTime); NS_LOG_UNCOND("--agentType: " << type); NS_LOG_UNCOND("--dryRun: " << dry_run); NS_LOG_UNCOND("--verbose: " << verbose); WifiMacHelper wifiMac; WifiHelper wifiHelper; wifiHelper.SetStandard (WIFI_PHY_STANDARD_80211ax_5GHZ); std::ostringstream oss; oss << "HeMcs" << mcs; wifiHelper.SetRemoteStationManager("ns3::ConstantRateWifiManager", "DataMode", StringValue(oss.str()), "ControlMode", StringValue(oss.str())); Ptr<MatrixPropagationLossModel> lossModel = CreateObject<MatrixPropagationLossModel>(); lossModel->SetDefaultLoss(50); YansWifiChannelHelper channel = YansWifiChannelHelper::Default (); Ptr<YansWifiChannel> chan = channel.Create(); chan->SetPropagationLossModel(lossModel); chan->SetPropagationDelayModel(CreateObject<ConstantSpeedPropagationDelayModel>()); YansWifiPhyHelper wifiPhy; wifiPhy = YansWifiPhyHelper::Default(); wifiPhy.SetChannel(chan); wifiPhy.Set("GuardInterval", TimeValue(NanoSeconds(guardInterval))); NodeContainer wifiStaNodes; wifiStaNodes.Create (nSta); NodeContainer wifiApNode; wifiApNode.Create (uint32_t (1)); Ssid ssid = Ssid ("wifi1"); wifiMac.SetType ("ns3::ApWifiMac", "Ssid", SsidValue (ssid)); NetDeviceContainer apDevice; apDevice = wifiHelper.Install (wifiPhy, wifiMac, wifiApNode); wifiMac.SetType ("ns3::StaWifiMac", "Ssid", SsidValue (ssid));
NetDeviceContainer staDevices; staDevices = wifiHelper.Install (wifiPhy, wifiMac, wifiStaNodes); Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/Phy/ChannelWidth", UintegerValue(channelWidth));
std::cout << "----------------------------" << std::endl; if (!dry_run) { Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(CW)); Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(CW)); } else { NS_LOG_UNCOND("Default CW"); Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(cwmin)); Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(cwmax)); } MobilityHelper mobility; Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator> (); positionAlloc->Add (Vector (0.0, 0.0, 0.0)); float rho = 0.5; for (uint32_t i = 0; i < nSta; i++) { double theta = i * 2 * PI / nSta; positionAlloc->Add (Vector (rho * cos(theta), rho * sin(theta), 0.0)); std::cout << "node " << i << " position:" << "(" << rho * cos(theta) << "," << rho * sin(theta) << ",0.0)" << std::endl; }
mobility.SetPositionAllocator (positionAlloc); mobility.SetMobilityModel ("ns3::ConstantPositionMobilityModel"); mobility.Install (wifiApNode); mobility.Install (wifiStaNodes); InternetStackHelper stack; stack.Install(wifiApNode); stack.Install(wifiStaNodes); Ipv4AddressHelper address; address.SetBase("10.1.1.0", "255.255.255.0"); Ipv4InterfaceContainer ApInterface = address.Assign(apDevice); Ipv4InterfaceContainer StaInterface = address.Assign(staDevices); for (uint32_t i = 0; i < nSta; i++) { OnOffHelper onoff("ns3::UdpSocketFactory", Address(InetSocketAddress(ApInterface.GetAddress(0), 9))); onoff.SetConstantRate(DataRate ("100000kb/s"), 1500 - 20 - 8 - 8); ApplicationContainer temp = onoff.Install(wifiStaNodes.Get (i)); temp.Start(Seconds(txStartTime)); temp.Stop(Seconds(simulationTime)); } PacketSinkHelper sink ("ns3::UdpSocketFactory", Address(InetSocketAddress(ApInterface.GetAddress(0), 9))); ApplicationContainer Serverapp = sink.Install(wifiApNode.Get (0)); Serverapp.Start(Seconds (0.0)); //Config::ConnectWithoutContext("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/Phy/PhyTxBegin", MakeCallback(&packetSent)); Config::ConnectWithoutContext("/NodeList/*/DeviceList/*/Mac/MacTx", MakeCallback(&packetSent)); Config::ConnectWithoutContext ("/NodeList/*/DeviceList/*/Mac/MacRx", MakeCallback (&packetReceived)); Ipv4GlobalRoutingHelper::PopulateRoutingTables (); sinkApp = DynamicCast<PacketSink> (Serverapp.Get (0)); FlowMonitorHelper flowmon; Ptr<FlowMonitor> monitor = flowmon.InstallAll (); Ptr<OpenGymInterface> openGymInterface = CreateObject<OpenGymInterface>(openGymPort); openGymInterface->SetGetActionSpaceCb(MakeCallback(&MyGetActionSpace)); openGymInterface->SetGetObservationSpaceCb(MakeCallback(&MyGetObservationSpace)); openGymInterface->SetGetGameOverCb(MakeCallback(&MyGetGameOver)); openGymInterface->SetGetObservationCb(MakeCallback(&MyGetObservation)); openGymInterface->SetGetRewardCb(MakeCallback(&MyGetReward)); openGymInterface->SetGetExtraInfoCb(MakeCallback(&MyGetExtraInfo)); openGymInterface->SetExecuteActionsCb(MakeCallback(&MyExecuteActions)); Simulator::Schedule(Seconds(1.0), &ScheduleNextStateRead, envStepTime, openGymInterface); Simulator::Stop(Seconds(simulationTime + 1.0 + envStepTime*(history_length+1))); Simulator::Run(); Ptr<Ipv4FlowClassifier> classifier = DynamicCast<Ipv4FlowClassifier>(flowmon.GetClassifier()); std::map<FlowId, FlowMonitor::FlowStats> stats = monitor->GetFlowStats(); double lastRxTime = 0; double firstRxTime = simulationTime + 10;; double flowThr; double timediff; uint32_t totalRx =0; uint32_t totalTx =0; uint32_t totalRxBytes =0; for(std::map<FlowId, FlowMonitor::FlowStats>::const_iterator set = stats.begin(); set != stats.end(); set++) { if(lastRxTime < set->second.timeLastRxPacket.GetSeconds()) { lastRxTime = set->second.timeLastRxPacket.GetSeconds(); } if(firstRxTime > set->second.timeFirstRxPacket.GetSeconds()) { firstRxTime = set->second.timeFirstRxPacket.GetSeconds(); } totalRx += set->second.rxPackets; totalTx += set->second.txPackets; totalRxBytes += set->second.rxBytes; Ipv4FlowClassifier::FiveTuple t = classifier->FindFlow(set->first); timediff = set->second.timeLastRxPacket.GetSeconds() - set->second.timeFirstRxPacket.GetSeconds(); flowThr = set->second.rxBytes * 8.0 / timediff / 1000 / 1000; std::cout << "Flow " << set->first << " (" << t.sourceAddress << " -> " << t.destinationAddress << ")\tThroughput: " << flowThr << " Mbps\tTime: " << set->second.timeLastRxPacket.GetSeconds() - set->second.timeFirstRxPacket.GetSeconds() << " s\tRx packets " << set->second.rxPackets << std::endl; //std::cout << "packetsDropped:" << set->second.packetsDropped.size() << std::endl; } std::cout << "totalTx:" << totalTx << " totalRx:" << totalRx << std::endl; std::cout << "sinkApp->GetTotalRxPkt()=" << sinkApp->GetTotalRxPkt() << std::endl; std::cout << "g_txPktNum=" << g_txPktNum << std::endl; std::cout << "my_rxPktNum=" << my_rxPktNum << std::endl; //std::cout << "totalRxBytes=" << totalRxBytes << std::endl; //std::cout << "sinkApp->GetTotalRx()=" << sinkApp->GetTotalRx() << std::endl; double totalBytes = sinkApp->GetTotalRx(); float throughput = totalBytes * 8.0/1000/1000/(lastRxTime - firstRxTime); std::cout << "throughput:\t" << throughput << " Mbps" << std::endl; //std::cout << "cwmin: " << cwmin << ", cwmax: " << cwmax << ", nSta: " << nSta << std::endl; //std::cout << "firstRxTime: " << firstRxTime << "sec,\t lastRxTime: " << lastRxTime << "sec" << std::endl;
openGymInterface->NotifySimulationEnd(); Simulator::Destroy (); } |
test_dqn.py
#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse from ns3gym import ns3env import numpy as np import random import os import parl from parl import layers import copy import paddle.fluid as fluid import collections MEMORY_SIZE = 20000 MEMORY_WARMUP_SIZE = 100 BATCH_SIZE = 32 LEARNING_RATE = 0.001 GAMMA = 0.9 class Model(parl.Model): def __init__(self, act_dim): hid1_size = 128 hid2_size = 128 self.fc1 = layers.fc(size=hid1_size, act='relu') self.fc2 = layers.fc(size=hid2_size, act='relu') self.fc3 = layers.fc(size=act_dim, act=None) def value(self, obs): h1 = self.fc1(obs) h2 = self.fc2(h1) Q = self.fc3(h2) return Q class DQN(parl.Algorithm): def __init__(self, model, act_dim=None, gamma=None, lr=None): self.model = model self.target_model = copy.deepcopy(model) assert isinstance(act_dim, int) assert isinstance(gamma, float) assert isinstance(lr, float) self.act_dim = act_dim self.gamma = gamma self.lr = lr def predict(self, obs): return self.model.value(obs) def learn(self, obs, action, reward, next_obs, terminal): next_pred_value = self.target_model.value(next_obs) best_v = layers.reduce_max(next_pred_value, dim=1) best_v.stop_gradient = True terminal = layers.cast(terminal, dtype='float32') target = reward + (1.0 - terminal) * self.gamma * best_v pred_value = self.model.value(obs) action_onehot = layers.one_hot(action, self.act_dim) action_onehot = layers.cast(action_onehot, dtype='float32') pred_action_value = layers.reduce_sum( layers.elementwise_mul(action_onehot, pred_value), dim=1) cost = layers.square_error_cost(pred_action_value, target) cost = layers.reduce_mean(cost) optimizer = fluid.optimizer.Adam(learning_rate=self.lr) optimizer.minimize(cost) return cost def sync_target(self): self.model.sync_weights_to(self.target_model) class Agent(parl.Agent): def __init__(self, algorithm, obs_dim, act_dim, e_greed=0.1, e_greed_decrement=0): assert isinstance(obs_dim, int) assert isinstance(act_dim, int) self.obs_dim = obs_dim self.act_dim = act_dim super(Agent, self).__init__(algorithm) self.global_step = 0 self.update_target_steps = 200 self.e_greed = e_greed self.e_greed_decrement = e_greed_decrement def build_program(self): self.pred_program = fluid.Program() self.learn_program = fluid.Program() with fluid.program_guard(self.pred_program): obs = layers.data( name='obs', shape=[self.obs_dim], dtype='float32') self.value = self.alg.predict(obs) with fluid.program_guard(self.learn_program): obs = layers.data( name='obs', shape=[self.obs_dim], dtype='float32') action = layers.data(name='act', shape=[1], dtype='int32') reward = layers.data(name='reward', shape=[], dtype='float32') next_obs = layers.data( name='next_obs', shape=[self.obs_dim], dtype='float32') terminal = layers.data(name='terminal', shape=[], dtype='bool') self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) def sample(self, obs): sample = np.random.rand() if sample < self.e_greed: act = np.random.randint(self.act_dim) else: act = self.predict(obs) self.e_greed = max( 0.01, self.e_greed - self.e_greed_decrement) return act def predict(self, obs): obs = np.expand_dims(obs, axis=0) pred_Q = self.fluid_executor.run( self.pred_program, feed={'obs': obs.astype('float32')}, fetch_list=[self.value])[0] pred_Q = np.squeeze(pred_Q, axis=0) act = np.argmax(pred_Q) return act def learn(self, obs, act, reward, next_obs, terminal): if self.global_step % self.update_target_steps == 0: self.alg.sync_target() self.global_step += 1 act = np.expand_dims(act, -1) feed = { 'obs': obs.astype('float32'), 'act': act.astype('int32'), 'reward': reward, 'next_obs': next_obs.astype('float32'), 'terminal': terminal } cost = self.fluid_executor.run( self.learn_program, feed=feed, fetch_list=[self.cost])[0] return cost class ReplayMemory(object): def __init__(self, max_size): self.buffer = collections.deque(maxlen=max_size) def append(self, exp): self.buffer.append(exp) def sample(self, batch_size): mini_batch = random.sample(self.buffer, batch_size) obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], [] for experience in mini_batch: s, a, r, s_p, done = experience obs_batch.append(s) action_batch.append(a) reward_batch.append(r) next_obs_batch.append(s_p) done_batch.append(done) return np.array(obs_batch).astype('float32'), \ np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\ np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32') def __len__(self): return len(self.buffer) port = 5555 seed = 1 env = ns3env.Ns3Env(port=port, simSeed=seed) env.reset() stepIdx = 0 obs_dim=20 act_dim=7 model = Model(act_dim=act_dim) algorithm = DQN(model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE) agent = Agent( algorithm, obs_dim=obs_dim, act_dim=act_dim, e_greed=0.1, e_greed_decrement=1e-6) rpm = ReplayMemory(MEMORY_SIZE) try: obs = env.reset() print("Step: ", stepIdx) print("---obs:", obs)
while True: while len(rpm) < MEMORY_WARMUP_SIZE: stepIdx += 1 action = agent.sample(obs) action2 = np.array([action]) print("---action: ", action) next_obs, reward, done, info = env.step(action2) print("---obs, reward, done, info: ", next_obs, reward, done, info) rpm.append((obs, action, reward, next_obs, done)) obs=next_obs print("Step: ", stepIdx) if stepIdx%5==0: print("="*20,"agent learn","="*20) (batch_obs, batch_action, batch_reward, batch_next_obs, batch_done) = rpm.sample(BATCH_SIZE) train_loss = agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, batch_done) action = agent.sample(obs) action2 = np.array([action]) print("---action: ", action) next_obs, reward, done, info = env.step(action2) print("---obs, reward, done, info: ", next_obs, reward, done, info) rpm.append((obs, action, reward, next_obs, done)) obs=next_obs if done: print("done") break stepIdx += 1 print("Step: ", stepIdx)
except KeyboardInterrupt: print("Ctrl-C -> Exit") finally: env.close() print("Done") |
[Executions] (My Test OS environment: ubuntu18.04)
(for CSMACA, 30 nodes, simulation time:10 sec)
Open another terminal
(Wait)
In the first terminal, you can see that the throughput for CSMACA is 39.3 Mbps.
For CCOD-DQN
Open another terminal
(Wait)
You can find the throughput for CCOD-DQN is 51.5933Mbps (better than CSMACA)
If you are interested in improving 802.11 throughput via reinforcement learning, you can also refer to my work at https://nqucsie.myqnapcloud.com/smallko/setl-rl.htm
Last Modified: 2022/2/27 done
[Author]
Dr. Chih-Heng Ke
Department
of Computer Science and Information Engineering, National Quemoy
University, Kinmen, Taiwan
Email:
smallko@gmail.com