ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 가치이터레이션 코드
    IT&컴퓨터공학/딥러닝 2020. 12. 6. 21:01

    environment.py

     

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    import tkinter as tk
    import time
    import numpy as np
    import random
    from PIL import ImageTk, Image
     
    PhotoImage = ImageTk.PhotoImage
    UNIT = 100  # 픽셀 수
    HEIGHT = 5  # 그리드월드 세로
    WIDTH = 5  # 그리드월드 가로
    TRANSITION_PROB = 1
    POSSIBLE_ACTIONS = [0123]  # 상, 하, 좌, 우
    ACTIONS = [(-10), (10), (0-1), (01)]  # 좌표로 나타낸 행동
    REWARDS = []
     
     
    class GraphicDisplay(tk.Tk):
        def __init__(self, value_iteration):
            super(GraphicDisplay, self).__init__()
            self.title('Value Iteration')
            self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT + 50))
            self.texts = []
            self.arrows = []
            self.env = Env()
            self.agent = value_iteration
            self.iteration_count = 0
            self.improvement_count = 0
            self.is_moving = 0
            (self.up, self.down, self.left,
             self.right), self.shapes = self.load_images()
            self.canvas = self._build_canvas()
            self.text_reward(22"R : 1.0")
            self.text_reward(12"R : -1.0")
            self.text_reward(21"R : -1.0")
     
        def _build_canvas(self):
            canvas = tk.Canvas(self, bg='white',
                               height=HEIGHT * UNIT,
                               width=WIDTH * UNIT)
            # 버튼 초기화
            iteration_button = tk.Button(self, text="Calculate",
                                         command=self.calculate_value)
            iteration_button.configure(width=10, activebackground="#33B5E5")
            canvas.create_window(WIDTH * UNIT * 0.13, (HEIGHT * UNIT) + 10,
                                 window=iteration_button)
     
            policy_button = tk.Button(self, text="Print Policy",
                                      command=self.print_optimal_policy)
            policy_button.configure(width=10, activebackground="#33B5E5")
            canvas.create_window(WIDTH * UNIT * 0.37, (HEIGHT * UNIT) + 10,
                                 window=policy_button)
     
            policy_button = tk.Button(self, text="Move",
                                      command=self.move_by_policy)
            policy_button.configure(width=10, activebackground="#33B5E5")
            canvas.create_window(WIDTH * UNIT * 0.62, (HEIGHT * UNIT) + 10,
                                 window=policy_button)
     
            policy_button = tk.Button(self, text="Clear", command=self.clear)
            policy_button.configure(width=10, activebackground="#33B5E5")
            canvas.create_window(WIDTH * UNIT * 0.87, (HEIGHT * UNIT) + 10,
                                 window=policy_button)
     
            # 그리드 생성
            for col in range(0, WIDTH * UNIT, UNIT):  # 0~400 by 80
                x0, y0, x1, y1 = col, 0, col, HEIGHT * UNIT
                canvas.create_line(x0, y0, x1, y1)
            for row in range(0, HEIGHT * UNIT, UNIT):  # 0~400 by 80
                x0, y0, x1, y1 = 0, row, HEIGHT * UNIT, row
                canvas.create_line(x0, y0, x1, y1)
     
            # 캔버스에 이미지 추가
            self.rectangle = canvas.create_image(5050, image=self.shapes[0])
            canvas.create_image(250150, image=self.shapes[1])
            canvas.create_image(150250, image=self.shapes[1])
            canvas.create_image(250250, image=self.shapes[2])
     
            canvas.pack()
     
            return canvas
     
        def load_images(self):
            PhotoImage = ImageTk.PhotoImage
            up = PhotoImage(Image.open("../img/up.png").resize((1313)))
            right = PhotoImage(Image.open("../img/right.png").resize((1313)))
            left = PhotoImage(Image.open("../img/left.png").resize((1313)))
            down = PhotoImage(Image.open("../img/down.png").resize((1313)))
            rectangle = PhotoImage(
                Image.open("../img/rectangle.png").resize((6565)))
            triangle = PhotoImage(
                Image.open("../img/triangle.png").resize((6565)))
            circle = PhotoImage(Image.open("../img/circle.png").resize((6565)))
            return (up, down, left, right), (rectangle, triangle, circle)
     
        def clear(self):
     
            if self.is_moving == 0:
                self.iteration_count = 0
                self.improvement_count = 0
                for i in self.texts:
                    self.canvas.delete(i)
     
                for i in self.arrows:
                    self.canvas.delete(i)
     
                self.agent.value_table = [[0.0* WIDTH for _ in range(HEIGHT)]
     
                x, y = self.canvas.coords(self.rectangle)
                self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
     
        def reset(self):
            self.update()
            time.sleep(0.5)
            self.canvas.delete(self.rectangle)
            return self.canvas.coords(self.rectangle)
     
        def text_value(self, row, col, contents, font='Helvetica', size=12,
                       style='normal', anchor="nw"):
            origin_x, origin_y = 8570
            x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
            font = (font, str(size), style)
            text = self.canvas.create_text(x, y, fill="black", text=contents,
                                           font=font, anchor=anchor)
            return self.texts.append(text)
     
        def text_reward(self, row, col, contents, font='Helvetica', size=12,
                        style='normal', anchor="nw"):
            origin_x, origin_y = 55
            x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
            font = (font, str(size), style)
            text = self.canvas.create_text(x, y, fill="black", text=contents,
                                           font=font, anchor=anchor)
            return self.texts.append(text)
     
        def rectangle_move(self, action):
            base_action = np.array([00])
            location = self.find_rectangle()
            self.render()
            if action == 0 and location[0> 0:  # up
                base_action[1-= UNIT
            elif action == 1 and location[0< HEIGHT - 1:  # down
                base_action[1+= UNIT
            elif action == 2 and location[1> 0:  # left
                base_action[0-= UNIT
            elif action == 3 and location[1< WIDTH - 1:  # right
                base_action[0+= UNIT
     
            self.canvas.move(self.rectangle, base_action[0],
                             base_action[1])  # move agent
     
        def find_rectangle(self):
            temp = self.canvas.coords(self.rectangle)
            x = (temp[0/ 100- 0.5
            y = (temp[1/ 100- 0.5
            return int(y), int(x)
     
        def move_by_policy(self):
     
            if self.improvement_count != 0 and self.is_moving != 1:
                self.is_moving = 1
                x, y = self.canvas.coords(self.rectangle)
                self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
     
                x, y = self.find_rectangle()
                while len(self.agent.get_action([x, y])) != 0:
                    action = random.sample(self.agent.get_action([x, y]), 1)[0]
                    self.after(100self.rectangle_move(action))
                    x, y = self.find_rectangle()
                self.is_moving = 0
     
        def draw_one_arrow(self, col, row, action):
            if col == 2 and row == 2:
                return
            if action == 0:  # up
                origin_x, origin_y = 50 + (UNIT * row), 10 + (UNIT * col)
                self.arrows.append(self.canvas.create_image(origin_x, origin_y,
                                                            image=self.up))
            elif action == 1:  # down
                origin_x, origin_y = 50 + (UNIT * row), 90 + (UNIT * col)
                self.arrows.append(self.canvas.create_image(origin_x, origin_y,
                                                            image=self.down))
            elif action == 3:  # right
                origin_x, origin_y = 90 + (UNIT * row), 50 + (UNIT * col)
                self.arrows.append(self.canvas.create_image(origin_x, origin_y,
                                                            image=self.right))
            elif action == 2:  # left
                origin_x, origin_y = 10 + (UNIT * row), 50 + (UNIT * col)
                self.arrows.append(self.canvas.create_image(origin_x, origin_y,
                                                            image=self.left))
     
        def draw_from_values(self, state, action_list):
            i = state[0]
            j = state[1]
            for action in action_list:
                self.draw_one_arrow(i, j, action)
     
        def print_values(self, values):
            for i in range(WIDTH):
                for j in range(HEIGHT):
                    self.text_value(i, j, round(values[i][j], 2))
     
        def render(self):
            time.sleep(0.1)
            self.canvas.tag_raise(self.rectangle)
            self.update()
     
        def calculate_value(self):
            self.iteration_count += 1
            for i in self.texts:
                self.canvas.delete(i)
            self.agent.value_iteration()
            self.print_values(self.agent.value_table)
     
        def print_optimal_policy(self):
            self.improvement_count += 1
            for i in self.arrows:
                self.canvas.delete(i)
            for state in self.env.get_all_states():
                action = self.agent.get_action(state)
                self.draw_from_values(state, action)
     
     
    class Env:
        def __init__(self):
            self.transition_probability = TRANSITION_PROB
            self.width = WIDTH  # Width of Grid World
            self.height = HEIGHT  # Height of GridWorld
            self.reward = [[0* WIDTH for _ in range(HEIGHT)]
            self.possible_actions = POSSIBLE_ACTIONS
            self.reward[2][2= 1  # reward 1 for circle
            self.reward[1][2= -1  # reward -1 for triangle
            self.reward[2][1= -1  # reward -1 for triangle
            self.all_state = []
     
            for x in range(WIDTH):
                for y in range(HEIGHT):
                    state = [x, y]
                    self.all_state.append(state)
     
        def get_reward(self, state, action):
            next_state = self.state_after_action(state, action)
            return self.reward[next_state[0]][next_state[1]]
     
        def state_after_action(self, state, action_index):
            action = ACTIONS[action_index]
            return self.check_boundary([state[0+ action[0], state[1+ action[1]])
     
        @staticmethod
        def check_boundary(state):
            state[0= (0 if state[0< 0 else WIDTH - 1
            if state[0> WIDTH - 1 else state[0])
            state[1= (0 if state[1< 0 else HEIGHT - 1
            if state[1> HEIGHT - 1 else state[1])
            return state
     
        def get_transition_prob(self, state, action):
            return self.transition_probability
     
        def get_all_states(self):
            return self.all_state
     
    cs

     

    value_iteration.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    import numpy as np
    from environment import GraphicDisplay, Env
     
     
    class ValueIteration:
        def __init__(self, env):
            # 환경에 대한 객체 선언
            self.env = env
            # 가치 함수를 2차원 리스트로 초기화
            self.value_table = [[0.0* env.width for _ in range(env.height)]
            # 할인율
            self.discount_factor = 0.9
     
        # 벨만 최적 방정식을 통해 다음 가치 함수 계산
        def value_iteration(self):
            # 다음 가치함수 초기화
            next_value_table = [[0.0* self.env.width 
                               for _ in range(self.env.height)]
     
            # 모든 상태에 대해서 벨만 최적방정식을 계산                           
            for state in self.env.get_all_states():
                # 마침 상태의 가치 함수 = 0
                if state == [22]:
                    next_value_table[state[0]][state[1]] = 0.0
                    continue
     
                # 벨만 최적 방정식
                value_list = []
                for action in self.env.possible_actions:
                    next_state = self.env.state_after_action(state, action)
                    reward = self.env.get_reward(state, action)
                    next_value = self.get_value(next_state)
                    value_list.append((reward + self.discount_factor * next_value))
     
                # 최댓값을 다음 가치 함수로 대입 ( 기대방정식과 다른 점 )
                next_value_table[state[0]][state[1]] = max(value_list)
     
            self.value_table = next_value_table
     
        # 현재 가치 함수로부터 행동을 반환
        def get_action(self, state):
            if state == [22]:
                return []
     
            # 모든 행동에 대해 (보상 + (감가율 * 다음 상태 가치함수)) = 큐함수 를 계산
            value_list = []
            for action in self.env.possible_actions:
                next_state = self.env.state_after_action(state, action)
                reward = self.env.get_reward(state, action)
                next_value = self.get_value(next_state)
                value = (reward + self.discount_factor * next_value)
                value_list.append(value)
     
            # 최대 큐 함수를 가진 행동(복수일 경우 여러 개)을 반환
            max_idx_list = np.argwhere(value_list == np.amax(value_list))
            action_list = max_idx_list.flatten().tolist()
            return action_list # 정책이터레이션과 다르게 정책을 반환하지않음. 최적방정식의 경우는 정책이 밖으로 드러나지 않으므로 ! 오직 가치함수 최댓값만으로 행동을 결정
     
        def get_value(self, state):
            return self.value_table[state[0]][state[1]]
     
     
    if __name__ == "__main__":
        env = Env()
        value_iteration = ValueIteration(env)
        grid_world = GraphicDisplay(value_iteration)
        grid_world.mainloop()
     
    cs

    댓글

Designed by Tistory.