diff --git a/CMakeLists.txt b/CMakeLists.txt index 881633c..eda975c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,10 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(BUILD_ZECSY_TESTS "Build tests?" ON) +find_package(magic_enum) + add_library(s2ga STATIC s2ga.hpp) +target_link_libraries(s2ga PUBLIC magic_enum::magic_enum) set_target_properties(s2ga PROPERTIES LINKER_LANGUAGE CXX) ####################################################### diff --git a/conanfile.txt b/conanfile.txt index 30594fc..d8e61e0 100644 --- a/conanfile.txt +++ b/conanfile.txt @@ -1,5 +1,6 @@ [requires] catch2/3.8.0 +magic_enum/0.9.7 [generators] CMakeDeps diff --git a/s2ga.hpp b/s2ga.hpp index 3adad1b..9669893 100644 --- a/s2ga.hpp +++ b/s2ga.hpp @@ -1,7 +1,9 @@ -#include -#include #include -#include +#include +#include +#include +#include +#include namespace s2ga { @@ -67,6 +69,82 @@ namespace s2ga }; static_assert(std::uniform_random_bit_generator); + + template + requires std::is_enum_v && ((std::is_same_v), ...) + std::bitset()> bm(const Args&... v) //aka "bitmask" + { + constexpr size_t N = magic_enum::enum_count(); + std::bitset bitmask; + + ((bitmask.set(magic_enum::enum_integer(v))), ...); + return bitmask; + } + + template + requires std::is_enum_v + struct action + { + static constexpr size_t N = magic_enum::enum_count(); + + std::string name = "ACTION"; + float cost = 1.0f; + + std::bitset positive_effects{}; + std::bitset negative_effects{}; + std::bitset positive_preconds{}; + std::bitset negative_preconds{}; + + bool preconds_met(const std::bitset& state) const + { + return ((state & positive_preconds) == positive_preconds) && + (state & ~negative_preconds).none(); + } + + std::bitset apply(const std::bitset& state) const + { + return (state | positive_effects) & ~negative_effects; + } + + void print() + { + std::cout << std::format("[{}, {}]", name, cost) << std::endl; + + if(positive_effects.any() || negative_effects.any()) + { + std::cout << "Effects:" << std::endl; + for(size_t i = 0; i < N; ++i) + { + if(positive_effects.test(i) || negative_effects.test(i)) + { + auto enum_value = magic_enum::enum_cast(i).value(); + auto value_name = magic_enum::enum_name(enum_value); + + std::cout << std::format(" {}{}", + negative_effects.test(i) ? "-" : "+", + value_name) << std::endl; + } + } + } + + if(positive_preconds.any() || negative_preconds.any()) + { + std::cout << "Preconditions:" << std::endl; + for(size_t i = 0; i < N; ++i) + { + if(positive_preconds.test(i) || negative_preconds.test(i)) + { + auto enum_value = magic_enum::enum_cast(i).value(); + auto value_name = magic_enum::enum_name(enum_value); + + std::cout << std::format(" {}{}", + negative_preconds.test(i) ? "-" : "+", + value_name) << std::endl; + } + } + } + } + }; }; // namespace s2ga inline void foo() diff --git a/tests/test.cpp b/tests/test.cpp index d81e79c..00997bc 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -5,13 +5,44 @@ #include #include #include -#include #include #include using namespace Catch; +using namespace s2ga; -constexpr double TOLERANCE = 1e-4; +TEST_CASE("(dummy)", "[test]") +{ + enum State + { + HAS_AMMO, + HAS_WEAPON, + ENEMY_ALIVE, + SATISFIED + }; + + action shoot_action; + shoot_action.name = "SHOOT"; + shoot_action.positive_preconds = bm(HAS_AMMO, HAS_WEAPON, ENEMY_ALIVE); + shoot_action.negative_preconds = bm(SATISFIED); + shoot_action.positive_effects = bm(SATISFIED); + shoot_action.negative_effects = bm(ENEMY_ALIVE); + shoot_action.cost = 0.5f; + + shoot_action.print(); + /* + Will print this: + [SHOOT, 0.5] + Effects: + -ENEMY_ALIVE + +SATISFIED + Preconditions: + +HAS_AMMO + +HAS_WEAPON + +ENEMY_ALIVE + -SATISFIED + */ +} TEST_CASE("lehmer64 rng", "[test]") { @@ -166,190 +197,3 @@ TEST_CASE("random benchmarking", "[benchmark]") } }; } - -inline double sphere(double x, double y) -{ - return std::pow(x, 2.0) + std::pow(y, 2.0); -} - -TEST_CASE("sphere(0, 0) = 0", "[test]") -{ - REQUIRE(sphere(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double ackley(double x, double y) -{ - return -20.0 * std::exp(-0.2 * std::sqrt(0.5 * sphere(x, y))) - - std::exp(0.5 * (std::cos(2.0 * std::numbers::pi * x) + - std::cos(2.0 * std::numbers::pi * y))) + - std::numbers::e + 20.0; -} - -TEST_CASE("ackley(0, 0) = 0", "[test]") -{ - REQUIRE(ackley(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double rastrigin(double x, double y) -{ - const double A = 10.0; - return 2.0 * A + - (std::pow(x, 2.0) - A * std::cos(2.0 * std::numbers::pi * x)) + - (std::pow(y, 2.0) - A * std::cos(2.0 * std::numbers::pi * y)); -} - -TEST_CASE("rastrigin(0, 0) = 0", "[test]") -{ - REQUIRE(rastrigin(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double rosenbrock(double x, double y) -{ - return 100.0 * std::pow(y - std::pow(x, 2.0), 2.0) + std::pow(1.0 - x, 2.0); -} - -TEST_CASE("rosenbrock(1, 1) = 0", "[test]") -{ - REQUIRE(rosenbrock(1.0, 1.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double bill(double x, double y) -{ - return std::pow(1.5 - x + x * y, 2.0) + - std::pow(2.25 - x + x * std::pow(y, 2.0), 2.0) + - std::pow(2.625 - x + x * std::pow(y, 3.0), 2.0); -} - -TEST_CASE("bill(3, 0.5) = 0", "[test]") -{ - REQUIRE(bill(3.0, 0.5) == Approx(0.0).margin(TOLERANCE)); -} - -inline double goldstein_price(double x, double y) -{ - return (1.0 + std::pow(x + y + 1.0, 2.0) * - (19.0 - 14.0 * x + 3.0 * std::pow(x, 2.0) - 14.0 * y + - 6.0 * x * y + 3.0 * std::pow(y, 2.0))) * - (30.0 + std::pow(2.0 * x - 3.0 * y, 2.0) * - (18.0 - 32.0 * x + 12.0 * std::pow(x, 2.0) + 48.0 * y - - 36.0 * x * y + 27.0 * std::pow(y, 2.0))); -} - -TEST_CASE("goldstein_price(0, -1) = 3", "[test]") -{ - REQUIRE(goldstein_price(0.0, -1.0) == Approx(3.0).margin(TOLERANCE)); -} - -inline double booth(double x, double y) -{ - return std::pow(x + 2.0 * y - 7.0, 2.0) + std::pow(2.0 * x + y - 5.0, 2.0); -} - -TEST_CASE("booth(1, 3) = 0", "[test]") -{ - REQUIRE(booth(1.0, 3.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double bukin_n6(double x, double y) -{ - return 100.0 * std::sqrt(std::abs(y - 0.01 * std::pow(x, 2.0))) + - 0.01 * std::abs(x + 10.0); -} - -TEST_CASE("bukin_n6(-10, 1) = 0", "[test]") -{ - REQUIRE(bukin_n6(-10.0, 1.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double matyas(double x, double y) -{ - return 0.26 * sphere(x, y) - 0.48 * x * y; -} - -TEST_CASE("matyas(0, 0) = 0", "[test]") -{ - REQUIRE(matyas(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double sin2(double x) -{ - return std::pow(std::sin(x), 2.0); -} - -inline double levi_n13(double x, double y) -{ - return sin2(3.0 * std::numbers::pi * x) + - std::pow(x - 1.0, 2.0) * (1.0 + sin2(3.0 * std::numbers::pi * y)) + - std::pow(y - 1.0, 2.0) * (1.0 + sin2(2.0 * std::numbers::pi * y)); -} - -TEST_CASE("levi_n13(1, 1) = 0", "[test]") -{ - REQUIRE(levi_n13(1.0, 1.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double three_hump_camel(double x, double y) -{ - return 2.0 * std::pow(x, 2.0) - 1.05 * std::pow(x, 4.0) + - std::pow(x, 6.0) / 6.0 + x * y + std::pow(y, 2.0); -} - -TEST_CASE("three_hump_camel(0, 0) = 0", "[test]") -{ - REQUIRE(three_hump_camel(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double eggholder(double x, double y) -{ - return -(y + 47.0) * std::sin(std::sqrt(std::abs(x / 2.0 + (y + 47.0)))) - - x * std::sin(std::sqrt(std::abs(x - (y + 47.0)))); -} - -TEST_CASE("eggholder(512, 404.2319) = -959.6407", "[test]") -{ - REQUIRE(eggholder(512.0, 404.2319) == Approx(-959.6407).margin(TOLERANCE)); -} - -inline double mccormick(double x, double y) -{ - return std::sin(x + y) + std::pow(x - y, 2.0) - 1.5 * x + 2.5 * y + 1.0; -} - -TEST_CASE("mccormick(-0.54719, -1.54719) = -1.9133", "[test]") -{ - REQUIRE(mccormick(-0.54719, -1.54719) == Approx(-1.9133).margin(TOLERANCE)); -} - -inline double schaffer_n2(double x, double y) -{ - return 0.5 + (sin2(std::pow(x, 2.0) - std::pow(y, 2.0)) - 0.5) / - std::pow(1.0 + 0.001 * sphere(x, y), 2.0); -} - -TEST_CASE("schaffer_n2(0, 0) = 0", "[test]") -{ - REQUIRE(schaffer_n2(0.0, 0.0) == Approx(0.0).margin(TOLERANCE)); -} - -inline double cos2(double x) -{ - return std::pow(std::cos(x), 2.0); -} - -inline double schaffer_n4(double x, double y) -{ - return 0.5 + - (cos2(std::sin(std::abs(std::pow(x, 2.0) - std::pow(y, 2.0)))) - - 0.5) / - std::pow(1.0 + 0.001 * sphere(x, y), 2.0); -} - -TEST_CASE("schaffer_n4(0, 1.25313) = 0.292579", "[test]") -{ - REQUIRE(schaffer_n4(0.0, 1.25313) == Approx(0.292579).margin(TOLERANCE)); -} - -TEST_CASE("Should pass", "[test]") -{ - REQUIRE_NOTHROW(foo()); -}