Commit ec68c05b authored by sbl1996@126.com's avatar sbl1996@126.com

Add vloss clip

parent c02fbd19
...@@ -130,6 +130,9 @@ class Args: ...@@ -130,6 +130,9 @@ class Args:
logits_threshold: Optional[float] = None logits_threshold: Optional[float] = None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0""" """the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
vloss_clip: Optional[float] = None
"""the value loss clipping coefficient"""
ent_coef: float = 0.01 ent_coef: float = 0.01
"""coefficient of the entropy""" """coefficient of the entropy"""
vf_coef: float = 1.0 vf_coef: float = 1.0
...@@ -718,6 +721,9 @@ if __name__ == "__main__": ...@@ -718,6 +721,9 @@ if __name__ == "__main__":
v_loss = v_loss / n_valids v_loss = v_loss / n_valids
ent_loss = ent_loss / n_valids ent_loss = ent_loss / n_valids
if args.vloss_clip is not None:
v_loss = jnp.minimum(v_loss, args.vloss_clip)
loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef loss = pg_loss - args.ent_coef * ent_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl)) return loss, (pg_loss, v_loss, ent_loss, jax.lax.stop_gradient(approx_kl))
......
...@@ -582,7 +582,7 @@ static const ankerl::unordered_dense::map<uint8_t, uint8_t> location2id = ...@@ -582,7 +582,7 @@ static const ankerl::unordered_dense::map<uint8_t, uint8_t> location2id =
DEFINE_X_TO_ID_FUN(location_to_id, location2id) DEFINE_X_TO_ID_FUN(location_to_id, location2id)
#define POS_NONE 0x0 // xyz materials (overlay) #define POS_NONE 0x0 // xyz materials (overlay) ???
static const std::map<uint8_t, std::string> position2str = { static const std::map<uint8_t, std::string> position2str = {
{POS_NONE, "none"}, {POS_NONE, "none"},
...@@ -646,7 +646,7 @@ static const std::map<uint32_t, std::string> race2str = { ...@@ -646,7 +646,7 @@ static const std::map<uint32_t, std::string> race2str = {
{RACE_CREATORGOD, "Creator God"}, {RACE_CREATORGOD, "Creator God"},
{RACE_WYRM, "Wyrm"}, {RACE_WYRM, "Wyrm"},
{RACE_CYBERSE, "Cyberse"}, {RACE_CYBERSE, "Cyberse"},
{RACE_ILLUSION, "Illusion'"}}; {RACE_ILLUSION, "Illusion"}};
static const ankerl::unordered_dense::map<uint32_t, uint8_t> race2id = static const ankerl::unordered_dense::map<uint32_t, uint8_t> race2id =
make_ids(race2str); make_ids(race2str);
...@@ -2086,6 +2086,11 @@ private: ...@@ -2086,6 +2086,11 @@ private:
f_cards(offset, 6) = 1; f_cards(offset, 6) = 1;
} else { } else {
f_cards(offset, 5) = position_to_id(c.position_); f_cards(offset, 5) = position_to_id(c.position_);
if (hide && (location == LOCATION_DECK || location == LOCATION_HAND ||
location == LOCATION_EXTRA)) {
f_cards(offset, 5) = position_to_id(POS_FACEDOWN);
// fmt::println("location: {}, position: {}\n", location2str.at(location), position_to_string(c.position_));
}
} }
if (!hide) { if (!hide) {
f_cards(offset, 7) = attribute_to_id(c.attribute_); f_cards(offset, 7) = attribute_to_id(c.attribute_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment