diff --git a/build.sbt b/build.sbt index ad9736b1..19a934fd 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,8 @@ lazy val commonSettings = Seq( "org.fusesource.jansi" % "jansi" % "1.12", "org.scoverage" %% "scalac-scoverage-plugin" % "1.1.1", "com.github.nscala-time" %% "nscala-time" % "2.12.0", - "com.github.mauricio" %% "mysql-async" % "0.2.21" + "com.github.mauricio" %% "mysql-async" % "0.2.21", + "org.ini4j" % "ini4j" % "0.5.4" ) ) diff --git a/common/src/main/scala/net/psforever/config/ConfigParser.scala b/common/src/main/scala/net/psforever/config/ConfigParser.scala new file mode 100644 index 00000000..173a27eb --- /dev/null +++ b/common/src/main/scala/net/psforever/config/ConfigParser.scala @@ -0,0 +1,183 @@ +// Copyright (c) 2019 PSForever +package net.psforever.config + +import org.ini4j +import scala.reflect.ClassTag +import scala.annotation.implicitNotFound +import scala.concurrent.duration._ + +case class ConfigValueMapper[T](name: String)(f: (String => Option[T])) { + def apply(t: String): Option[T] = f(t) +} + +object ConfigValueMapper { + implicit val toInt : ConfigValueMapper[Int] = ConfigValueMapper[Int]("toInt") { e => + try { + Some(e.toInt) + } catch { + case e: Exception => None + } + } + + implicit val toBool : ConfigValueMapper[Boolean] = ConfigValueMapper[Boolean]("toBool") { e => + if (e == "yes") { + Some(true) + } else if (e == "no") { + Some(false) + } else { + None + } + } + + implicit val toFloat : ConfigValueMapper[Float] = ConfigValueMapper[Float]("toFloat") { e => + try { + Some(e.toFloat) + } catch { + case e: Exception => None + } + } + + implicit val toDuration : ConfigValueMapper[Duration] = ConfigValueMapper[Duration]("toDuration") { e => + try { + Some(Duration(e)) + } catch { + case e: Exception => None + } + } + + implicit val toStr : ConfigValueMapper[String] = ConfigValueMapper[String]("toString") { e => + Some(e) + } +} + +sealed trait ConfigEntry { + type Value + val key : String + val default : Value + def getType : String + val constraints : Seq[Constraint[Value]] + def read(v: String): Option[Value] +} + +final case class ConfigEntryString(key: String, default : String, constraints : Constraint[String]*) extends ConfigEntry { + type Value = String + def getType = "String" + def read(v : String) = ConfigValueMapper.toStr(v) +} + +final case class ConfigEntryInt(key: String, default : Int, constraints : Constraint[Int]*) extends ConfigEntry { + type Value = Int + def getType = "Int" + def read(v : String) = ConfigValueMapper.toInt(v) +} + +final case class ConfigEntryBool(key: String, default : Boolean, constraints : Constraint[Boolean]*) extends ConfigEntry { + type Value = Boolean + def getType = "Bool" + def read(v : String) = ConfigValueMapper.toBool(v) +} + +final case class ConfigEntryFloat(key: String, default : Float, constraints : Constraint[Float]*) extends ConfigEntry { + type Value = Float + def getType = "Float" + def read(v : String) = ConfigValueMapper.toFloat(v) +} + +final case class ConfigEntryTime(key: String, default : Duration, constraints : Constraint[Duration]*) extends ConfigEntry { + type Value = Duration + def getType = "Time" + def read(v : String) = ConfigValueMapper.toDuration(v) +} + +case class ConfigSection(name: String, entries: ConfigEntry*) + +@implicitNotFound("Nothing was inferred") +sealed trait ConfigTypeRequired[-T] + +object ConfigTypeRequired { + implicit object cfgTypeRequired extends ConfigTypeRequired[Any] + //We do not want Nothing to be inferred, so make an ambiguous implicit + implicit object `\n The Get[T] call needs a type T matching the corresponding ConfigEntry` extends ConfigTypeRequired[Nothing] +} + +trait ConfigParser { + protected var config_map : Map[String, Any] + protected val config_template : Seq[ConfigSection] + + // Misuse of this function can lead to run time exceptions when the types don't match + def Get[T : ConfigTypeRequired](key : String) : T = config_map(key).asInstanceOf[T] + + def Load(filename : String) : ValidationResult = { + val ini = new org.ini4j.Ini() + config_map = Map() + + try { + ini.load(new java.io.File(filename)) + } catch { + case e : org.ini4j.InvalidFileFormatException => + return Invalid(e.getMessage) + case e : java.io.FileNotFoundException => + return Invalid(e.getMessage) + } + + val result : Seq[ValidationResult] = config_template.map { section => + val sectionIni = ini.get(section.name) + + if (sectionIni == null) + Seq(Invalid("section.missing", section.name)) + else + section.entries.map(parseSection(sectionIni, _)) + }.reduceLeft((x, y) => x ++ y) + + val errors : Seq[Invalid] = result.collect { case iv : Invalid => iv } + + if (errors.length > 0) + errors.reduceLeft((x, y) => x ++ y) + else + // run post-parse validation only if we successfully parsed + postParseChecks + } + + def FormatErrors(invalidResult : Invalid) : Seq[String] = { + var count = 0; + + invalidResult.errors.map((error) => { + var message = error.message; + + if (error.args.length > 0) + message += " ("+error.args(0)+")" + + count += 1; + s"Error ${count}: ${message}" + }); + } + + protected def postParseChecks : ValidationResult = { + Valid + } + + protected def parseSection(sectionIni : org.ini4j.Profile.Section, entry : ConfigEntry) : ValidationResult = { + var rawValue = sectionIni.get(entry.key) + val full_key = sectionIni.getName + "." + entry.key + + val value = if (rawValue == null) { + // warn about defaults from unset parameters? + entry.default + } else { + rawValue = rawValue.stripPrefix("\"").stripSuffix("\"") + + entry.read(rawValue) match { + case Some(v) => v + case None => return Invalid(ValidationError(String.format("%s: value format error (expected: %s)", full_key, entry.getType))) + } + } + config_map += (full_key -> value) + + ParameterValidator(entry.constraints, Some(value)) match { + case v @ Valid => v + case i @ Invalid(errors) => { + Invalid(errors.map(x => ValidationError(x.messages.map(full_key + ": " + _), x.args: _*))) + } + } + } +} diff --git a/common/src/main/scala/net/psforever/config/ConfigValidation.scala b/common/src/main/scala/net/psforever/config/ConfigValidation.scala new file mode 100644 index 00000000..b50a2039 --- /dev/null +++ b/common/src/main/scala/net/psforever/config/ConfigValidation.scala @@ -0,0 +1,290 @@ +// Copyright (c) 2019 PSForever +// Lifted from https://raw.githubusercontent.com/playframework/playframework/2.7.x/core/play/src/main/scala/play/api/data/validation/Validation.scala +package net.psforever.config + + +/** + * A form constraint. + * + * @tparam T type of values handled by this constraint + * @param name the constraint name, to be displayed to final user + * @param args the message arguments, to format the constraint name + * @param f the validation function + */ +case class Constraint[-T](name: Option[String], args: Seq[Any])(f: (T => ValidationResult)) { + + /** + * Run the constraint validation. + * + * @param t the value to validate + * @return the validation result + */ + def apply(t: T): ValidationResult = f(t) +} + +/** + * This object provides helpers for creating `Constraint` values. + * + * For example: + * {{{ + * val negative = Constraint[Int] { + * case i if i < 0 => Valid + * case _ => Invalid("Must be a negative number.") + * } + * }}} + */ +object Constraint { + + /** + * Creates a new anonymous constraint from a validation function. + * + * @param f the validation function + * @return a constraint + */ + def apply[T](f: (T => ValidationResult)): Constraint[T] = apply(None, Nil)(f) + + /** + * Creates a new named constraint from a validation function. + * + * @param name the constraint name + * @param args the constraint arguments, used to format the constraint name + * @param f the validation function + * @return a constraint + */ + def apply[T](name: String, args: Any*)(f: (T => ValidationResult)): Constraint[T] = apply(Some(name), args.toSeq)(f) + +} + +/** + * Defines a set of built-in constraints. + */ +object Constraints extends Constraints + +/** + * Defines a set of built-in constraints. + * + * @define emailAddressDoc Defines an ‘emailAddress’ constraint for `String` values which will validate email addresses. + * + * '''name'''[constraint.email] + * '''error'''[error.email] + * + * @define nonEmptyDoc Defines a ‘required’ constraint for `String` values, i.e. one in which empty strings are invalid. + * + * '''name'''[constraint.required] + * '''error'''[error.required] + */ +trait Constraints { + + private val emailRegex = + """^[a-zA-Z0-9\.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$""".r + + /** + * $emailAddressDoc + */ + def emailAddress(errorMessage: String = "error.email"): Constraint[String] = Constraint[String]("constraint.email") { + e => + if (e == null) Invalid(ValidationError(errorMessage)) + else if (e.trim.isEmpty) Invalid(ValidationError(errorMessage)) + else + emailRegex + .findFirstMatchIn(e) + .map(_ => Valid) + .getOrElse(Invalid(ValidationError(errorMessage))) + } + + /** + * $emailAddressDoc + * + */ + def emailAddress: Constraint[String] = emailAddress() + + /** + * $nonEmptyDoc + */ + def nonEmpty(errorMessage: String = "error.required"): Constraint[String] = + Constraint[String]("constraint.required") { o => + if (o == null) Invalid(ValidationError(errorMessage)) + else if (o.trim.isEmpty) Invalid(ValidationError(errorMessage)) + else Valid + } + + /** + * $nonEmptyDoc + * + */ + def nonEmpty: Constraint[String] = nonEmpty() + + /** + * Defines a minimum value for `Ordered` values, by default the value must be greater than or equal to the constraint parameter + * + * '''name'''[constraint.min(minValue)] + * '''error'''[error.min(minValue)] or [error.min.strict(minValue)] + */ + def min[T]( + minValue: T, + strict: Boolean = false, + errorMessage: String = "error.min", + strictErrorMessage: String = "error.min.strict" + )(implicit ordering: scala.math.Ordering[T]): Constraint[T] = Constraint[T]("constraint.min", minValue) { o => + (ordering.compare(o, minValue).signum, strict) match { + case (1, _) | (0, false) => Valid + case (_, false) => Invalid(ValidationError(errorMessage, minValue)) + case (_, true) => Invalid(ValidationError(strictErrorMessage, minValue)) + } + } + + /** + * Defines a maximum value for `Ordered` values, by default the value must be less than or equal to the constraint parameter + * + * '''name'''[constraint.max(maxValue)] + * '''error'''[error.max(maxValue)] or [error.max.strict(maxValue)] + */ + def max[T]( + maxValue: T, + strict: Boolean = false, + errorMessage: String = "error.max", + strictErrorMessage: String = "error.max.strict" + )(implicit ordering: scala.math.Ordering[T]): Constraint[T] = Constraint[T]("constraint.max", maxValue) { o => + (ordering.compare(o, maxValue).signum, strict) match { + case (-1, _) | (0, false) => Valid + case (_, false) => Invalid(ValidationError(errorMessage, maxValue)) + case (_, true) => Invalid(ValidationError(strictErrorMessage, maxValue)) + } + } + + /** + * Defines a minimum length constraint for `String` values, i.e. the string’s length must be greater than or equal to the constraint parameter + * + * '''name'''[constraint.minLength(length)] + * '''error'''[error.minLength(length)] + */ + def minLength(length: Int, errorMessage: String = "error.minLength"): Constraint[String] = + Constraint[String]("constraint.minLength", length) { o => + require(length >= 0, "string minLength must not be negative") + if (o == null) Invalid(ValidationError(errorMessage, length)) + else if (o.size >= length) Valid + else Invalid(ValidationError(errorMessage, length)) + } + + /** + * Defines a maximum length constraint for `String` values, i.e. the string’s length must be less than or equal to the constraint parameter + * + * '''name'''[constraint.maxLength(length)] + * '''error'''[error.maxLength(length)] + */ + def maxLength(length: Int, errorMessage: String = "error.maxLength"): Constraint[String] = + Constraint[String]("constraint.maxLength", length) { o => + require(length >= 0, "string maxLength must not be negative") + if (o == null) Invalid(ValidationError(errorMessage, length)) + else if (o.size <= length) Valid + else Invalid(ValidationError(errorMessage, length)) + } + + /** + * Defines a regular expression constraint for `String` values, i.e. the string must match the regular expression pattern + * + * '''name'''[constraint.pattern(regex)] or defined by the name parameter. + * '''error'''[error.pattern(regex)] or defined by the error parameter. + */ + def pattern( + regex: => scala.util.matching.Regex, + name: String = "constraint.pattern", + error: String = "error.pattern" + ): Constraint[String] = Constraint[String](name, () => regex) { o => + require(regex != null, "regex must not be null") + require(name != null, "name must not be null") + require(error != null, "error must not be null") + + if (o == null) Invalid(ValidationError(error, regex)) + else regex.unapplySeq(o).map(_ => Valid).getOrElse(Invalid(ValidationError(error, regex))) + } + +} + +/** + * A validation result. + */ +sealed trait ValidationResult + +/** + * Validation was a success. + */ +case object Valid extends ValidationResult + +/** + * Validation was a failure. + * + * @param errors the resulting errors + */ +case class Invalid(errors: Seq[ValidationError]) extends ValidationResult { + + /** + * Combines these validation errors with another validation failure. + * + * @param other validation failure + * @return a new merged `Invalid` + */ + def ++(other: Invalid): Invalid = Invalid(this.errors ++ other.errors) +} + +/** + * This object provides helper methods to construct `Invalid` values. + */ +object Invalid { + + /** + * Creates an `Invalid` value with a single error. + * + * @param error the validation error + * @return an `Invalid` value + */ + def apply(error: ValidationError): Invalid = Invalid(Seq(error)) + + /** + * Creates an `Invalid` value with a single error. + * + * @param error the validation error message + * @param args the validation error message arguments + * @return an `Invalid` value + */ + def apply(error: String, args: Any*): Invalid = Invalid(Seq(ValidationError(error, args: _*))) +} + +object ParameterValidator { + def apply[T](constraints: Iterable[Constraint[T]], optionalParam: Option[T]*) = + optionalParam.flatMap { + _.map { param => + constraints.flatMap { + _(param) match { + case i: Invalid => Some(i) + case _ => None + } + } + } + }.flatten match { + case Nil => Valid + case invalids => + invalids.reduceLeft { (a, b) => + a ++ b + } + } +} + +/** + * A validation error. + * + * @param messages the error message, if more then one message is passed it will use the last one + * @param args the error message arguments + */ +case class ValidationError(messages: Seq[String], args: Any*) { + + lazy val message = messages.last + +} + +object ValidationError { + + + def apply(message: String, args: Any*) = new ValidationError(Seq(message), args: _*) + +} diff --git a/config/worldserver.ini.dist b/config/worldserver.ini.dist new file mode 100644 index 00000000..eb12d0e0 --- /dev/null +++ b/config/worldserver.ini.dist @@ -0,0 +1,9 @@ +####################################### +# PSForever Server configuration file # +####################################### +[worldserver] +ListeningPort = 51001 + +[loginserver] +ListeningPort = 51000 + diff --git a/pslogin/src/main/scala/PsLogin.scala b/pslogin/src/main/scala/PsLogin.scala index 58183746..04b7c051 100644 --- a/pslogin/src/main/scala/PsLogin.scala +++ b/pslogin/src/main/scala/PsLogin.scala @@ -11,6 +11,7 @@ import ch.qos.logback.core.joran.spi.JoranException import ch.qos.logback.core.status._ import ch.qos.logback.core.util.StatusPrinter import com.typesafe.config.ConfigFactory +import net.psforever.config.{Valid, Invalid} import net.psforever.crypto.CryptoInterface import net.psforever.objects.zones._ import net.psforever.objects.guid.TaskResolver @@ -102,6 +103,33 @@ object PsLogin { } } + def loadConfig(configDirectory : String) = { + val worldConfigFile = configDirectory + File.separator + "worldserver.ini" + // For fallback when no user-specific config file has been created + val worldDefaultConfigFile = configDirectory + File.separator + "worldserver.ini.dist" + + val worldConfigToLoad = if ((new File(worldConfigFile)).exists()) { + worldConfigFile + } else if ((new File(worldDefaultConfigFile)).exists()) { + println("WARNING: loading the default worldserver.ini.dist config file") + println("WARNING: Please create a worldserver.ini file to override server defaults") + + worldDefaultConfigFile + } else { + println("FATAL: unable to load any worldserver.ini file") + sys.exit(1) + } + + WorldConfig.Load(worldConfigToLoad) match { + case Valid => + println("Loaded world config from " + worldConfigFile) + case i : Invalid => + println("FATAL: Error loading config from " + worldConfigFile) + println(WorldConfig.FormatErrors(i).mkString("\n")) + sys.exit(1) + } + } + def parseArgs(args : Array[String]) : Unit = { if(args.length == 1) { LoginConfig.serverIpAddress = InetAddress.getByName(args{0}) @@ -125,9 +153,15 @@ object PsLogin { configDirectory = System.getProperty("prog.home") + File.separator + "config" } - initializeLogging(configDirectory + File.separator + "logback.xml") parseArgs(this.args) + val loggingConfigFile = configDirectory + File.separator + "logback.xml" + + loadConfig(configDirectory) + + println(s"Initializing logging from ${loggingConfigFile}...") + initializeLogging(loggingConfigFile) + /** Initialize the PSCrypto native library * * PSCrypto provides PlanetSide specific crypto that is required to communicate with it. @@ -194,9 +228,8 @@ object PsLogin { SessionPipeline("world-session-", Props[WorldSessionActor]) ) - val loginServerPort = 51000 - val worldServerPort = 51001 - + val loginServerPort = WorldConfig.Get[Int]("loginserver.ListeningPort") + val worldServerPort = WorldConfig.Get[Int]("worldserver.ListeningPort") // Uncomment for network simulation // TODO: make this config or command flag diff --git a/pslogin/src/main/scala/WorldConfig.scala b/pslogin/src/main/scala/WorldConfig.scala new file mode 100644 index 00000000..5f5bff41 --- /dev/null +++ b/pslogin/src/main/scala/WorldConfig.scala @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PSForever +import net.psforever.config._ + +object WorldConfig extends ConfigParser { + protected var config_map : Map[String, Any] = Map() + + protected val config_template = Seq( + ConfigSection("loginserver", + ConfigEntryInt("ListeningPort", 51000, Constraints.min(1), Constraints.max(65535)) + ), + ConfigSection("worldserver", + ConfigEntryInt("ListeningPort", 51001, Constraints.min(1), Constraints.max(65535)) + ) + ) + + override def postParseChecks : ValidationResult = { + var errors : Invalid = Invalid("") + + if (Get[Int]("worldserver.ListeningPort") == Get[Int]("loginserver.ListeningPort")) + errors = errors ++ Invalid("worldserver.ListeningPort must be different from loginserver.ListeningPort") + + if (errors.errors.length > 1) + // drop the first error using tail (it was a placeholder) + Invalid(errors.errors.tail) + else + Valid + } +} diff --git a/pslogin/src/test/resources/testconfig.ini b/pslogin/src/test/resources/testconfig.ini new file mode 100644 index 00000000..64c98ee4 --- /dev/null +++ b/pslogin/src/test/resources/testconfig.ini @@ -0,0 +1,19 @@ +# This is a comment +[default] +string = a string +string_quoted = "a string" +int = 31 +time = 1 second +time2 = 100 milliseconds +float = 0.1 +bool_true = yes +bool_false = no +# missing + +[bad] +bad_int = not a number +bad_time = 10 +bad_float = A +bad_bool = dunno +bad_int_range = -1 +bad_int_range2 = 3 diff --git a/pslogin/src/test/scala/ConfigTest.scala b/pslogin/src/test/scala/ConfigTest.scala new file mode 100644 index 00000000..56f9285c --- /dev/null +++ b/pslogin/src/test/scala/ConfigTest.scala @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PSForever +import scala.io.Source +import org.specs2.mutable._ +import net.psforever.config._ +import scala.concurrent.duration._ + +class ConfigTest extends Specification { + val testConfig = getClass.getResource("/testconfig.ini").getPath + + "WorldConfig" should { + "have no errors" in { + WorldConfig.Load("config/worldserver.ini.dist") mustEqual Valid + } + } + + "TestConfig" should { + "parse" in { + TestConfig.Load(testConfig) mustEqual Valid + TestConfig.Get[String]("default.string") mustEqual "a string" + TestConfig.Get[String]("default.string_quoted") mustEqual "a string" + TestConfig.Get[Int]("default.int") mustEqual 31 + TestConfig.Get[Duration]("default.time") mustEqual (1 second) + TestConfig.Get[Duration]("default.time2") mustEqual (100 milliseconds) + TestConfig.Get[Float]("default.float") mustEqual 0.1f + TestConfig.Get[Boolean]("default.bool_true") mustEqual true + TestConfig.Get[Boolean]("default.bool_false") mustEqual false + TestConfig.Get[Int]("default.missing") mustEqual 1337 + } + } + + "TestBadConfig" should { + "not parse" in { + val error = TestBadConfig.Load(testConfig).asInstanceOf[Invalid] + val check_errors = List( + ValidationError("bad.bad_int: value format error (expected: Int)"), + ValidationError("bad.bad_time: value format error (expected: Time)"), + ValidationError("bad.bad_float: value format error (expected: Float)"), + ValidationError("bad.bad_bool: value format error (expected: Bool)"), + ValidationError("bad.bad_int_range: error.min", 0), + ValidationError("bad.bad_int_range2: error.max", 2) + ) + + error.errors mustEqual check_errors + } + } +} + +object TestConfig extends ConfigParser { + protected var config_map : Map[String, Any] = Map() + + protected val config_template = Seq( + ConfigSection("default", + ConfigEntryString("string", ""), + ConfigEntryString("string_quoted", ""), + ConfigEntryInt("int", 0), + ConfigEntryTime("time", 0 seconds), + ConfigEntryTime("time2", 0 seconds), + ConfigEntryFloat("float", 0.0f), + ConfigEntryBool("bool_true", false), + ConfigEntryBool("bool_false", true), + ConfigEntryInt("missing", 1337) + ) + ) +} + +object TestBadConfig extends ConfigParser { + protected var config_map : Map[String, Any] = Map() + + protected val config_template = Seq( + ConfigSection("bad", + ConfigEntryInt("bad_int", 0), + ConfigEntryTime("bad_time", 0 seconds), + ConfigEntryFloat("bad_float", 0.0f), + ConfigEntryBool("bad_bool", false), + ConfigEntryInt("bad_int_range", 0, Constraints.min(0)), + ConfigEntryInt("bad_int_range2", 0, Constraints.min(0), Constraints.max(2)) + ) + ) +}