yan_z 2020. 12. 6. 23:31

environment.py

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import time
import numpy as np
import tkinter as tk
from PIL import ImageTk, Image
 
np.random.seed(1)
PhotoImage = ImageTk.PhotoImage
UNIT = 100  # 필셀 수
HEIGHT = 5  # 그리드 월드 가로
WIDTH = 5  # 그리드 월드 세로
 
 
class Env(tk.Tk):
    def __init__(self):
        super(Env, self).__init__()
        self.action_space = ['u''d''l''r'#상,하,좌,우
        self.n_actions = len(self.action_space)
        self.title('SARSA')
        self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))
        self.shapes = self.load_images()
        self.canvas = self._build_canvas()
        self.texts = []
 
    def _build_canvas(self):
        canvas = tk.Canvas(self, bg='white',
                           height=HEIGHT * UNIT,
                           width=WIDTH * UNIT)
        # 그리드 생성
        for c in range(0, WIDTH * UNIT, UNIT):  # 0~400 by 80
            x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT
            canvas.create_line(x0, y0, x1, y1)
        for r in range(0, HEIGHT * UNIT, UNIT):  # 0~400 by 80
            x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r
            canvas.create_line(x0, y0, x1, y1)
 
        # 캔버스에 이미지 추가
        self.rectangle = canvas.create_image(5050, image=self.shapes[0])
        self.triangle1 = canvas.create_image(250150, image=self.shapes[1])
        self.triangle2 = canvas.create_image(150250, image=self.shapes[1])
        self.circle = canvas.create_image(250250, image=self.shapes[2])
 
        canvas.pack()
 
        return canvas
 
    def load_images(self):
        rectangle = PhotoImage(
            Image.open("../img/rectangle.png").resize((6565)))
        triangle = PhotoImage(
            Image.open("../img/triangle.png").resize((6565)))
        circle = PhotoImage(
            Image.open("../img/circle.png").resize((6565)))
 
        return rectangle, triangle, circle
 
    def text_value(self, row, col, contents, action, font='Helvetica', size=10,
                   style='normal', anchor="nw"):
        if action == 0:
            origin_x, origin_y = 742
        elif action == 1:
            origin_x, origin_y = 8542
        elif action == 2:
            origin_x, origin_y = 425
        else:
            origin_x, origin_y = 4277
 
        x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
        font = (font, str(size), style)
        text = self.canvas.create_text(x, y, fill="black", text=contents,
                                       font=font, anchor=anchor)
        return self.texts.append(text)
 
    def print_value_all(self, q_table):
        for i in self.texts:
            self.canvas.delete(i)
        self.texts.clear()
        for x in range(HEIGHT):
            for y in range(WIDTH):
                for action in range(04):
                    state = [x, y]
                    if str(state) in q_table.keys():
                        temp = q_table[str(state)][action]
                        self.text_value(y, x, round(temp, 3), action)
 
    def coords_to_state(self, coords):
        x = int((coords[0- 50/ 100)
        y = int((coords[1- 50/ 100)
        return [x, y]
 
    def reset(self):
        self.update()
        time.sleep(0.5)
        x, y = self.canvas.coords(self.rectangle)
        self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
        self.render()
        return self.coords_to_state(self.canvas.coords(self.rectangle))
 
    def step(self, action):
        state = self.canvas.coords(self.rectangle)
        base_action = np.array([00])
        self.render()
 
        if action == 0:  # 상
            if state[1> UNIT:
                base_action[1-= UNIT
        elif action == 1:  # 하
            if state[1< (HEIGHT - 1* UNIT:
                base_action[1+= UNIT
        elif action == 2:  # 좌
            if state[0> UNIT:
                base_action[0-= UNIT
        elif action == 3:  # 우
            if state[0< (WIDTH - 1* UNIT:
                base_action[0+= UNIT
 
        # 에이전트 이동
        self.canvas.move(self.rectangle, base_action[0], base_action[1])
        # 에이전트(빨간 네모)를 가장 상위로 배치
        self.canvas.tag_raise(self.rectangle)
        next_state = self.canvas.coords(self.rectangle)
 
        # 보상 함수
        if next_state == self.canvas.coords(self.circle):
            reward = 100
            done = True
        elif next_state in [self.canvas.coords(self.triangle1),
                            self.canvas.coords(self.triangle2)]:
            reward = -100
            done = True
        else:
            reward = 0
            done = False
 
        next_state = self.coords_to_state(next_state)
        return next_state, reward, done
 
    def render(self):
        time.sleep(0.03)
        self.update()
cs

 

agent.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import random
from collections import defaultdict
from environment import Env
 
 
class SARSAgent:
    def __init__(self, actions):
        self.actions = actions
        self.step_size = 0.01 # 학습률 0.01
        self.discount_factor = 0.9 # 감가율 0.9
        self.epsilon = 0.1 # 입실론 0.1
        # 0을 초기값으로 가지는 큐함수 테이블 생성
        self.q_table = defaultdict(lambda: [0.00.00.00.0])
 
    # <s, a, r, s', a'>의 샘플로부터 큐함수를 업데이트 ( 벨만 기대방정식 이용 )
    def learn(self, state, action, reward, next_state, next_action):
        state, next_state = str(state), str(next_state)
        current_q = self.q_table[state][action]
        next_state_q = self.q_table[next_state][next_action]
        td = reward + self.discount_factor * next_state_q - current_q
        new_q = current_q + self.step_size * td  # 살사의 큐함수 업데이트 식
        self.q_table[state][action] = new_q
 
    # 입실론 탐욕 정책에 따라서 행동을 반환
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            # 무작위 행동 반환
            action = np.random.choice(self.actions)
        else:
            # 큐함수에 따른 행동 반환
            state = str(state)
            q_list = self.q_table[state]
            action = arg_max(q_list)
        return action
 
 
# 큐함수의 값에 따라 최적의 행동을 반환
def arg_max(q_list):
    max_idx_list = np.argwhere(q_list == np.amax(q_list))
    max_idx_list = max_idx_list.flatten().tolist()
    return random.choice(max_idx_list)
 
 
if __name__ == "__main__":
    env = Env() #환경설정
    agent = SARSAgent(actions=list(range(env.n_actions))) # 에이전트생성
 
    for episode in range(1000): # 에피소드는 1000 개
        # 게임 환경과 상태를 초기화
        state = env.reset()
        # 현재 상태에 대한 행동을 선택
        action = agent.get_action(state)
 
        while True:
            env.render()
 
            # 행동을 위한 후 다음상태 보상 에피소드의 종료 여부를 받아옴
            next_state, reward, done = env.step(action)
            # 다음 상태에서의 다음 행동 선택
            next_action = agent.get_action(next_state)
            # <s,a,r,s',a'>로 큐함수를 업데이트
            agent.learn(state, action, reward, next_state, next_action)
 
            state = next_state
            action = next_action
 
            # 모든 큐함수를 화면에 표시
            env.print_value_all(agent.q_table)
 
            if done:
                break
cs