ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 살사 코드
    IT&컴퓨터공학/딥러닝 2020. 12. 6. 23:31

    environment.py

     

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    import time
    import numpy as np
    import tkinter as tk
    from PIL import ImageTk, Image
     
    np.random.seed(1)
    PhotoImage = ImageTk.PhotoImage
    UNIT = 100  # 필셀 수
    HEIGHT = 5  # 그리드 월드 가로
    WIDTH = 5  # 그리드 월드 세로
     
     
    class Env(tk.Tk):
        def __init__(self):
            super(Env, self).__init__()
            self.action_space = ['u''d''l''r'#상,하,좌,우
            self.n_actions = len(self.action_space)
            self.title('SARSA')
            self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))
            self.shapes = self.load_images()
            self.canvas = self._build_canvas()
            self.texts = []
     
        def _build_canvas(self):
            canvas = tk.Canvas(self, bg='white',
                               height=HEIGHT * UNIT,
                               width=WIDTH * UNIT)
            # 그리드 생성
            for c in range(0, WIDTH * UNIT, UNIT):  # 0~400 by 80
                x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT
                canvas.create_line(x0, y0, x1, y1)
            for r in range(0, HEIGHT * UNIT, UNIT):  # 0~400 by 80
                x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r
                canvas.create_line(x0, y0, x1, y1)
     
            # 캔버스에 이미지 추가
            self.rectangle = canvas.create_image(5050, image=self.shapes[0])
            self.triangle1 = canvas.create_image(250150, image=self.shapes[1])
            self.triangle2 = canvas.create_image(150250, image=self.shapes[1])
            self.circle = canvas.create_image(250250, image=self.shapes[2])
     
            canvas.pack()
     
            return canvas
     
        def load_images(self):
            rectangle = PhotoImage(
                Image.open("../img/rectangle.png").resize((6565)))
            triangle = PhotoImage(
                Image.open("../img/triangle.png").resize((6565)))
            circle = PhotoImage(
                Image.open("../img/circle.png").resize((6565)))
     
            return rectangle, triangle, circle
     
        def text_value(self, row, col, contents, action, font='Helvetica', size=10,
                       style='normal', anchor="nw"):
            if action == 0:
                origin_x, origin_y = 742
            elif action == 1:
                origin_x, origin_y = 8542
            elif action == 2:
                origin_x, origin_y = 425
            else:
                origin_x, origin_y = 4277
     
            x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
            font = (font, str(size), style)
            text = self.canvas.create_text(x, y, fill="black", text=contents,
                                           font=font, anchor=anchor)
            return self.texts.append(text)
     
        def print_value_all(self, q_table):
            for i in self.texts:
                self.canvas.delete(i)
            self.texts.clear()
            for x in range(HEIGHT):
                for y in range(WIDTH):
                    for action in range(04):
                        state = [x, y]
                        if str(state) in q_table.keys():
                            temp = q_table[str(state)][action]
                            self.text_value(y, x, round(temp, 3), action)
     
        def coords_to_state(self, coords):
            x = int((coords[0- 50/ 100)
            y = int((coords[1- 50/ 100)
            return [x, y]
     
        def reset(self):
            self.update()
            time.sleep(0.5)
            x, y = self.canvas.coords(self.rectangle)
            self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
            self.render()
            return self.coords_to_state(self.canvas.coords(self.rectangle))
     
        def step(self, action):
            state = self.canvas.coords(self.rectangle)
            base_action = np.array([00])
            self.render()
     
            if action == 0:  # 상
                if state[1> UNIT:
                    base_action[1-= UNIT
            elif action == 1:  # 하
                if state[1< (HEIGHT - 1* UNIT:
                    base_action[1+= UNIT
            elif action == 2:  # 좌
                if state[0> UNIT:
                    base_action[0-= UNIT
            elif action == 3:  # 우
                if state[0< (WIDTH - 1* UNIT:
                    base_action[0+= UNIT
     
            # 에이전트 이동
            self.canvas.move(self.rectangle, base_action[0], base_action[1])
            # 에이전트(빨간 네모)를 가장 상위로 배치
            self.canvas.tag_raise(self.rectangle)
            next_state = self.canvas.coords(self.rectangle)
     
            # 보상 함수
            if next_state == self.canvas.coords(self.circle):
                reward = 100
                done = True
            elif next_state in [self.canvas.coords(self.triangle1),
                                self.canvas.coords(self.triangle2)]:
                reward = -100
                done = True
            else:
                reward = 0
                done = False
     
            next_state = self.coords_to_state(next_state)
            return next_state, reward, done
     
        def render(self):
            time.sleep(0.03)
            self.update()
    cs

     

    agent.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    import numpy as np
    import random
    from collections import defaultdict
    from environment import Env
     
     
    class SARSAgent:
        def __init__(self, actions):
            self.actions = actions
            self.step_size = 0.01 # 학습률 0.01
            self.discount_factor = 0.9 # 감가율 0.9
            self.epsilon = 0.1 # 입실론 0.1
            # 0을 초기값으로 가지는 큐함수 테이블 생성
            self.q_table = defaultdict(lambda: [0.00.00.00.0])
     
        # <s, a, r, s', a'>의 샘플로부터 큐함수를 업데이트 ( 벨만 기대방정식 이용 )
        def learn(self, state, action, reward, next_state, next_action):
            state, next_state = str(state), str(next_state)
            current_q = self.q_table[state][action]
            next_state_q = self.q_table[next_state][next_action]
            td = reward + self.discount_factor * next_state_q - current_q
            new_q = current_q + self.step_size * td  # 살사의 큐함수 업데이트 식
            self.q_table[state][action] = new_q
     
        # 입실론 탐욕 정책에 따라서 행동을 반환
        def get_action(self, state):
            if np.random.rand() < self.epsilon:
                # 무작위 행동 반환
                action = np.random.choice(self.actions)
            else:
                # 큐함수에 따른 행동 반환
                state = str(state)
                q_list = self.q_table[state]
                action = arg_max(q_list)
            return action
     
     
    # 큐함수의 값에 따라 최적의 행동을 반환
    def arg_max(q_list):
        max_idx_list = np.argwhere(q_list == np.amax(q_list))
        max_idx_list = max_idx_list.flatten().tolist()
        return random.choice(max_idx_list)
     
     
    if __name__ == "__main__":
        env = Env() #환경설정
        agent = SARSAgent(actions=list(range(env.n_actions))) # 에이전트생성
     
        for episode in range(1000): # 에피소드는 1000 개
            # 게임 환경과 상태를 초기화
            state = env.reset()
            # 현재 상태에 대한 행동을 선택
            action = agent.get_action(state)
     
            while True:
                env.render()
     
                # 행동을 위한 후 다음상태 보상 에피소드의 종료 여부를 받아옴
                next_state, reward, done = env.step(action)
                # 다음 상태에서의 다음 행동 선택
                next_action = agent.get_action(next_state)
                # <s,a,r,s',a'>로 큐함수를 업데이트
                agent.learn(state, action, reward, next_state, next_action)
     
                state = next_state
                action = next_action
     
                # 모든 큐함수를 화면에 표시
                env.print_value_all(agent.q_table)
     
                if done:
                    break
    cs

    댓글

Designed by Tistory.