We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
actor_network
Distribution
Hi everybody i have a problem with training a complex network by tf ppo agent:
class CustomActorNetwork(network.Network): def __init__(self, input_tensor_spec, output_tensor_spec, name='CustomActorNetwork'): super(CustomActorNetwork, self).__init__( input_tensor_spec=input_tensor_spec, state_spec=(), name=name) self.conv1 = Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), activation='relu') self.pool1 = MaxPooling2D((2, 2)) self.conv2 = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), activation='relu') self.pool2 = MaxPooling2D((2, 2)) self.conv3 = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), activation='relu') self.pool3 = MaxPooling2D((2, 2)) self.reshape = Reshape((-1, 128)) self.lstm = LSTM(512, return_sequences=False, use_bias=True) self.batch_norm = BatchNormalization() self.dense1 = [Dense(128, activation='relu') for _ in range(4)] self.dense2 = [Dense(64, activation='relu') for _ in range(4)] self.out1 = Dense(3, activation='softmax') self.out2 = Dense(100, activation='softmax') self.out3 = Dense(100, activation='softmax') self.out4 = Dense(30, activation='softmax') def call(self, observation, step_type=None, network_state=(), training=False): x = self.conv1(observation) x = self.pool1(x) x = self.conv2(x) x = self.pool2(x) x = self.conv3(x) x = self.pool3(x) x = self.reshape(x) x = self.lstm(x) x = self.batch_norm(x) x_out1 = self.dense1[0](x) x_out1 = self.dense2[0](x_out1) out1_logits = self.position(x_out1) out1_dist = tfp.distributions.Categorical(logits=out1_logits).sample() xc = concatenate([out1_logits, x], axis=-1) x_out2 = self.dense1[1](xc) x_out2 = self.dense2[1](x_out2) out2_logits = self.up_band(x_out2) out2_dist = tfp.distributions.Categorical(logits=out2_logits).sample() x_out3 = self.dense1[2](xc) x_out3 = self.dense2[2](x_out3) out3_logits = self.down_band(x_out3) out3_dist = tfp.distributions.Categorical(logits=out3_logits).sample() xc2 = concatenate([xc, out2_logits, out3_logits], axis=-1) x_out4 = self.dense1[3](xc2) x_out4 = self.dense2[3](x_out4) out4_logits = self.volume(x_out4) out4_dist = tfp.distributions.Categorical(logits=out4_logits).sample() return tf.stack([out1_dist, out2_dist, out3_dist, out4_dist], axis=-1), network_state
this is my action sepc:
array_spec.BoundedArraySpec( shape=(4,), dtype=np.int32, minimum=[0, 0, 0, 0], maximum=[2, 99, 99, 29], name='action')
and this is my agent:
actor_net = Model.CustomActorNetwork(observation_spec,action_spec) value_net = Model.CustomValueNetwork(observation_spec,action_spec) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-4) train_step_counter = tf.Variable(0) tf_agent = ppo_agent.PPOAgent( time_step_spec=train_env.time_step_spec(), action_spec=action_spec, optimizer=optimizer, actor_net=actor_net, value_net=value_net, num_epochs=5, train_step_counter=train_step_counter )
and my error is :
Exception has occurred: ValueError Unexpected output from `actor_network`. Expected `Distribution` objects, but saw output spec: TensorSpec(shape=(4,), dtype=tf.int32, name=None) In call to configurable 'PPOPolicy' (<class 'tf_agents.agents.ppo.ppo_policy.PPOPolicy'>) In call to configurable 'PPOAgent' (<class 'tf_agents.agents.ppo.ppo_agent.PPOAgent'>) ValueError: Unexpected output from `actor_network`. Expected `Distribution` objects, but saw output spec: TensorSpec(shape=(4,), dtype=tf.int32, name=None)
can everybody help me?
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi everybody
i have a problem with training a complex network by tf ppo agent:
this is my action sepc:
and this is my agent:
and my error is :
can everybody help me?
The text was updated successfully, but these errors were encountered: