|
|
import json |
|
|
from collections import Counter |
|
|
import numpy as np |
|
|
from typing import List, Dict |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def analyze_dialogue_lengths(file_path: str) -> Dict: |
|
|
|
|
|
lengths = [] |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line.strip()) |
|
|
for message in item['messages']: |
|
|
if message['role'] == 'assistant': |
|
|
content = message['content'] |
|
|
length = len(content) |
|
|
lengths.append(length) |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Error parsing line: {e}") |
|
|
continue |
|
|
|
|
|
if not lengths: |
|
|
print(f"No valid assistant responses found in {file_path}") |
|
|
return {} |
|
|
|
|
|
|
|
|
max_length = max(lengths) |
|
|
avg_length = np.mean(lengths) |
|
|
median_length = np.median(lengths) |
|
|
|
|
|
|
|
|
length_ranges = { |
|
|
'0-100': 0, |
|
|
'101-500': 0, |
|
|
'501-1000': 0, |
|
|
'1001-2000': 0, |
|
|
'2001-3000': 0, |
|
|
'3001-4000': 0, |
|
|
'4001-5000': 0, |
|
|
'5001-6000': 0, |
|
|
'6000+': 0 |
|
|
} |
|
|
|
|
|
for length in lengths: |
|
|
if length <= 100: |
|
|
length_ranges['0-100'] += 1 |
|
|
elif length <= 500: |
|
|
length_ranges['101-500'] += 1 |
|
|
elif length <= 1000: |
|
|
length_ranges['501-1000'] += 1 |
|
|
elif length <= 2000: |
|
|
length_ranges['1001-2000'] += 1 |
|
|
elif length <= 3000: |
|
|
length_ranges['2001-3000'] += 1 |
|
|
elif length <= 4000: |
|
|
length_ranges['3001-4000'] += 1 |
|
|
elif length <= 5000: |
|
|
length_ranges['4001-5000'] += 1 |
|
|
elif length <= 6000: |
|
|
length_ranges['5001-6000'] += 1 |
|
|
else: |
|
|
length_ranges['6000+'] += 1 |
|
|
|
|
|
|
|
|
total = len(lengths) |
|
|
percentages = {k: (v/total)*100 for k, v in length_ranges.items()} |
|
|
|
|
|
|
|
|
print(f"\nAnalysis Results for {file_path}:") |
|
|
print(f"Total number of assistant responses: {total}") |
|
|
print(f"Maximum length: {max_length} characters") |
|
|
print(f"Average length: {avg_length:.2f} characters") |
|
|
print(f"Median length: {median_length:.2f} characters") |
|
|
print("\nLength Distribution:") |
|
|
for range_name, percentage in percentages.items(): |
|
|
print(f"{range_name}: {percentage:.2f}%") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
plt.hist(lengths, bins=100, edgecolor='black') |
|
|
plt.title('Distribution of Assistant Response Lengths') |
|
|
plt.xlabel('Length (characters)') |
|
|
plt.ylabel('Frequency') |
|
|
plt.savefig('dialogue_length_distribution.png') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
ranges = list(length_ranges.keys()) |
|
|
counts = list(length_ranges.values()) |
|
|
plt.bar(ranges, counts) |
|
|
plt.title('Distribution of Response Lengths by Range') |
|
|
plt.xlabel('Length Range') |
|
|
plt.ylabel('Count') |
|
|
plt.xticks(rotation=45) |
|
|
plt.tight_layout() |
|
|
plt.savefig('dialogue_length_ranges.png') |
|
|
plt.close() |
|
|
|
|
|
return { |
|
|
'total_responses': total, |
|
|
'max_length': max_length, |
|
|
'avg_length': avg_length, |
|
|
'median_length': median_length, |
|
|
'distribution': percentages |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
train_results = analyze_dialogue_lengths('dataset_cotSFTtrain.json') |
|
|
test_results = analyze_dialogue_lengths('dataset_cotSFTtest.json') |