diff --git a/grpc-server.py b/grpc_server.py similarity index 96% rename from grpc-server.py rename to grpc_server.py index 52f2f64..3c5ebcd 100644 --- a/grpc-server.py +++ b/grpc_server.py @@ -26,6 +26,7 @@ def GetCoachActions(self, request:pb2.State, context): return actions def GetTrainerActions(self, request:pb2.State, context): + print(f"Fullstate wm time:{request.full_world_model.cycle}") actions = self.trainer_agent.get_actions(request.world_model) return actions diff --git a/gym_env.py b/gym_env.py new file mode 100644 index 0000000..3ff6775 --- /dev/null +++ b/gym_env.py @@ -0,0 +1,150 @@ +import gymnasium as gym +from queue import Queue +import numpy as np +from service_pb2 import GameModeType, ServerParam, PlayerType, PlayerParam, Side, State, TrainerAction, TrainerActions, Vector2D, WorldModel +import service_pb2 as pb2 +from service_pb2 import Body_KickOneStep +from pyrusgeom.vector_2d import Vector2D as V2D +class CustomGymEnv(gym.Env): + def __init__(self) -> None: + super().__init__() + self.player_action_queue = Queue(1) + self.trainer_action_queue = Queue(10) + self.observation_queue = Queue(1) + self.episode_reward = 0 + self.intermittent_rewards=[] + self.previous_intermittent_wm: WorldModel = None + self.old_observation = None + self.ANGLE_DIVS = 4 + self.POWER_DIVS = 4 + self.server_param: ServerParam = None + self.player_param: PlayerParam = None + self.player_type: PlayerType = None + self.ANGLE_STEP = 360/self.ANGLE_DIVS + + # observation: normalized self pos, normalized ball pos + self.observation_space = gym.spaces.Box(low=np.array([-1,-1,-1,-1]),high=np.array([1,1,1,1]) + ,shape=(4,),dtype=np.float64) + # action space: kick, 18 angles, 4 power levels + # use multidescrete + self.action_space = gym.spaces.MultiDiscrete([self.ANGLE_DIVS,self.POWER_DIVS]) + + def append_intermittent_rewards(self, wm:WorldModel): + if self.previous_intermittent_wm is None: + self.previous_intermittent_wm = wm + return + ball_pos = wm.ball.position + old_ball_pos = self.previous_intermittent_wm.ball.position + self.intermittent_rewards.append(ball_pos.x - old_ball_pos.x) + self.previous_intermittent_wm = wm + + + + + + def observation_to_ndarray(self, observation: State) -> np.ndarray: + wm = observation.world_model + self_pos = wm.self.position + ball_pos = wm.ball.position + hl = self.server_param.pitch_half_length + hw = self.server_param.pitch_half_width + normalized_self_pos = [self_pos.x/hl, self_pos.y/hw] + normalized_ball_pos = [ball_pos.x/hl, ball_pos.y/hw] + return np.array(normalized_self_pos + normalized_ball_pos) + + def calculate_reward(self, observation:State, old_observation:State) -> float: + if observation.world_model.game_mode_type == GameModeType.KickOff_: + last_ball_x = old_observation.world_model.ball.position.x + if last_ball_x > 0: + actual_ball = Vector2D(x=52.5,y=0) + else: + actual_ball = Vector2D(x=-52.5,y=0) + ball_pos = actual_ball + else: + ball_pos = observation.world_model.ball.position + old_ball_pos = old_observation.world_model.ball.position + # print(f'Ball Pos:({ball_pos.x},{ball_pos.y}), old ball pos : ({old_ball_pos.x},{old_ball_pos.y})') + return ball_pos.x - old_ball_pos.x + + def wait_for_observation_and_return(self): + # print("Awaiting observation") + observation = self.observation_queue.get(block=True) + # print("RECEIVED OBSERVATION") + self.old_observation = observation + # if self.old_observation is None: + # print("OLD OBS IS NONE!!!!!!!!") + return self.observation_to_ndarray(observation), {} + + def clear_actions_queue(self): + with self.player_action_queue.mutex: + self.player_action_queue.queue.clear() + + def clear_observation_queue(self): + with self.observation_queue.mutex: + self.observation_queue.queue.clear() + + def do_action(self, action, clear_actions: bool = False): + if clear_actions: + self.clear_actions_queue() + self.player_action_queue.put(action,block=False) + + def gym_action_to_soccer_action(self, action, wm:WorldModel): + # angle = -np.pi + action[0] * self.ANGLE_STEP + absolute_angle = -180 + action[0] * self.ANGLE_STEP + power_step = self.server_param.ball_speed_max / self.POWER_DIVS + power = (action[1] +1) * power_step + # print(f"Body Dir: {wm.self.body_direction}") + pos = wm.self.position + target = V2D.polar2vector(10,absolute_angle)+ V2D(x=pos.x,y=pos.y) + + # print(f"Absolute Angle: {absolute_angle}, Power: {power}") + + return Body_KickOneStep(first_speed=power,target_point=Vector2D(x=target.x(),y=target.y()),force_mode=True) + + + def step(self, action): + # print("STEP START") + self.clear_observation_queue() + self.do_action(action) + # print("await observation") + observation:State = self.observation_queue.get() + # print("###############") + game_mode = observation.world_model.game_mode_type + # print(f"Game Cycle: {observation.world_model.cycle}, Game Mode : {game_mode}") + # todo: hack to find which goal was scored in since goal_l or goal_r dont send observations. + + reward = 0 + + if self.old_observation is not None: + reward = self.calculate_reward(observation, self.old_observation) + self.episode_reward += reward + + done = observation.world_model.game_mode_type != GameModeType.PlayOn + # print(f"Done = {done}") + self.old_observation = observation + return self.observation_to_ndarray(observation), reward, done, False, {} + + def get_trainer_reset_commands(self) -> TrainerActions: + actions = TrainerActions() + zero_vec = Vector2D(x=0.,y=0.) + player_vec =Vector2D(x=-5.,y=0.) + actions.actions.append(TrainerAction(do_change_mode=pb2.DoChangeMode(game_mode_type=GameModeType.PlayOn,side=Side.LEFT))) + actions.actions.append(TrainerAction(do_move_ball=pb2.DoMoveBall(position=zero_vec,velocity=zero_vec))) + actions.actions.append(TrainerAction(do_recover=pb2.DoRecover())) + actions.actions.append(TrainerAction(do_move_player=pb2.DoMovePlayer(our_side=True, uniform_number= 1, position= player_vec,body_direction=0))) + return actions + + + def reset(self, seed = -1): + # print("RESETING ENV!") + print(f"Episode reward:{self.episode_reward}") + self.episode_reward = 0 + self.old_observation = None + self.clear_actions_queue() + self.clear_observation_queue() + if self.trainer_action_queue.empty: + self.trainer_action_queue.put(0) + # todo add reset action + # reset action sent, to unblock player action + self.do_action([-1,-1], clear_actions=True) + return self.wait_for_observation_and_return() \ No newline at end of file diff --git a/gym_grpc_server.py b/gym_grpc_server.py new file mode 100644 index 0000000..40891c4 --- /dev/null +++ b/gym_grpc_server.py @@ -0,0 +1,119 @@ +import time +from grpc_server import Game +import threading +from concurrent import futures +import grpc +from gym_env import CustomGymEnv +import service_pb2_grpc as pb2_grpc +import service_pb2 as pb2 +from stable_baselines3.common.env_checker import check_env +from stable_baselines3 import PPO,A2C +from queue import Empty, Full + +trainer_started = False +class GymGame(Game): + def __init__(self, gym_env:CustomGymEnv): + super().__init__() + self.was_set_play_before = False + self.gym_env: CustomGymEnv = gym_env + + def SendServerParams(self, request: pb2.ServerParam, context): + self.gym_env.server_param = request + return super().SendServerParams(request, context) + + def SendPlayerParams(self, request: pb2.PlayerParam, context): + self.gym_env.player_param = request + return super().SendPlayerParams(request, context) + + def SendPlayerType(self, request: pb2.PlayerType, context): + self.gym_env.player_type = request + return super().SendPlayerType(request, context) + + def GetTrainerActions(self, request: pb2.State, context): + global trainer_started + trainer_started = True + try: + self.gym_env.trainer_action_queue.get(block=False) + return self.gym_env.get_trainer_reset_commands() + except Empty: + return pb2.TrainerActions() + + def GetPlayerActions(self, request, context): + # append neck action first to the action list + + actions = pb2.PlayerActions() + wm = request.world_model + # self.gym_env.append_intermittent_rewards(wm) + actions.actions.append(pb2.PlayerAction(neck_turn_to_ball=pb2.Neck_TurnToBall())) + if wm.game_mode_type != pb2.GameModeType.PlayOn and not self.was_set_play_before: + self.was_set_play_before = True + self.gym_env.clear_actions_queue() + self.gym_env.observation_queue.put(request) + return actions + if wm.game_mode_type == pb2.GameModeType.PlayOn: + self.was_set_play_before = False + if not request.world_model.self.is_kickable: + # if the ball is not kickable, return Intercept + intercept_action = pb2.Body_Intercept(save_recovery=False, face_point=wm.ball.position) + actions.actions.append(pb2.PlayerAction(body_intercept=intercept_action)) + return actions + # if the ball is kickable, send observation to the gym env + self.gym_env.clear_actions_queue() + self.gym_env.clear_observation_queue() + self.gym_env.observation_queue.put(request,block=True) + action = self.gym_env.player_action_queue.get(block = True) + if action[0] == -1: + # is from reset + return actions + + selected_kick_action: pb2.Body_KickOneStep = self.gym_env.gym_action_to_soccer_action(action, request.world_model) + # convert the action to the grpc action + actions.actions.append(pb2.PlayerAction(body_kick_one_step=selected_kick_action)) + return actions + +def serve(gym_env:CustomGymEnv): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=22)) + pb2_grpc.add_GameServicer_to_server(GymGame(gym_env), server) + server.add_insecure_port('localhost:50051') + server.start() + print("Decision Server started. Listening on port 50051...") + try: + while True: + time.sleep(60 * 60 * 24) # Sleep for a day or any desired interval + except KeyboardInterrupt: + print("Shutting down the server...") + server.stop(0) + +if __name__ == "__main__": + gym_env = CustomGymEnv() + server_thread = threading.Thread(target=serve, args=(gym_env,)) + server_thread.start() + print("Await trainer") + while not trainer_started: + pass + # gym_env.reset() + model = A2C('MlpPolicy',gym_env, verbose= 1) + model = model.learn(1000,progress_bar=True) + print("Model trained") + print("?????????????????????????????????????????????????????????????????????????") + gym_env.clear_actions_queue() + gym_env.clear_observation_queue() + with gym_env.trainer_action_queue.mutex: + gym_env.trainer_action_queue.queue.clear() + observation, _ = gym_env.reset() + # observation, _ = gym_env.wait_for_observation_and_return() + # observation,_ = gym_env.wait_for_observation_and_return() + while server_thread.is_alive(): + # get action from the model + action, _ = model.predict(observation) + # action = gym_env.action_space.sample() + print(f"Action: {action}") + # get observation from the environment + observation, reward, done,truncated, info = gym_env.step(action) + print(f"Observation: {observation}, Reward: {reward}, Done: {done}, Info: {info}") + if done: + observation, info = gym_env.reset() + print("Environment reset") + else: + print("Environment not reset") + diff --git a/src/SampleTrainerAgent.py b/src/SampleTrainerAgent.py index 5debe9e..197f052 100644 --- a/src/SampleTrainerAgent.py +++ b/src/SampleTrainerAgent.py @@ -20,6 +20,7 @@ def get_actions(self, wm:pb2.WorldModel) -> pb2.TrainerActions: print(f'cycle: {self.wm.ball.position.x}, {self.wm.ball.position.y}') if self.wm.cycle % 100 == 0: + print("Sending trainer action") actions.actions.append( pb2.TrainerAction( do_move_ball=pb2.DoMoveBall(