Delta-Vector commited on
Commit
0b5d878
·
verified ·
1 Parent(s): 33dbf8a

Upload sharegpt_polar.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sharegpt_polar.py +462 -0
sharegpt_polar.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, Iterable
6
+
7
+ from datasets import Dataset, load_dataset
8
+ import httpx
9
+
10
+ import verifiers as vf
11
+ from verifiers.types import Messages, State
12
+
13
+ DEFAULT_MODEL = "internlm/POLAR-7B"
14
+ POOL_ENDPOINT = "/pooling"
15
+
16
+
17
+ def _ensure_messages(conversations: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
18
+ messages: list[dict[str, Any]] = []
19
+ for turn in conversations:
20
+ role = turn.get("from") or turn.get("role")
21
+ content = turn.get("value") or turn.get("content")
22
+ if role == "system":
23
+ messages.append({"role": "system", "content": content})
24
+ elif role == "human" or role == "user":
25
+ messages.append({"role": "user", "content": content})
26
+ elif role in {"gpt", "assistant"}:
27
+ messages.append({"role": "assistant", "content": content})
28
+ return messages
29
+
30
+
31
+ def _has_assistant(conversations: Iterable[dict[str, Any]]) -> bool:
32
+ return any(
33
+ (turn.get("from") or turn.get("role")) in {"gpt", "assistant"}
34
+ for turn in conversations
35
+ )
36
+
37
+
38
+ def _partition_conversation(
39
+ messages: list[dict[str, Any]]
40
+ ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[list[dict[str, Any]]], list[str]]:
41
+ assistant_indices = [idx for idx, msg in enumerate(messages) if msg["role"] == "assistant"]
42
+ if not assistant_indices:
43
+ raise ValueError("Conversation must include at least one assistant response")
44
+
45
+ first_assistant_idx = assistant_indices[0]
46
+ prompt_messages = messages[:first_assistant_idx]
47
+ if not any(msg["role"] == "user" for msg in prompt_messages):
48
+ raise ValueError("Conversation must include a user message before the first assistant turn")
49
+
50
+ reference_messages = [messages[idx] for idx in assistant_indices]
51
+
52
+ future_turns: list[list[dict[str, Any]]] = []
53
+ user_contexts: list[str] = []
54
+ assistant_indices_with_end = assistant_indices + [len(messages)]
55
+ for current_idx, next_idx in zip(assistant_indices, assistant_indices_with_end[1:]):
56
+ env_msgs: list[dict[str, Any]] = []
57
+ user_context_lines: list[str] = []
58
+ for i in range(current_idx + 1, next_idx):
59
+ turn = messages[i]
60
+ role = turn["role"]
61
+ content = turn["content"]
62
+ if role == "system":
63
+ continue
64
+ if role == "user":
65
+ line = (content or "").strip()
66
+ if line:
67
+ user_context_lines.append(line)
68
+ env_msgs.append({"role": "user", "content": line})
69
+ else:
70
+ env_msgs.append(turn)
71
+ future_turns.append(env_msgs)
72
+ user_contexts.append("\n".join(user_context_lines).strip())
73
+
74
+ return prompt_messages, reference_messages, future_turns, user_contexts
75
+
76
+
77
+ def _extract_last_value(value: Any) -> float | None:
78
+ current: Any = value
79
+ while isinstance(current, list) and current:
80
+ current = current[-1]
81
+ if isinstance(current, (int, float)):
82
+ return float(current)
83
+ return None
84
+
85
+
86
+ class PoolingClient:
87
+ def __init__(
88
+ self,
89
+ base_url: str,
90
+ model: str = DEFAULT_MODEL,
91
+ timeout: float = 30.0,
92
+ logger: logging.Logger | None = None,
93
+ enable_logging: bool = False,
94
+ ):
95
+ self.base_url = base_url.rstrip("/")
96
+ if not self.base_url.startswith("http"):
97
+ self.base_url = f"https://{self.base_url}"
98
+ self.timeout = timeout
99
+ self.model = model
100
+ self.logger = logger or logging.getLogger("sharegpt_polar.PoolingClient")
101
+ self.enable_logging = enable_logging
102
+
103
+ @staticmethod
104
+ def encode(sample: dict[str, Any]) -> str:
105
+ def _messages_to_text(messages: list[dict[str, Any]] | None) -> str:
106
+ if not messages:
107
+ return ""
108
+ return "\n".join(msg.get("content", "") for msg in messages if msg.get("content"))
109
+
110
+ prompt_text = _messages_to_text(sample.get("prompt"))
111
+ reference_text = _messages_to_text(sample.get("reference"))
112
+ output_text = _messages_to_text(sample.get("output"))
113
+
114
+ reference_cat = f"{prompt_text}\n{reference_text}" if reference_text else prompt_text
115
+ output_cat = f"{prompt_text}\n{output_text}" if output_text else prompt_text
116
+
117
+ return f"{reference_cat}<|reward|>{output_cat}[UNUSED_TOKEN_130]"
118
+
119
+ async def score(self, payload: list[dict[str, Any]]) -> dict[str, Any] | list[Any]:
120
+ encoded_payload = [self.encode(item) for item in payload]
121
+ if self.enable_logging:
122
+ self.logger.debug(
123
+ "Sending reward request",
124
+ extra={
125
+ "payload_size": len(payload),
126
+ "model": self.model,
127
+ "endpoint": self.base_url,
128
+ },
129
+ )
130
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
131
+ response = await client.post(
132
+ f"{self.base_url}{POOL_ENDPOINT}",
133
+ json={"model": self.model, "input": encoded_payload},
134
+ headers={"Content-Type": "application/json"},
135
+ )
136
+ try:
137
+ response.raise_for_status()
138
+ except httpx.HTTPStatusError as exc:
139
+ if self.enable_logging:
140
+ self.logger.error(
141
+ "Reward server request failed",
142
+ extra={
143
+ "status": exc.response.status_code,
144
+ "body": exc.response.text,
145
+ },
146
+ )
147
+ raise RuntimeError(
148
+ f"Pooling request failed: {exc.response.status_code} {exc.response.text}"
149
+ ) from exc
150
+ if self.enable_logging:
151
+ self.logger.debug(
152
+ "Received reward response",
153
+ extra={
154
+ "status": response.status_code,
155
+ "model": self.model,
156
+ },
157
+ )
158
+ return response.json()
159
+
160
+
161
+ async def polar_reward(
162
+ prompt: Messages,
163
+ completion: Messages,
164
+ info: dict[str, Any],
165
+ reward_client: PoolingClient,
166
+ logger: logging.Logger | None = None,
167
+ enable_logging: bool = False,
168
+ **_: Any,
169
+ ) -> float:
170
+ assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
171
+ if not assistant_turns:
172
+ if enable_logging:
173
+ (logger or logging.getLogger("sharegpt_polar.reward")).debug(
174
+ "No assistant turn available for reward",
175
+ extra={"prompt": prompt, "completion": completion},
176
+ )
177
+ return 0.0
178
+
179
+ payload = [
180
+ {
181
+ "prompt": prompt,
182
+ "reference": info.get("reference", []),
183
+ "output": [assistant_turns[-1]],
184
+ }
185
+ ]
186
+ try:
187
+ data = await reward_client.score(payload)
188
+ except RuntimeError as err:
189
+ if enable_logging:
190
+ (logger or logging.getLogger("sharegpt_polar.reward")).exception(
191
+ "Reward request failed", extra={"error": str(err), "payload": payload}
192
+ )
193
+ raise
194
+ if enable_logging:
195
+ (logger or logging.getLogger("sharegpt_polar.reward")).debug(
196
+ "Reward response received", extra={"response": data}
197
+ )
198
+ if isinstance(data, dict):
199
+ if "data" in data:
200
+ scores = data["data"][0]["data"]
201
+ last_value = _extract_last_value(scores)
202
+ if last_value is not None:
203
+ return last_value
204
+ if "rewards" in data and data["rewards"]:
205
+ last_value = _extract_last_value(data["rewards"])
206
+ if last_value is not None:
207
+ return last_value
208
+ if isinstance(data, list) and data:
209
+ last_value = _extract_last_value(data)
210
+ if last_value is not None:
211
+ return last_value
212
+ if enable_logging:
213
+ (logger or logging.getLogger("sharegpt_polar.reward")).error(
214
+ "Unexpected reward payload", extra={"response": data}
215
+ )
216
+ raise RuntimeError(f"Unexpected reward response: {data}")
217
+
218
+
219
+ class ShareGPTPolarEnv(vf.MultiTurnEnv):
220
+ def __init__(
221
+ self,
222
+ dataset: Dataset,
223
+ rubric: vf.Rubric,
224
+ *,
225
+ enable_logging: bool = False,
226
+ **kwargs: Any,
227
+ ):
228
+ super().__init__(dataset=dataset, rubric=rubric, **kwargs)
229
+ self.enable_logging = enable_logging
230
+ self.logger = logging.getLogger("sharegpt_polar.env")
231
+
232
+ async def setup_state(self, state: State, **kwargs: Any) -> State:
233
+ state.setdefault("future_turns", state["info"].get("future_turns", []))
234
+ return state
235
+
236
+ async def is_completed(self, messages: Messages, state: State, **kwargs: Any) -> bool:
237
+ total_turns = len(state["info"].get("reference", []))
238
+ if self.enable_logging:
239
+ self.logger.debug(
240
+ "Checking completion state",
241
+ extra={"current_turn": state.get("turn", 0), "total_turns": total_turns},
242
+ )
243
+ return state.get("turn", 0) >= total_turns
244
+
245
+ async def env_response(self, messages: Messages, state: State, **kwargs: Any) -> tuple[Messages, State]:
246
+ future_turns: list[list[dict[str, Any]]] = state.get("future_turns", [])
247
+ turn_index = state.get("turn", 0) - 1
248
+ if self.enable_logging:
249
+ self.logger.debug(
250
+ "Providing future turn",
251
+ extra={"turn_index": turn_index, "future_turn_count": len(future_turns)},
252
+ )
253
+ if 0 <= turn_index < len(future_turns):
254
+ return future_turns[turn_index], state
255
+ return [], state
256
+
257
+ def process_chat_format_vllm( # type: ignore[override]
258
+ self,
259
+ prompt: list[dict[str, Any]],
260
+ completion: list[dict[str, Any]],
261
+ state: State,
262
+ processing_class: Any,
263
+ mask_env_responses: bool = False,
264
+ ) -> tuple[list[int], list[int], list[int], list[int], list[float]]:
265
+ # Clean messages to remove tool-related fields that might trigger template errors
266
+ def clean_message(msg: dict[str, Any]) -> dict[str, Any]:
267
+ return {k: v for k, v in msg.items() if k not in {"tool_calls", "tool_call_id"}}
268
+
269
+ responses = state.get("responses", [])
270
+ responses_idx = 0
271
+ zipped: list[tuple[dict[str, Any], Any | None]] = []
272
+ for turn in completion:
273
+ if turn.get("role") == "assistant":
274
+ zipped.append((turn, responses[responses_idx]))
275
+ responses_idx += 1
276
+ else:
277
+ zipped.append((turn, None))
278
+ assert len(responses) == responses_idx, "Responses not fully consumed"
279
+ assert len(zipped) == len(completion), "Length mismatch"
280
+
281
+ clean_prompt = [clean_message(msg) for msg in prompt]
282
+ prompt_ids: list[int] = processing_class.apply_chat_template(
283
+ conversation=clean_prompt, # type: ignore[arg-type]
284
+ add_generation_prompt=True,
285
+ tools=None,
286
+ )
287
+ messages_consumed = [clean_message(m) for m in prompt]
288
+ prompt_mask: list[int] = [0] * len(prompt_ids)
289
+ completion_ids: list[int] = []
290
+ completion_mask: list[int] = []
291
+ completion_logprobs: list[float] = []
292
+ i = 0
293
+ while i < len(zipped):
294
+ message, response = zipped[i]
295
+ clean_msg = clean_message(message)
296
+ if message.get("role") == "assistant":
297
+ if response is not None:
298
+ completion_turn_ids = self.parse_chat_completion_tokens(response)
299
+ completion_turn_mask = [1] * len(completion_turn_ids)
300
+ completion_turn_logprobs = self.parse_chat_completion_logprobs(response)
301
+ else:
302
+ completion_turn_ids = []
303
+ completion_turn_mask = []
304
+ completion_turn_logprobs = []
305
+ completion_ids.extend(completion_turn_ids)
306
+ completion_mask.extend(completion_turn_mask)
307
+ completion_logprobs.extend(completion_turn_logprobs)
308
+ messages_consumed.append(clean_msg)
309
+ i += 1
310
+ continue
311
+
312
+ consecutive_messages = [clean_msg]
313
+ j = i + 1
314
+ while j < len(zipped) and zipped[j][0].get("role") != "assistant":
315
+ consecutive_messages.append(clean_message(zipped[j][0]))
316
+ j += 1
317
+
318
+ base_tokens: list[int] = processing_class.apply_chat_template(
319
+ conversation=messages_consumed, # type: ignore[arg-type]
320
+ add_generation_prompt=True,
321
+ tools=None,
322
+ )
323
+ extended_tokens: list[int] = processing_class.apply_chat_template(
324
+ conversation=messages_consumed + consecutive_messages, # type: ignore[arg-type]
325
+ add_generation_prompt=True,
326
+ tools=None,
327
+ )
328
+ prefix_len = 0
329
+ max_len = min(len(base_tokens), len(extended_tokens))
330
+ while prefix_len < max_len and base_tokens[prefix_len] == extended_tokens[prefix_len]:
331
+ prefix_len += 1
332
+ if self.enable_logging and prefix_len != len(base_tokens):
333
+ self.logger.debug(
334
+ "Token prefix adjusted",
335
+ extra={"prefix_len": prefix_len, "base_len": len(base_tokens)},
336
+ )
337
+ completion_turn_ids = extended_tokens[prefix_len:]
338
+ if mask_env_responses:
339
+ completion_turn_mask = [0] * len(completion_turn_ids)
340
+ else:
341
+ completion_turn_mask = [1] * len(completion_turn_ids)
342
+ completion_turn_logprobs = [0.0] * len(completion_turn_ids)
343
+ completion_ids.extend(completion_turn_ids)
344
+ completion_mask.extend(completion_turn_mask)
345
+ completion_logprobs.extend(completion_turn_logprobs)
346
+ messages_consumed.extend(consecutive_messages)
347
+ i = j
348
+
349
+ return (
350
+ prompt_ids,
351
+ prompt_mask,
352
+ completion_ids,
353
+ completion_mask,
354
+ completion_logprobs,
355
+ )
356
+
357
+
358
+ def load_environment(
359
+ dataset_name: str | None = None,
360
+ *,
361
+ dataset_split: str = "train",
362
+ dataset_files: dict[str, str] | None = None,
363
+ data_path: str | Path | None = None,
364
+ server_address: str,
365
+ reward_model: str = DEFAULT_MODEL,
366
+ reward_scheme: type[vf.Rubric] | None = None,
367
+ max_turns: int = -1,
368
+ enable_logging: bool = False,
369
+ logger: logging.Logger | None = None,
370
+ **env_kwargs: Any,
371
+ ) -> ShareGPTPolarEnv:
372
+ if dataset_name is None and data_path is None:
373
+ raise ValueError("Either 'dataset_name' or 'data_path' must be provided")
374
+
375
+ if dataset_name is not None:
376
+ hf_dataset = load_dataset(dataset_name, split=dataset_split, data_files=dataset_files)
377
+ else:
378
+ hf_dataset = load_dataset("json", data_files=str(data_path), split="train")
379
+
380
+ def to_multi_turn(example: dict[str, Any]) -> dict[str, Any]:
381
+ conversations = example.get("conversations") or []
382
+ if not _has_assistant(conversations):
383
+ return {
384
+ "prompt": [],
385
+ "info": {
386
+ "reference": [],
387
+ "future_turns": [],
388
+ },
389
+ "valid": False,
390
+ }
391
+
392
+ messages = _ensure_messages(conversations)
393
+ prompt, reference, future_turns, user_contexts = _partition_conversation(messages)
394
+ flattened_future = [msg for block in future_turns for msg in block]
395
+ if any(msg.get("role") != "user" for msg in flattened_future):
396
+ return {
397
+ "prompt": [],
398
+ "info": {
399
+ "reference": [],
400
+ "future_turns": [],
401
+ },
402
+ "valid": False,
403
+ }
404
+ if any(not msg.get("content") for msg in flattened_future):
405
+ return {
406
+ "prompt": [],
407
+ "info": {
408
+ "reference": [],
409
+ "future_turns": [],
410
+ },
411
+ "valid": False,
412
+ }
413
+
414
+ return {
415
+ "prompt": prompt,
416
+ "info": {
417
+ "reference": reference,
418
+ "future_turns": future_turns,
419
+ "user_contexts": user_contexts,
420
+ },
421
+ "valid": True,
422
+ }
423
+
424
+ dataset = hf_dataset.map(to_multi_turn, remove_columns=hf_dataset.column_names)
425
+ dataset = dataset.filter(lambda example: example.get("valid", False))
426
+ if "valid" in dataset.column_names:
427
+ dataset = dataset.remove_columns("valid")
428
+
429
+ effective_logger = logger or logging.getLogger("sharegpt_polar")
430
+ if enable_logging:
431
+ effective_logger.info(
432
+ "Initializing ShareGPTPolar environment",
433
+ extra={
434
+ "dataset_name": dataset_name,
435
+ "data_path": str(data_path) if data_path else None,
436
+ "server_address": server_address,
437
+ },
438
+ )
439
+
440
+ client = PoolingClient(
441
+ base_url=server_address,
442
+ model=reward_model,
443
+ logger=effective_logger,
444
+ enable_logging=enable_logging,
445
+ )
446
+
447
+ rubric_cls = reward_scheme or vf.Rubric
448
+ rubric = rubric_cls(funcs=[polar_reward])
449
+ rubric.class_objects["reward_client"] = client
450
+ rubric.class_objects["logger"] = effective_logger
451
+ rubric.class_objects["enable_logging"] = enable_logging
452
+
453
+ env_kwargs.setdefault("max_concurrent", 1)
454
+ return ShareGPTPolarEnv(
455
+ dataset=dataset,
456
+ rubric=rubric,
457
+ max_turns=max_turns,
458
+ enable_logging=enable_logging,
459
+ **env_kwargs,
460
+ )
461
+
462
+