DmitriiKhizbullin commited on
Commit
0b20b6b
Β·
1 Parent(s): 7180fe0

Copy from camel repo

Browse files
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from apps.data_explorer.data_explorer import main
2
+ from apps.data_explorer.downloader import download_data
3
+
4
+ if __name__ == "__main__":
5
+ download_data()
6
+ main()
apps/data_explorer/data_explorer.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio-based web UI to explore the Camel dataset.
3
+ """
4
+
5
+ import argparse
6
+ import random
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import gradio as gr
10
+
11
+ from apps.data_explorer.loader import Datasets, load_datasets
12
+
13
+
14
+ def parse_arguments():
15
+ """ Get command line arguments. """
16
+
17
+ parser = argparse.ArgumentParser("Camel data explorer")
18
+ parser.add_argument(
19
+ '--data-path', type=str, default=None,
20
+ help='Path to the folder with ZIP datasets containing JSONs')
21
+ parser.add_argument('--default-dataset', type=str, default=None,
22
+ help='Default dataset name selected from ZIPs')
23
+ parser.add_argument('--share', type=bool, default=False,
24
+ help='Expose the web UI to Gradio')
25
+ parser.add_argument('--server-port', type=int, default=8080,
26
+ help='Port ot run the web page on')
27
+ parser.add_argument('--inbrowser', type=bool, default=False,
28
+ help='Open the web UI in the default browser on lunch')
29
+ parser.add_argument(
30
+ '--concurrency-count', type=int, default=10,
31
+ help='Number if concurrent threads at Gradio websocket queue. ' +
32
+ 'Increase to serve more requests but keep an eye on RAM usage.')
33
+ args, unknown = parser.parse_known_args()
34
+ if len(unknown) > 0:
35
+ print("Unknown args: ", unknown)
36
+ return args
37
+
38
+
39
+ def construct_ui(blocks, datasets: Datasets, default_dataset: str = None):
40
+ """ Build Gradio UI and populate with chat data from JSONs.
41
+
42
+ Args:
43
+ blocks: Gradio blocks
44
+ datasets (Datasets): Several parsed
45
+ multi-JSON dataset with chats.
46
+ default_dataset (str): Default selection of the dataset.
47
+
48
+ Returns:
49
+ None
50
+ """
51
+
52
+ if default_dataset is None:
53
+ default_dataset = "ai_society_chat"
54
+
55
+ misalignment_set_names = {"misalignment"}
56
+ ordinary_datasets = [
57
+ v for v in datasets.keys() if v not in misalignment_set_names
58
+ ]
59
+ misalignment_datasets = [
60
+ v for v in datasets.keys() if v in misalignment_set_names
61
+ ]
62
+ default_dataset_name = default_dataset \
63
+ if default_dataset in datasets.keys() \
64
+ else ordinary_datasets[0] if len(ordinary_datasets) > 0 \
65
+ else misalignment_datasets[0] if len(misalignment_datasets) > 0 \
66
+ else ""
67
+ dataset_names = list(datasets.keys())
68
+
69
+ with gr.Row().style():
70
+ with gr.Column(scale=2):
71
+ with gr.Row():
72
+ dataset_dd = gr.Dropdown(dataset_names, label="Select dataset",
73
+ value="NODEFAULT", interactive=True)
74
+ with gr.Row():
75
+ disclaimer_ta = gr.Markdown(
76
+ "## By clicking AGREE I consent to use the dataset "
77
+ "for purely educational and academic purposes and "
78
+ "not use it for any fraudulent activity; and I take "
79
+ "all the responsibility if the data is used in a "
80
+ "malicious application.", visible=False)
81
+ with gr.Row():
82
+ with gr.Column(scale=1):
83
+ accept_disclaimer_bn = gr.Button("AGREE", visible=False)
84
+ with gr.Column(scale=1):
85
+ decline_disclaimer_bn = gr.Button("DECLINE", visible=False)
86
+ with gr.Row():
87
+ with gr.Column(scale=3):
88
+ assistant_dd = gr.Dropdown([], label="ASSISTANT", value="",
89
+ interactive=True)
90
+ with gr.Column(scale=3):
91
+ user_dd = gr.Dropdown([], label="USER", value="",
92
+ interactive=True)
93
+ with gr.Column(scale=1):
94
+ gr.Markdown(
95
+ "## CAMEL: Communicative Agents for \"Mind\" Exploration"
96
+ " of Large Scale Language Model Society\n"
97
+ "Github repo: [https://github.com/lightaime/camel]"
98
+ "(https://github.com/lightaime/camel)\n"
99
+ '<div style="display:flex; justify-content:center;">'
100
+ '<img src="https://raw.githubusercontent.com/lightaime/camel/'
101
+ 'master/misc/logo.png" alt="Logo" style="max-width:50%;">'
102
+ '</div>')
103
+
104
+ task_dd = gr.Dropdown([], label="Original task", value="",
105
+ interactive=True)
106
+ specified_task_ta = gr.TextArea(label="Specified task", lines=2)
107
+ chatbot = gr.Chatbot()
108
+ accepted_st = gr.State(False)
109
+
110
+ def set_default_dataset() -> Dict:
111
+ """ Trigger for app load.
112
+
113
+ Returns:
114
+ Dict: Update dict for dataset_dd.
115
+ """
116
+ return gr.update(value=default_dataset_name)
117
+
118
+ def check_if_misalignment(dataset_name: str, accepted: bool) \
119
+ -> Tuple[Dict, Dict, Dict]:
120
+ """ Display AGREE/DECLINE if needed.
121
+
122
+ Returns:
123
+ Tuple: Visibility updates for the buttons.
124
+ """
125
+
126
+ if dataset_name == "misalignment" and not accepted:
127
+ return gr.update(visible=True), \
128
+ gr.update(visible=True), gr.update(visible=True)
129
+ else:
130
+ return gr.update(visible=False), \
131
+ gr.update(visible=False), gr.update(visible=False)
132
+
133
+ def enable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
134
+ """ Update the state of the accepted disclaimer.
135
+
136
+ Returns:
137
+ Tuple: New state and visibility updates for the buttons.
138
+ """
139
+
140
+ return True, gr.update(visible=False), \
141
+ gr.update(visible=False), gr.update(visible=False)
142
+
143
+ def disable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
144
+ """ Update the state of the accepted disclaimer.
145
+
146
+ Returns:
147
+ Tuple: New state and visibility updates for the buttons.
148
+ """
149
+
150
+ return False, gr.update(visible=False), \
151
+ gr.update(visible=False), gr.update(visible=False)
152
+
153
+ def update_dataset_selection(dataset_name: str,
154
+ accepted: bool) -> Tuple[Dict, Dict]:
155
+ """ Update roles based on the selected dataset.
156
+
157
+ Args:
158
+ dataset_name (str): Name of the loaded .zip dataset.
159
+ accepted (bool): If the disclaimer thas been accepted.
160
+
161
+ Returns:
162
+ Tuple[Dict, Dict]: New Assistant and User roles.
163
+ """
164
+
165
+ if dataset_name == "misalignment" and not accepted:
166
+ # If used did not accept the misalignment policy,
167
+ # keep the old selection.
168
+ return (gr.update(value="N/A",
169
+ choices=[]), gr.update(value="N/A", choices=[]))
170
+
171
+ dataset = datasets[dataset_name]
172
+ assistant_roles = dataset['assistant_roles']
173
+ user_roles = dataset['user_roles']
174
+ assistant_role = random.choice(assistant_roles) \
175
+ if len(assistant_roles) > 0 else ""
176
+ user_role = random.choice(user_roles) if len(user_roles) > 0 else ""
177
+ return (gr.update(value=assistant_role, choices=assistant_roles),
178
+ gr.update(value=user_role, choices=user_roles))
179
+
180
+ def roles_dd_change(dataset_name: str, assistant_role: str,
181
+ user_role: str) -> Dict:
182
+ """ Update the displayed chat upon inputs change.
183
+
184
+ Args:
185
+ assistant_role (str): Assistant dropdown value.
186
+ user_role (str): User dropdown value.
187
+
188
+ Returns:
189
+ Dict: New original roles state dictionary.
190
+ """
191
+ matrix = datasets[dataset_name]['matrix']
192
+ if (assistant_role, user_role) in matrix:
193
+ record: Dict[str, Dict] = matrix[(assistant_role, user_role)]
194
+ original_task_options = list(record.keys())
195
+ original_task = original_task_options[0]
196
+ else:
197
+ original_task = "N/A"
198
+ original_task_options = []
199
+
200
+ choices = gr.Dropdown.update(choices=original_task_options,
201
+ value=original_task, interactive=True)
202
+ return choices
203
+
204
+ def build_chat_history(messages: Dict[int, Dict]) -> List[Tuple]:
205
+ """ Structures chatbot contents from the loaded data.
206
+
207
+ Args:
208
+ messages (Dict[int, Dict]): Messages loaded from JSON.
209
+
210
+ Returns:
211
+ List[Tuple]: Chat history in chatbot UI element format.
212
+ """
213
+ history = []
214
+ curr_qa = (None, None)
215
+ for k in sorted(messages.keys()):
216
+ msg = messages[k]
217
+ content = msg['content']
218
+ if msg['role_type'] == "USER":
219
+ if curr_qa[0] is not None:
220
+ history.append(curr_qa)
221
+ curr_qa = (content, None)
222
+ else:
223
+ curr_qa = (content, None)
224
+ elif msg['role_type'] == "ASSISTANT":
225
+ curr_qa = (curr_qa[0], content)
226
+ history.append(curr_qa)
227
+ curr_qa = (None, None)
228
+ else:
229
+ pass
230
+ return history
231
+
232
+ def task_dd_change(dataset_name: str, assistant_role: str, user_role: str,
233
+ original_task: str) -> Tuple[str, List]:
234
+ """ Load task details and chatbot history into UI elements.
235
+
236
+ Args:
237
+ assistant_role (str): An assistan role.
238
+ user_role (str): An user role.
239
+ original_task (str): The original task.
240
+
241
+ Returns:
242
+ Tuple[str, List]: New contents of the specified task
243
+ and chatbot history UI elements.
244
+ """
245
+
246
+ matrix = datasets[dataset_name]['matrix']
247
+ if (assistant_role, user_role) in matrix:
248
+ task_dict: Dict[str, Dict] = matrix[(assistant_role, user_role)]
249
+ if original_task in task_dict:
250
+ chat = task_dict[original_task]
251
+ specified_task = chat['specified_task']
252
+ history = build_chat_history(chat['messages'])
253
+ else:
254
+ specified_task = "N/A"
255
+ history = []
256
+ else:
257
+ specified_task = "N/A"
258
+ history = []
259
+ return specified_task, history
260
+
261
+ dataset_dd.change(check_if_misalignment, [dataset_dd, accepted_st],
262
+ [disclaimer_ta, accept_disclaimer_bn,
263
+ decline_disclaimer_bn]) \
264
+ .then(update_dataset_selection,
265
+ [dataset_dd, accepted_st],
266
+ [assistant_dd, user_dd])
267
+
268
+ accept_disclaimer_bn.click(enable_misalignment, None, [
269
+ accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
270
+ ]) \
271
+ .then(update_dataset_selection,
272
+ [dataset_dd, accepted_st],
273
+ [assistant_dd, user_dd])
274
+
275
+ decline_disclaimer_bn.click(disable_misalignment, None, [
276
+ accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
277
+ ]) \
278
+ .then(update_dataset_selection,
279
+ [dataset_dd, accepted_st],
280
+ [assistant_dd, user_dd])
281
+
282
+ func_args = (roles_dd_change, [dataset_dd, assistant_dd, user_dd], task_dd)
283
+ assistant_dd.change(*func_args)
284
+ user_dd.change(*func_args)
285
+
286
+ task_dd.change(task_dd_change,
287
+ [dataset_dd, assistant_dd, user_dd, task_dd],
288
+ [specified_task_ta, chatbot])
289
+
290
+ blocks.load(set_default_dataset, None, dataset_dd)
291
+
292
+
293
+ def construct_blocks(data_path: str, default_dataset: Optional[str]):
294
+ """ Construct Blocs app but do not launch it.
295
+
296
+ Args:
297
+ data_path (str): Path to the set of ZIP datasets.
298
+ default_dataset (Optional[str]): Name of the default dataset,
299
+ without extension.
300
+
301
+ Returns:
302
+ gr.Blocks: Blocks instance.
303
+ """
304
+
305
+ print("Loading the dataset...")
306
+ datasets = load_datasets(data_path)
307
+ print("Dataset is loaded")
308
+
309
+ print("Getting Data Explorer web server online...")
310
+
311
+ with gr.Blocks() as blocks:
312
+ construct_ui(blocks, datasets, default_dataset)
313
+
314
+ return blocks
315
+
316
+
317
+ def main():
318
+ """ Entry point. """
319
+
320
+ args = parse_arguments()
321
+
322
+ blocks = construct_blocks(args.data_path, args.default_dataset)
323
+
324
+ blocks.queue(args.concurrency_count) \
325
+ .launch(share=args.share, inbrowser=args.inbrowser,
326
+ server_name="0.0.0.0", server_port=args.server_port)
327
+
328
+ print("Exiting.")
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()
apps/data_explorer/downloader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ REPO_ROOT = os.path.realpath(
7
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
8
+
9
+
10
+ def download_data():
11
+
12
+ print("Downloading...")
13
+
14
+ data_dir = os.path.join(REPO_ROOT, "datasets/")
15
+
16
+ os.makedirs(data_dir, exist_ok=True)
17
+
18
+ try:
19
+ hf_hub_download(repo_id="camel-ai/ai_society", repo_type="dataset",
20
+ filename="ai_society_chat.zip", local_dir=data_dir,
21
+ local_dir_use_symlinks=False)
22
+
23
+ hf_hub_download(repo_id="camel-ai/code", repo_type="dataset",
24
+ filename="code_chat.zip", local_dir=data_dir,
25
+ local_dir_use_symlinks=False)
26
+ except:
27
+ for name in ("ai_society_chat.zip", "code_chat.zip"):
28
+ data_url = ("https://storage.googleapis.com/"
29
+ f"camel-bucket/datasets/private/{name}")
30
+ file_path = os.path.join(data_dir, os.path.split(data_url)[1])
31
+ urllib.request.urlretrieve(data_url, file_path)
32
+
33
+ data_url = ("https://storage.googleapis.com/"
34
+ "camel-bucket/datasets/private/misalignment.zip")
35
+ file_path = os.path.join(data_dir, os.path.split(data_url)[1])
36
+ urllib.request.urlretrieve(data_url, file_path)
37
+
38
+ print("Download done")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ download_data()
apps/data_explorer/loader.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Everything related to parsing the data JSONs into UI-compatible format.
3
+ """
4
+
5
+ import glob
6
+ import json
7
+ import os
8
+ import re
9
+ import zipfile
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+
12
+ from tqdm import tqdm
13
+
14
+ ChatHistory = Dict[str, Any]
15
+ ParsedChatHistory = Dict[str, Any]
16
+ AllChats = Dict[str, Any]
17
+ Datasets = Dict[str, AllChats]
18
+
19
+ REPO_ROOT = os.path.realpath(
20
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
21
+
22
+
23
+ class AutoZip:
24
+ def __init__(self, zip_path: str, ext: str = ".json"):
25
+ self.zip_path = zip_path
26
+ self.zip = zipfile.ZipFile(zip_path, "r")
27
+ self.fl = [f for f in self.zip.filelist if f.filename.endswith(ext)]
28
+
29
+ def __next__(self):
30
+ if self.index >= len(self.fl):
31
+ raise StopIteration
32
+ else:
33
+ finfo = self.fl[self.index]
34
+ with self.zip.open(finfo) as f:
35
+ raw_json = json.loads(f.read().decode("utf-8"))
36
+ self.index += 1
37
+ return raw_json
38
+
39
+ def __len__(self):
40
+ return len(self.fl)
41
+
42
+ def __iter__(self):
43
+ self.index = 0
44
+ return self
45
+
46
+
47
+ def parse(raw_chat: ChatHistory) -> Union[ParsedChatHistory, None]:
48
+ """ Gets the JSON raw chat data, validates it and transforms
49
+ into an easy to work with form.
50
+
51
+ Args:
52
+ raw_chat (ChatHistory): In-memory loaded JSON data file.
53
+
54
+ Returns:
55
+ Union[ParsedChatHistory, None]: Parsed chat data or None
56
+ if there were parsing errors.
57
+ """
58
+
59
+ if "role_1" not in raw_chat:
60
+ return None
61
+
62
+ role_1 = raw_chat["role_1"]
63
+ if "_RoleType.ASSISTANT" not in role_1:
64
+ return None
65
+ assistant_role = role_1.split("_RoleType.ASSISTANT")
66
+ if len(assistant_role) < 1:
67
+ return None
68
+ if len(assistant_role[0]) <= 0:
69
+ return None
70
+ assistant_role = assistant_role[0]
71
+
72
+ role_2 = raw_chat["role_2"]
73
+ if "_RoleType.USER" not in role_2:
74
+ return None
75
+ user_role = role_2.split("_RoleType.USER")
76
+ if len(user_role) < 1:
77
+ return None
78
+ if len(user_role[0]) <= 0:
79
+ return None
80
+ user_role = user_role[0]
81
+
82
+ original_task = raw_chat["original_task"]
83
+ if len(original_task) <= 0:
84
+ return None
85
+
86
+ specified_task = raw_chat["specified_task"]
87
+ if len(specified_task) <= 0:
88
+ return None
89
+
90
+ messages = dict()
91
+ for key in raw_chat:
92
+ match = re.search("message_(?P<number>[0-9]+)", key)
93
+ if match:
94
+ number = int(match.group("number"))
95
+ messages[number] = raw_chat[key]
96
+
97
+ return dict(
98
+ assistant_role=assistant_role,
99
+ user_role=user_role,
100
+ original_task=original_task,
101
+ specified_task=specified_task,
102
+ messages=messages,
103
+ )
104
+
105
+
106
+ def load_zip(zip_path: str) -> AllChats:
107
+ """ Load all JSONs from a zip file and parse them.
108
+
109
+ Args:
110
+ path (str): path to the ZIP file.
111
+
112
+ Returns:
113
+ AllChats: A dictionary with all possible assistant and
114
+ user roles and the matrix of chats.
115
+ """
116
+
117
+ zip_inst = AutoZip(zip_path)
118
+ parsed_list = []
119
+ for raw_chat in tqdm(iter(zip_inst)):
120
+ parsed = parse(raw_chat)
121
+ if parsed is None:
122
+ continue
123
+ parsed_list.append(parsed)
124
+
125
+ assistant_roles = set()
126
+ user_roles = set()
127
+ for parsed in parsed_list:
128
+ assistant_roles.add(parsed['assistant_role'])
129
+ user_roles.add(parsed['user_role'])
130
+ assistant_roles = list(sorted(assistant_roles))
131
+ user_roles = list(sorted(user_roles))
132
+ matrix: Dict[Tuple[str, str], List[Dict]] = dict()
133
+ for parsed in parsed_list:
134
+ key = (parsed['assistant_role'], parsed['user_role'])
135
+ original_task = parsed['original_task']
136
+ new_item = {
137
+ k: v
138
+ for k, v in parsed.items()
139
+ if k not in {'assistant_role', 'user_role', 'original_task'}
140
+ }
141
+ if key in matrix:
142
+ matrix[key][original_task] = new_item
143
+ else:
144
+ matrix[key] = {original_task: new_item}
145
+
146
+ return dict(
147
+ assistant_roles=assistant_roles,
148
+ user_roles=user_roles,
149
+ matrix=matrix,
150
+ )
151
+
152
+
153
+ def load_datasets(path: Optional[str] = None) -> Datasets:
154
+ """ Load all JSONs from a set of zip files and parse them.
155
+
156
+ Args:
157
+ path (str): path to the folder with ZIP datasets.
158
+
159
+ Returns:
160
+ Datasets: A dictionary of dataset name and dataset contents.
161
+ """
162
+
163
+ if path is None:
164
+ path = os.path.join(REPO_ROOT, "datasets")
165
+
166
+ filt = os.path.join(path, "*.zip")
167
+ files = glob.glob(filt)
168
+ datasets = {}
169
+ for file_name in tqdm(files):
170
+ name = os.path.splitext(os.path.basename(file_name))[0]
171
+ datasets[name] = load_zip(file_name)
172
+ return datasets