Skip to content

Commit

Permalink
weighted choices initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
filipworksdev committed Feb 16, 2022
1 parent cff88cb commit 4dd4a05
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
48 changes: 46 additions & 2 deletions modules/goblin/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,49 @@ Variant Rand::choice(const Variant &p_from) {
return Variant();
}

Variant Rand::choices(const Variant &p_from, int count, const Array &p_weights) {
switch (p_from.get_type()) {
case Variant::POOL_BYTE_ARRAY:
case Variant::POOL_INT_ARRAY:
case Variant::POOL_REAL_ARRAY:
case Variant::POOL_STRING_ARRAY:
case Variant::POOL_VECTOR2_ARRAY:
case Variant::POOL_VECTOR3_ARRAY:
case Variant::POOL_COLOR_ARRAY:
case Variant::ARRAY: {
Array arr = p_from;
ERR_FAIL_COND_V_MSG(arr.empty(), Array(), "Array is empty.");
ERR_FAIL_COND_V_MSG(arr.size() != p_weights.size(), Array(), "Array and weights unequal size.");

int weights_sum = 0;
for (int i = 0; i < p_weights.size(); i++) {
if (p_weights.get(i).get_type() == Variant::INT) {
weights_sum += (int)p_weights.get(i);
} else {
ERR_FAIL_V_MSG(Array(), "Weights are not integers.");
}
}

Array choices = Array();
while(choices.size() < count) {
float remaining_distance = randf() * weights_sum;
for (int i = 0; i < p_weights.size(); i++) {
remaining_distance -= (int)p_weights.get(i);
if (remaining_distance < 0) {
choices.append(p_from.get(i));
break;
}
}
}

return choices;
} break;
default: {
ERR_FAIL_V_MSG(Variant(), "Unsupported: the type must be Array.");
}
}
}

void Rand::shuffle(Array p_array) {
if (p_array.size() < 2) {
return;
Expand Down Expand Up @@ -239,7 +282,7 @@ Color Rand::color() {
return color;
}

String Rand::uuid_v4() {
String Rand::uuid() {
Ref<Crypto> crypto = Ref<Crypto>(Crypto::create());
PoolByteArray data = crypto->generate_random_bytes(16);

Expand All @@ -259,12 +302,13 @@ void Rand::_bind_methods() {
ClassDB::bind_method(D_METHOD("f", "from", "to"), &Rand::f);

ClassDB::bind_method(D_METHOD("choice", "from"), &Rand::choice);
ClassDB::bind_method(D_METHOD("choices", "from", "count", "weights"), &Rand::choices);
ClassDB::bind_method(D_METHOD("shuffle", "array"), &Rand::shuffle);
ClassDB::bind_method(D_METHOD("decision", "probability"), &Rand::decision);
ClassDB::bind_method(D_METHOD("roll", "count", "faces"), &Rand::roll);
ClassDB::bind_method(D_METHOD("roll_notation", "dice_notation"), &Rand::roll_notation);
ClassDB::bind_method(D_METHOD("color"), &Rand::color);
ClassDB::bind_method(D_METHOD("uuid_v4"), &Rand::uuid_v4);
ClassDB::bind_method(D_METHOD("uuid"), &Rand::uuid);

ADD_PROPERTY_DEFAULT("seed", 0);
ADD_PROPERTY_DEFAULT("state", 0);
Expand Down
3 changes: 2 additions & 1 deletion modules/goblin/rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ class Rand : public RandomNumberGenerator {
real_t f(real_t from, real_t to);

Variant choice(const Variant &p_from);
Variant choices(const Variant &p_from, int count, const Array &p_weights);
void shuffle(Array p_array);
bool decision(float probability);
Variant roll(uint32_t count, uint32_t sides);
Variant roll_notation(const String notation);
Color color();
String uuid_v4();
String uuid();

Rand();
~Rand() {};
Expand Down

0 comments on commit 4dd4a05

Please sign in to comment.