diff --git a/pslogin/src/main/scala/CryptoSessionActor.scala b/pslogin/src/main/scala/CryptoSessionActor.scala index 5f76c4f7..6a8421b8 100644 --- a/pslogin/src/main/scala/CryptoSessionActor.scala +++ b/pslogin/src/main/scala/CryptoSessionActor.scala @@ -15,7 +15,8 @@ import net.psforever.packet.control.{ClientStart, ServerStart} import net.psforever.packet.crypto._ /** - * Actor that stores crypto state for a connection and filters away any packet metadata. + * Actor that stores crypto state for a connection, appropriately encrypts and decrypts packets, + * and passes packets along to the next hop once processed. */ class CryptoSessionActor extends Actor with MDCContextAware { private[this] val log = org.log4s.getLogger @@ -23,7 +24,7 @@ class CryptoSessionActor extends Actor with MDCContextAware { var leftRef : ActorRef = ActorRef.noSender var rightRef : ActorRef = ActorRef.noSender - var cryptoDHState = new CryptoInterface.CryptoDHState() + var cryptoDHState : Option[CryptoInterface.CryptoDHState] = None var cryptoState : Option[CryptoInterface.CryptoStateWithMAC] = None val random = new SecureRandom() @@ -36,6 +37,11 @@ class CryptoSessionActor extends Actor with MDCContextAware { var clientChallenge = ByteVector.empty var clientChallengeResult = ByteVector.empty + // Don't leak crypto object memory even on an exception + override def postStop() = { + cleanupCrypto() + } + def receive = Initializing def Initializing : Receive = { @@ -83,8 +89,13 @@ class CryptoSessionActor extends Actor with MDCContextAware { p match { case CryptoPacket(seq, ClientChallengeXchg(time, challenge, p, g)) => + + cryptoDHState = Some(new CryptoInterface.CryptoDHState()) + + val dh = cryptoDHState.get + // initialize our crypto state from the client's P and G - cryptoDHState.start(p, g) + dh.start(p, g) // save the client challenge clientChallenge = ServerChallengeXchg.getCompleteChallenge(time, challenge) @@ -99,7 +110,7 @@ class CryptoSessionActor extends Actor with MDCContextAware { serverChallenge = ServerChallengeXchg.getCompleteChallenge(serverTime, randomChallenge) val packet = PacketCoding.CreateCryptoPacket(seq, - ServerChallengeXchg(serverTime, randomChallenge, cryptoDHState.getPublicKey)) + ServerChallengeXchg(serverTime, randomChallenge, dh.getPublicKey)) val sentPacket = sendResponse(packet) @@ -128,7 +139,11 @@ class CryptoSessionActor extends Actor with MDCContextAware { // save the packet we got for a MAC check later serverMACBuffer ++= msg.drop(3) - val agreedValue = cryptoDHState.agree(clientPublicKey) + val dh = cryptoDHState.get + val agreedValue = dh.agree(clientPublicKey) + + // we are now done with the DH crypto object + dh.close /*println("Agreed: " + agreedValue) println(s"Client challenge: $clientChallenge")*/ @@ -232,17 +247,23 @@ class CryptoSessionActor extends Actor with MDCContextAware { log.error(error) } - def resetState() : Unit = { - context.become(receive) - - // reset the crypto primitives - cryptoDHState.close - cryptoDHState = new CryptoInterface.CryptoDHState() + def cleanupCrypto() = { + if(cryptoDHState.isDefined) { + cryptoDHState.get.close + cryptoDHState = None + } if(cryptoState.isDefined) { cryptoState.get.close cryptoState = None } + } + + def resetState() : Unit = { + context.become(receive) + + // reset the crypto primitives + cleanupCrypto() serverChallenge = ByteVector.empty serverChallengeResult = ByteVector.empty diff --git a/pslogin/src/main/scala/LoginSessionActor.scala b/pslogin/src/main/scala/LoginSessionActor.scala index ac8feaa2..1d9ff0c2 100644 --- a/pslogin/src/main/scala/LoginSessionActor.scala +++ b/pslogin/src/main/scala/LoginSessionActor.scala @@ -98,11 +98,7 @@ class LoginSessionActor extends Actor with MDCContextAware { 0, 1, 2, 685276011, username, 0, false) sendResponse(PacketCoding.CreateGamePacket(0, response)) - - import scala.concurrent.duration._ - import scala.concurrent.ExecutionContext.Implicits.global - context.system.scheduler.schedule(0 seconds, 250 milliseconds, self, UpdateServerList()) - + updateServerList case default => log.debug(s"Unhandled GamePacket ${pkt}") } diff --git a/pslogin/src/main/scala/PsLogin.scala b/pslogin/src/main/scala/PsLogin.scala index 12f5be92..0114cd22 100644 --- a/pslogin/src/main/scala/PsLogin.scala +++ b/pslogin/src/main/scala/PsLogin.scala @@ -64,7 +64,7 @@ object PsLogin { //val system = ActorSystem("PsLogin", Some(ConfigFactory.parseMap(config)), None, Some(MDCPropagatingExecutionContextWrapper(ExecutionContext.Implicits.global))) val system = ActorSystem("PsLogin", ConfigFactory.parseMap(config)) - val session = system.actorOf(Props[SessionRouter], "session-router") - val listener = system.actorOf(Props(new UdpListener(session, InetAddress.getLocalHost, 51000)), "login-udp-endpoint") + val listener = system.actorOf(Props(new UdpListener(Props[SessionRouter], "session-router", + InetAddress.getLocalHost, 51000)), "login-udp-endpoint") } } diff --git a/pslogin/src/main/scala/SessionRouter.scala b/pslogin/src/main/scala/SessionRouter.scala index ebf2d9d5..137996b5 100644 --- a/pslogin/src/main/scala/SessionRouter.scala +++ b/pslogin/src/main/scala/SessionRouter.scala @@ -8,6 +8,7 @@ import scodec.bits._ import scala.collection.mutable import MDCContextAware.Implicits._ import akka.actor.MDCContextAware.MdcMsg +import akka.actor.SupervisorStrategy.Stop final case class RawPacket(data : ByteVector) final case class ResponsePacket(data : ByteVector) @@ -27,6 +28,12 @@ class SessionRouter extends Actor with MDCContextAware { var sessionId = 0L // this is a connection session, not an actual logged in session ID var inputRef : ActorRef = ActorRef.noSender + override def supervisorStrategy = OneForOneStrategy() { case _ => Stop } + + override def preStart = { + log.info("SessionRouter started...ready for PlanetSide sessions") + } + /* Login sessions are divided between two actors. the crypto session actor transparently handles all of the cryptographic setup of the connection. Once a correct crypto session has been established, all packets, after being decrypted @@ -47,9 +54,12 @@ class SessionRouter extends Actor with MDCContextAware { def initializing : Receive = { case Hello() => inputRef = sender() + + inputRef ! SendPacket(hex"41414141", new InetSocketAddress("8.8.8.8", 51000)) + context.become(started) - case _ => - log.error("Unknown message") + case default => + log.error(s"Unknown message $default. Stopping...") context.stop(self) } @@ -67,7 +77,8 @@ class SessionRouter extends Actor with MDCContextAware { idBySocket{from} = session.id sessionById{session.id} = session - sessionByActor{session.pipeline.head} = session + sessionByActor{session.startOfPipe} = session + sessionByActor{session.nextOfStart} = session MDC("sessionId") = session.id.toString @@ -80,25 +91,56 @@ class SessionRouter extends Actor with MDCContextAware { MDC.clear() } case ResponsePacket(msg) => - val session = sessionByActor{sender()} + val session = sessionByActor.get(sender()) - log.trace(s"Sending response ${msg}") + // drop any old queued messages from old actors + //if(session.isDefined) { + log.trace(s"Sending response ${msg}") - inputRef ! SendPacket(msg, session.address) - case _ => log.error("Unknown message") + inputRef ! SendPacket(msg, session.get.address) + //} + case Terminated(actor) => + val terminatedSession = sessionByActor.get(actor) + + if(terminatedSession.isDefined) { + removeSessionById(terminatedSession.get.id, s"${actor.path.name} died") + } + case default => + log.error(s"Unknown message $default") } def createNewSession(address : InetSocketAddress) = { val id = newSessionId val cryptoSession = context.actorOf(Props[CryptoSessionActor], - "crypto-session" + id.toString) + "crypto-session-" + id.toString) val loginSession = context.actorOf(Props[LoginSessionActor], - "login-session" + id.toString) + "login-session-" + id.toString) + + context.watch(cryptoSession) + context.watch(loginSession) SessionState(id, address, List(cryptoSession, loginSession)) } + def removeSessionById(id : Long, reason : String) : Unit = { + val sessionOption = sessionById.get(id) + + if(!sessionOption.isDefined) + return + + val session = sessionOption.get + + // TODO: add some sort of delay to prevent old session packets from coming through + // kill all session specific actors + session.pipeline.foreach(_ ! PoisonPill) + session.pipeline.foreach(sessionByActor remove _) + sessionById.remove(id) + idBySocket.remove(session.address) + + log.info(s"Stopping session ${id} (reason: $reason)") + } + def newSessionId = { val oldId = sessionId sessionId += 1 diff --git a/pslogin/src/main/scala/UdpListener.scala b/pslogin/src/main/scala/UdpListener.scala index c45b3ebf..06d9dbdd 100644 --- a/pslogin/src/main/scala/UdpListener.scala +++ b/pslogin/src/main/scala/UdpListener.scala @@ -1,7 +1,8 @@ // Copyright (c) 2016 PSForever.net to present import java.net.{InetAddress, InetSocketAddress} -import akka.actor.{Actor, ActorLogging, ActorRef, Identify} +import akka.actor.SupervisorStrategy.{Restart, Stop} +import akka.actor.{Actor, ActorRef, OneForOneStrategy, Props, Terminated} import akka.io._ import scodec.bits.ByteVector import scodec.interop.akka._ @@ -11,21 +12,29 @@ final case class SendPacket(msg : ByteVector, to : InetSocketAddress) final case class Hello() final case class HelloFriend(next: ActorRef) -class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extends Actor { - private val logger = org.log4s.getLogger +class UdpListener(nextActorProps : Props, nextActorName : String, address : InetAddress, port : Int) extends Actor { + private val log = org.log4s.getLogger + + override def supervisorStrategy = OneForOneStrategy() { + case _ => Stop + } import context.system IO(Udp) ! Udp.Bind(self, new InetSocketAddress(address, port)) var bytesRecevied = 0L var bytesSent = 0L + var nextActor : ActorRef = Actor.noSender def receive = { case Udp.Bound(local) => - logger.info(s"Now listening on UDP:$local") + log.info(s"Now listening on UDP:$local") + + createNextActor() - nextActor ! Hello() context.become(ready(sender())) + case default => + log.error(s"Unexpected message $default") } def ready(socket: ActorRef): Receive = { @@ -37,6 +46,15 @@ class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extend nextActor ! ReceivedPacket(data.toByteVector, remote) case Udp.Unbind => socket ! Udp.Unbind case Udp.Unbound => context.stop(self) - case default => logger.error(s"Unhandled message: $default") + case Terminated(actor) => + log.error(s"Next actor ${actor.path.name} has died...restarting") + createNextActor() + case default => log.error(s"Unhandled message: $default") + } + + def createNextActor() = { + nextActor = context.actorOf(nextActorProps, nextActorName) + context.watch(nextActor) + nextActor ! Hello() } } \ No newline at end of file