Add fault tolerance to SessionRouter and it's children

This commit is contained in:
Chord 2016-05-13 23:31:27 -04:00
parent 10f54b7980
commit 4d52d65b33
5 changed files with 110 additions and 33 deletions

View file

@ -15,7 +15,8 @@ import net.psforever.packet.control.{ClientStart, ServerStart}
import net.psforever.packet.crypto._ import net.psforever.packet.crypto._
/** /**
* Actor that stores crypto state for a connection and filters away any packet metadata. * Actor that stores crypto state for a connection, appropriately encrypts and decrypts packets,
* and passes packets along to the next hop once processed.
*/ */
class CryptoSessionActor extends Actor with MDCContextAware { class CryptoSessionActor extends Actor with MDCContextAware {
private[this] val log = org.log4s.getLogger private[this] val log = org.log4s.getLogger
@ -23,7 +24,7 @@ class CryptoSessionActor extends Actor with MDCContextAware {
var leftRef : ActorRef = ActorRef.noSender var leftRef : ActorRef = ActorRef.noSender
var rightRef : ActorRef = ActorRef.noSender var rightRef : ActorRef = ActorRef.noSender
var cryptoDHState = new CryptoInterface.CryptoDHState() var cryptoDHState : Option[CryptoInterface.CryptoDHState] = None
var cryptoState : Option[CryptoInterface.CryptoStateWithMAC] = None var cryptoState : Option[CryptoInterface.CryptoStateWithMAC] = None
val random = new SecureRandom() val random = new SecureRandom()
@ -36,6 +37,11 @@ class CryptoSessionActor extends Actor with MDCContextAware {
var clientChallenge = ByteVector.empty var clientChallenge = ByteVector.empty
var clientChallengeResult = ByteVector.empty var clientChallengeResult = ByteVector.empty
// Don't leak crypto object memory even on an exception
override def postStop() = {
cleanupCrypto()
}
def receive = Initializing def receive = Initializing
def Initializing : Receive = { def Initializing : Receive = {
@ -83,8 +89,13 @@ class CryptoSessionActor extends Actor with MDCContextAware {
p match { p match {
case CryptoPacket(seq, ClientChallengeXchg(time, challenge, p, g)) => case CryptoPacket(seq, ClientChallengeXchg(time, challenge, p, g)) =>
cryptoDHState = Some(new CryptoInterface.CryptoDHState())
val dh = cryptoDHState.get
// initialize our crypto state from the client's P and G // initialize our crypto state from the client's P and G
cryptoDHState.start(p, g) dh.start(p, g)
// save the client challenge // save the client challenge
clientChallenge = ServerChallengeXchg.getCompleteChallenge(time, challenge) clientChallenge = ServerChallengeXchg.getCompleteChallenge(time, challenge)
@ -99,7 +110,7 @@ class CryptoSessionActor extends Actor with MDCContextAware {
serverChallenge = ServerChallengeXchg.getCompleteChallenge(serverTime, randomChallenge) serverChallenge = ServerChallengeXchg.getCompleteChallenge(serverTime, randomChallenge)
val packet = PacketCoding.CreateCryptoPacket(seq, val packet = PacketCoding.CreateCryptoPacket(seq,
ServerChallengeXchg(serverTime, randomChallenge, cryptoDHState.getPublicKey)) ServerChallengeXchg(serverTime, randomChallenge, dh.getPublicKey))
val sentPacket = sendResponse(packet) val sentPacket = sendResponse(packet)
@ -128,7 +139,11 @@ class CryptoSessionActor extends Actor with MDCContextAware {
// save the packet we got for a MAC check later // save the packet we got for a MAC check later
serverMACBuffer ++= msg.drop(3) serverMACBuffer ++= msg.drop(3)
val agreedValue = cryptoDHState.agree(clientPublicKey) val dh = cryptoDHState.get
val agreedValue = dh.agree(clientPublicKey)
// we are now done with the DH crypto object
dh.close
/*println("Agreed: " + agreedValue) /*println("Agreed: " + agreedValue)
println(s"Client challenge: $clientChallenge")*/ println(s"Client challenge: $clientChallenge")*/
@ -232,17 +247,23 @@ class CryptoSessionActor extends Actor with MDCContextAware {
log.error(error) log.error(error)
} }
def resetState() : Unit = { def cleanupCrypto() = {
context.become(receive) if(cryptoDHState.isDefined) {
cryptoDHState.get.close
// reset the crypto primitives cryptoDHState = None
cryptoDHState.close }
cryptoDHState = new CryptoInterface.CryptoDHState()
if(cryptoState.isDefined) { if(cryptoState.isDefined) {
cryptoState.get.close cryptoState.get.close
cryptoState = None cryptoState = None
} }
}
def resetState() : Unit = {
context.become(receive)
// reset the crypto primitives
cleanupCrypto()
serverChallenge = ByteVector.empty serverChallenge = ByteVector.empty
serverChallengeResult = ByteVector.empty serverChallengeResult = ByteVector.empty

View file

@ -98,11 +98,7 @@ class LoginSessionActor extends Actor with MDCContextAware {
0, 1, 2, 685276011, username, 0, false) 0, 1, 2, 685276011, username, 0, false)
sendResponse(PacketCoding.CreateGamePacket(0, response)) sendResponse(PacketCoding.CreateGamePacket(0, response))
updateServerList
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
context.system.scheduler.schedule(0 seconds, 250 milliseconds, self, UpdateServerList())
case default => log.debug(s"Unhandled GamePacket ${pkt}") case default => log.debug(s"Unhandled GamePacket ${pkt}")
} }

View file

@ -64,7 +64,7 @@ object PsLogin {
//val system = ActorSystem("PsLogin", Some(ConfigFactory.parseMap(config)), None, Some(MDCPropagatingExecutionContextWrapper(ExecutionContext.Implicits.global))) //val system = ActorSystem("PsLogin", Some(ConfigFactory.parseMap(config)), None, Some(MDCPropagatingExecutionContextWrapper(ExecutionContext.Implicits.global)))
val system = ActorSystem("PsLogin", ConfigFactory.parseMap(config)) val system = ActorSystem("PsLogin", ConfigFactory.parseMap(config))
val session = system.actorOf(Props[SessionRouter], "session-router") val listener = system.actorOf(Props(new UdpListener(Props[SessionRouter], "session-router",
val listener = system.actorOf(Props(new UdpListener(session, InetAddress.getLocalHost, 51000)), "login-udp-endpoint") InetAddress.getLocalHost, 51000)), "login-udp-endpoint")
} }
} }

View file

@ -8,6 +8,7 @@ import scodec.bits._
import scala.collection.mutable import scala.collection.mutable
import MDCContextAware.Implicits._ import MDCContextAware.Implicits._
import akka.actor.MDCContextAware.MdcMsg import akka.actor.MDCContextAware.MdcMsg
import akka.actor.SupervisorStrategy.Stop
final case class RawPacket(data : ByteVector) final case class RawPacket(data : ByteVector)
final case class ResponsePacket(data : ByteVector) final case class ResponsePacket(data : ByteVector)
@ -27,6 +28,12 @@ class SessionRouter extends Actor with MDCContextAware {
var sessionId = 0L // this is a connection session, not an actual logged in session ID var sessionId = 0L // this is a connection session, not an actual logged in session ID
var inputRef : ActorRef = ActorRef.noSender var inputRef : ActorRef = ActorRef.noSender
override def supervisorStrategy = OneForOneStrategy() { case _ => Stop }
override def preStart = {
log.info("SessionRouter started...ready for PlanetSide sessions")
}
/* /*
Login sessions are divided between two actors. the crypto session actor transparently handles all of the cryptographic Login sessions are divided between two actors. the crypto session actor transparently handles all of the cryptographic
setup of the connection. Once a correct crypto session has been established, all packets, after being decrypted setup of the connection. Once a correct crypto session has been established, all packets, after being decrypted
@ -47,9 +54,12 @@ class SessionRouter extends Actor with MDCContextAware {
def initializing : Receive = { def initializing : Receive = {
case Hello() => case Hello() =>
inputRef = sender() inputRef = sender()
inputRef ! SendPacket(hex"41414141", new InetSocketAddress("8.8.8.8", 51000))
context.become(started) context.become(started)
case _ => case default =>
log.error("Unknown message") log.error(s"Unknown message $default. Stopping...")
context.stop(self) context.stop(self)
} }
@ -67,7 +77,8 @@ class SessionRouter extends Actor with MDCContextAware {
idBySocket{from} = session.id idBySocket{from} = session.id
sessionById{session.id} = session sessionById{session.id} = session
sessionByActor{session.pipeline.head} = session sessionByActor{session.startOfPipe} = session
sessionByActor{session.nextOfStart} = session
MDC("sessionId") = session.id.toString MDC("sessionId") = session.id.toString
@ -80,25 +91,56 @@ class SessionRouter extends Actor with MDCContextAware {
MDC.clear() MDC.clear()
} }
case ResponsePacket(msg) => case ResponsePacket(msg) =>
val session = sessionByActor{sender()} val session = sessionByActor.get(sender())
log.trace(s"Sending response ${msg}") // drop any old queued messages from old actors
//if(session.isDefined) {
log.trace(s"Sending response ${msg}")
inputRef ! SendPacket(msg, session.address) inputRef ! SendPacket(msg, session.get.address)
case _ => log.error("Unknown message") //}
case Terminated(actor) =>
val terminatedSession = sessionByActor.get(actor)
if(terminatedSession.isDefined) {
removeSessionById(terminatedSession.get.id, s"${actor.path.name} died")
}
case default =>
log.error(s"Unknown message $default")
} }
def createNewSession(address : InetSocketAddress) = { def createNewSession(address : InetSocketAddress) = {
val id = newSessionId val id = newSessionId
val cryptoSession = context.actorOf(Props[CryptoSessionActor], val cryptoSession = context.actorOf(Props[CryptoSessionActor],
"crypto-session" + id.toString) "crypto-session-" + id.toString)
val loginSession = context.actorOf(Props[LoginSessionActor], val loginSession = context.actorOf(Props[LoginSessionActor],
"login-session" + id.toString) "login-session-" + id.toString)
context.watch(cryptoSession)
context.watch(loginSession)
SessionState(id, address, List(cryptoSession, loginSession)) SessionState(id, address, List(cryptoSession, loginSession))
} }
def removeSessionById(id : Long, reason : String) : Unit = {
val sessionOption = sessionById.get(id)
if(!sessionOption.isDefined)
return
val session = sessionOption.get
// TODO: add some sort of delay to prevent old session packets from coming through
// kill all session specific actors
session.pipeline.foreach(_ ! PoisonPill)
session.pipeline.foreach(sessionByActor remove _)
sessionById.remove(id)
idBySocket.remove(session.address)
log.info(s"Stopping session ${id} (reason: $reason)")
}
def newSessionId = { def newSessionId = {
val oldId = sessionId val oldId = sessionId
sessionId += 1 sessionId += 1

View file

@ -1,7 +1,8 @@
// Copyright (c) 2016 PSForever.net to present // Copyright (c) 2016 PSForever.net to present
import java.net.{InetAddress, InetSocketAddress} import java.net.{InetAddress, InetSocketAddress}
import akka.actor.{Actor, ActorLogging, ActorRef, Identify} import akka.actor.SupervisorStrategy.{Restart, Stop}
import akka.actor.{Actor, ActorRef, OneForOneStrategy, Props, Terminated}
import akka.io._ import akka.io._
import scodec.bits.ByteVector import scodec.bits.ByteVector
import scodec.interop.akka._ import scodec.interop.akka._
@ -11,21 +12,29 @@ final case class SendPacket(msg : ByteVector, to : InetSocketAddress)
final case class Hello() final case class Hello()
final case class HelloFriend(next: ActorRef) final case class HelloFriend(next: ActorRef)
class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extends Actor { class UdpListener(nextActorProps : Props, nextActorName : String, address : InetAddress, port : Int) extends Actor {
private val logger = org.log4s.getLogger private val log = org.log4s.getLogger
override def supervisorStrategy = OneForOneStrategy() {
case _ => Stop
}
import context.system import context.system
IO(Udp) ! Udp.Bind(self, new InetSocketAddress(address, port)) IO(Udp) ! Udp.Bind(self, new InetSocketAddress(address, port))
var bytesRecevied = 0L var bytesRecevied = 0L
var bytesSent = 0L var bytesSent = 0L
var nextActor : ActorRef = Actor.noSender
def receive = { def receive = {
case Udp.Bound(local) => case Udp.Bound(local) =>
logger.info(s"Now listening on UDP:$local") log.info(s"Now listening on UDP:$local")
createNextActor()
nextActor ! Hello()
context.become(ready(sender())) context.become(ready(sender()))
case default =>
log.error(s"Unexpected message $default")
} }
def ready(socket: ActorRef): Receive = { def ready(socket: ActorRef): Receive = {
@ -37,6 +46,15 @@ class UdpListener(nextActor: ActorRef, address : InetAddress, port : Int) extend
nextActor ! ReceivedPacket(data.toByteVector, remote) nextActor ! ReceivedPacket(data.toByteVector, remote)
case Udp.Unbind => socket ! Udp.Unbind case Udp.Unbind => socket ! Udp.Unbind
case Udp.Unbound => context.stop(self) case Udp.Unbound => context.stop(self)
case default => logger.error(s"Unhandled message: $default") case Terminated(actor) =>
log.error(s"Next actor ${actor.path.name} has died...restarting")
createNextActor()
case default => log.error(s"Unhandled message: $default")
}
def createNextActor() = {
nextActor = context.actorOf(nextActorProps, nextActorName)
context.watch(nextActor)
nextActor ! Hello()
} }
} }