Add fault tolerance to SessionRouter and it's children

This commit is contained in:
Chord 2016-05-13 23:31:27 -04:00
parent 10f54b7980
commit 4d52d65b33
5 changed files with 110 additions and 33 deletions

View file

@ -15,7 +15,8 @@ import net.psforever.packet.control.{ClientStart, ServerStart}
import net.psforever.packet.crypto._
/**
* Actor that stores crypto state for a connection and filters away any packet metadata.
* Actor that stores crypto state for a connection, appropriately encrypts and decrypts packets,
* and passes packets along to the next hop once processed.
*/
class CryptoSessionActor extends Actor with MDCContextAware {
private[this] val log = org.log4s.getLogger
@ -23,7 +24,7 @@ class CryptoSessionActor extends Actor with MDCContextAware {
var leftRef : ActorRef = ActorRef.noSender
var rightRef : ActorRef = ActorRef.noSender
var cryptoDHState = new CryptoInterface.CryptoDHState()
var cryptoDHState : Option[CryptoInterface.CryptoDHState] = None
var cryptoState : Option[CryptoInterface.CryptoStateWithMAC] = None
val random = new SecureRandom()
@ -36,6 +37,11 @@ class CryptoSessionActor extends Actor with MDCContextAware {
var clientChallenge = ByteVector.empty
var clientChallengeResult = ByteVector.empty
// Don't leak crypto object memory even on an exception
override def postStop() = {
cleanupCrypto()
}
def receive = Initializing
def Initializing : Receive = {
@ -83,8 +89,13 @@ class CryptoSessionActor extends Actor with MDCContextAware {
p match {
case CryptoPacket(seq, ClientChallengeXchg(time, challenge, p, g)) =>
cryptoDHState = Some(new CryptoInterface.CryptoDHState())
val dh = cryptoDHState.get
// initialize our crypto state from the client's P and G
cryptoDHState.start(p, g)
dh.start(p, g)
// save the client challenge
clientChallenge = ServerChallengeXchg.getCompleteChallenge(time, challenge)
@ -99,7 +110,7 @@ class CryptoSessionActor extends Actor with MDCContextAware {
serverChallenge = ServerChallengeXchg.getCompleteChallenge(serverTime, randomChallenge)
val packet = PacketCoding.CreateCryptoPacket(seq,
ServerChallengeXchg(serverTime, randomChallenge, cryptoDHState.getPublicKey))
ServerChallengeXchg(serverTime, randomChallenge, dh.getPublicKey))
val sentPacket = sendResponse(packet)
@ -128,7 +139,11 @@ class CryptoSessionActor extends Actor with MDCContextAware {
// save the packet we got for a MAC check later
serverMACBuffer ++= msg.drop(3)
val agreedValue = cryptoDHState.agree(clientPublicKey)
val dh = cryptoDHState.get
val agreedValue = dh.agree(clientPublicKey)
// we are now done with the DH crypto object
dh.close
/*println("Agreed: " + agreedValue)
println(s"Client challenge: $clientChallenge")*/
@ -232,17 +247,23 @@ class CryptoSessionActor extends Actor with MDCContextAware {
log.error(error)
}
def resetState() : Unit = {
context.become(receive)
// reset the crypto primitives
cryptoDHState.close
cryptoDHState = new CryptoInterface.CryptoDHState()
def cleanupCrypto() = {
if(cryptoDHState.isDefined) {
cryptoDHState.get.close
cryptoDHState = None
}
if(cryptoState.isDefined) {
cryptoState.get.close
cryptoState = None
}
}
def resetState() : Unit = {
context.become(receive)
// reset the crypto primitives
cleanupCrypto()
serverChallenge = ByteVector.empty
serverChallengeResult = ByteVector.empty

View file

@ -98,11 +98,7 @@ class LoginSessionActor extends Actor with MDCContextAware {
0, 1, 2, 685276011, username, 0, false)
sendResponse(PacketCoding.CreateGamePacket(0, response))
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
context.system.scheduler.schedule(0 seconds, 250 milliseconds, self, UpdateServerList())
updateServerList
case default => log.debug(s"Unhandled GamePacket ${pkt}")
}

View file

@ -64,7 +64,7 @@ object PsLogin {
//val system = ActorSystem("PsLogin", Some(ConfigFactory.parseMap(config)), None, Some(MDCPropagatingExecutionContextWrapper(ExecutionContext.Implicits.global)))
val system = ActorSystem("PsLogin", ConfigFactory.parseMap(config))
val session = system.actorOf(Props[SessionRouter], "session-router")
val listener = system.actorOf(Props(new UdpListener(session, InetAddress.getLocalHost, 51000)), "login-udp-endpoint")
val listener = system.actorOf(Props(new UdpListener(Props[SessionRouter], "session-router",
InetAddress.getLocalHost, 51000)), "login-udp-endpoint")
}
}

View file

@ -8,6 +8,7 @@ import scodec.bits._
import scala.collection.mutable
import MDCContextAware.Implicits._
import akka.actor.MDCContextAware.MdcMsg
import akka.actor.SupervisorStrategy.Stop
final case class RawPacket(data : ByteVector)
final case class ResponsePacket(data : ByteVector)
@ -27,6 +28,12 @@ class SessionRouter extends Actor with MDCContextAware {
var sessionId = 0L // this is a connection session, not an actual logged in session ID
var inputRef : ActorRef = ActorRef.noSender
override def supervisorStrategy = OneForOneStrategy() { case _ => Stop }
override def preStart = {
log.info("SessionRouter started...ready for PlanetSide sessions")
}
/*
Login sessions are divided between two actors. the crypto session actor transparently handles all of the cryptographic
setup of the connection. Once a correct crypto session has been established, all packets, after being decrypted
@ -47,9 +54,12 @@ class SessionRouter extends Actor with MDCContextAware {
def initializing : Receive = {
case Hello() =>
inputRef = sender()
inputRef ! SendPacket(hex"41414141", new InetSocketAddress("8.8.8.8", 51000))
context.become(started)
case _ =>
log.error("Unknown message")
case default =>
log.error(s"Unknown message $default. Stopping...")
context.stop(self)
}
@ -67,7 +77,8 @@ class SessionRouter extends Actor with MDCContextAware {
idBySocket{from} = session.id
sessionById{session.id} = session
sessionByActor{session.pipeline.head} = session
sessionByActor{session.startOfPipe} = session
sessionByActor{session.nextOfStart} = session
MDC("sessionId") = session.id.toString
@ -80,25 +91,56 @@ class SessionRouter extends Actor with MDCContextAware {
MDC.clear()
}
case ResponsePacket(msg) =>
val session = sessionByActor{sender()}
val session = sessionByActor.get(sender())
log.trace(s"Sending response ${msg}")
// drop any old queued messages from old actors
//if(session.isDefined) {
log.trace(s"Sending response ${msg}")
inputRef ! SendPacket(msg, session.address)
case _ => log.error("Unknown message")
inputRef ! SendPacket(msg, session.get.address)
//}
case Terminated(actor) =>
val terminatedSession = sessionByActor.get(actor)
if(terminatedSession.isDefined) {
removeSessionById(terminatedSession.get.id, s"${actor.path.name} died")
}
case default =>
log.error(s"Unknown message $default")
}
def createNewSession(address : InetSocketAddress) = {
val id = newSessionId
val cryptoSession = context.actorOf(Props[CryptoSessionActor],
"crypto-session" + id.toString)
"crypto-session-" + id.toString)
val loginSession = context.actorOf(Props[LoginSessionActor],
"login-session" + id.toString)
"login-session-" + id.toString)
context.watch(cryptoSession)
context.watch(loginSession)
SessionState(id, address, List(cryptoSession, loginSession))
}
def removeSessionById(id : Long, reason : String) : Unit = {
val sessionOption = sessionById.get(id)
if(!sessionOption.isDefined)
return
val session = sessionOption.get
// TODO: add some sort of delay to prevent old session packets from coming through
// kill all session specific actors
session.pipeline.foreach(_ ! PoisonPill)
session.pipeline.foreach(sessionByActor remove _)
sessionById.remove(id)
idBySocket.remove(session.address)
log.info(s"Stopping session ${id} (reason: $reason)")
}
def newSessionId = {
val oldId = sessionId
sessionId += 1

View file

@ -1,7 +1,8 @@
// Copyright (c) 2016 PSForever.net to present
import java.net.{InetAddress, InetSocketAddress}
import akka.actor.{Actor, ActorLogging, ActorRef, Identify}
import akka.actor.SupervisorStrategy.{Restart, Stop}
import akka.actor.{Actor, ActorRef, OneForOneStrategy, Props, Terminated}
import akka.io._
import scodec.bits.ByteVector
import scodec.interop.akka._
@ -11,21 +12,29 @@ final case class SendPacket(msg : ByteVector, to : InetSocketAddress)
final case class Hello()
final case class HelloFriend(next: ActorRef)
class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extends Actor {
private val logger = org.log4s.getLogger
class UdpListener(nextActorProps : Props, nextActorName : String, address : InetAddress, port : Int) extends Actor {
private val log = org.log4s.getLogger
override def supervisorStrategy = OneForOneStrategy() {
case _ => Stop
}
import context.system
IO(Udp) ! Udp.Bind(self, new InetSocketAddress(address, port))
var bytesRecevied = 0L
var bytesSent = 0L
var nextActor : ActorRef = Actor.noSender
def receive = {
case Udp.Bound(local) =>
logger.info(s"Now listening on UDP:$local")
log.info(s"Now listening on UDP:$local")
createNextActor()
nextActor ! Hello()
context.become(ready(sender()))
case default =>
log.error(s"Unexpected message $default")
}
def ready(socket: ActorRef): Receive = {
@ -37,6 +46,15 @@ class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extend
nextActor ! ReceivedPacket(data.toByteVector, remote)
case Udp.Unbind => socket ! Udp.Unbind
case Udp.Unbound => context.stop(self)
case default => logger.error(s"Unhandled message: $default")
case Terminated(actor) =>
log.error(s"Next actor ${actor.path.name} has died...restarting")
createNextActor()
case default => log.error(s"Unhandled message: $default")
}
def createNextActor() = {
nextActor = context.actorOf(nextActorProps, nextActorName)
context.watch(nextActor)
nextActor ! Hello()
}
}