Spaces:
Running
Running
Deploy FastAPI ML service to Hugging Face Spaces
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +9 -0
- docker-compose.yml +14 -0
- dockerfile +31 -0
- reports/.gitkeep +0 -0
- reports/feedback/feedback_data.csv +3 -0
- reports/figures/.gitkeep +0 -0
- reports/figures/logo_header.svg +38 -0
- reports/unit_and_behavioral_tests/report.md +108 -0
- reports/unit_tests/report.md +122 -0
- requirements.txt +13 -0
- turing/CLI_runner/run_dataset.py +105 -0
- turing/CLI_runner/run_prediction.py +57 -0
- turing/__init__.py +1 -0
- turing/__pycache__/__init__.cpython-312.pyc +0 -0
- turing/__pycache__/config.cpython-312.pyc +0 -0
- turing/__pycache__/dataset.cpython-312.pyc +0 -0
- turing/__pycache__/evaluate_model.cpython-312.pyc +0 -0
- turing/api/__init__.py +0 -0
- turing/api/app.py +115 -0
- turing/api/demo.py +302 -0
- turing/api/schemas.py +22 -0
- turing/config.py +95 -0
- turing/data_validation.py +271 -0
- turing/dataset.py +210 -0
- turing/evaluate_model.py +121 -0
- turing/features.py +678 -0
- turing/modeling/__init__.py +0 -0
- turing/modeling/__pycache__/__init__.cpython-312.pyc +0 -0
- turing/modeling/__pycache__/baseModel.cpython-312.pyc +0 -0
- turing/modeling/baseModel.py +111 -0
- turing/modeling/model_selector.py +145 -0
- turing/modeling/models/__init__.py +15 -0
- turing/modeling/models/__pycache__/miniLM.cpython-312.pyc +0 -0
- turing/modeling/models/__pycache__/miniLmWithClassificationHead.cpython-312.pyc +0 -0
- turing/modeling/models/__pycache__/randomForestTfIdf.cpython-312.pyc +0 -0
- turing/modeling/models/codeBerta.py +463 -0
- turing/modeling/models/graphCodeBert.py +469 -0
- turing/modeling/models/randomForestTfIdf.py +153 -0
- turing/modeling/models/tinyBert.py +441 -0
- turing/modeling/predict.py +195 -0
- turing/modeling/train.py +212 -0
- turing/plots.py +29 -0
- turing/reporting.py +173 -0
- turing/tests/behavioral/test_directional.py +183 -0
- turing/tests/behavioral/test_invariance.py +117 -0
- turing/tests/behavioral/test_minimum_functionality.py +52 -0
- turing/tests/conftest.py +305 -0
- turing/tests/unit/test_api.py +201 -0
- turing/tests/unit/test_config.py +133 -0
- turing/tests/unit/test_dataset.py +95 -0
.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
turing/reporting.py
|
| 2 |
+
turing/plots.py
|
| 3 |
+
turing/features.py
|
| 4 |
+
turing/evaluate_model.py
|
| 5 |
+
turing/data_validation.py
|
| 6 |
+
|
| 7 |
+
turing/CLI_runner
|
| 8 |
+
turing/modeling/train.py
|
| 9 |
+
turing/tests
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
api:
|
| 3 |
+
build: .
|
| 4 |
+
container_name: turing_app
|
| 5 |
+
image: turing_api
|
| 6 |
+
ports:
|
| 7 |
+
- "7860:7860"
|
| 8 |
+
|
| 9 |
+
environment:
|
| 10 |
+
- MLFLOW_TRACKING_USERNAME=${MLFLOW_USER}
|
| 11 |
+
- MLFLOW_TRACKING_PASSWORD=${MLFLOW_PWD}
|
| 12 |
+
- DAGSHUB_USER_TOKEN=${DAGSHUB_TOKEN}
|
| 13 |
+
|
| 14 |
+
command: uvicorn turing.api.app:app --host 0.0.0.0 --port 7860 --reload
|
dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
|
| 3 |
+
# Create a non-root user to run the application and set permissions
|
| 4 |
+
RUN useradd -m -u 1000 turinguser
|
| 5 |
+
RUN mkdir -p /app/models && chown -R turinguser:turinguser /app /app/models
|
| 6 |
+
USER turinguser
|
| 7 |
+
|
| 8 |
+
# Set environment variables
|
| 9 |
+
# PATH to include local user binaries and project root
|
| 10 |
+
ENV PATH="/home/turinguser/.local/bin:$PATH"
|
| 11 |
+
ENV PROJ_ROOT=/app
|
| 12 |
+
|
| 13 |
+
# Set the working directory in the container
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# Copy essential files to install dependencies
|
| 17 |
+
COPY --chown=turinguser requirements.txt .
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
RUN pip install --default-timeout=1000 --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
| 21 |
+
RUN pip3 install -v -r requirements.txt --upgrade --default-timeout=1000 --no-cache-dir --break-system-packages
|
| 22 |
+
|
| 23 |
+
# Copy remaining project files
|
| 24 |
+
COPY --chown=turinguser turing ./turing
|
| 25 |
+
COPY --chown=turinguser reports ./reports
|
| 26 |
+
|
| 27 |
+
# Expose port 7860 for the FastAPI application
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Default command to run the FastAPI application on port 7860
|
| 31 |
+
CMD ["uvicorn", "turing.api.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
reports/.gitkeep
ADDED
|
File without changes
|
reports/feedback/feedback_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Timestamp,Input_Text,Language,Model_Prediction,User_Correction
|
| 2 |
+
2025-12-11 22:41:05,# Create output directory,python,Usage,DevelopmentNotes
|
| 3 |
+
2025-12-11 23:05:24,# Entry point for running the API directly with python,python,Usage,DevelopmentNotes
|
reports/figures/.gitkeep
ADDED
|
File without changes
|
reports/figures/logo_header.svg
ADDED
|
|
reports/unit_and_behavioral_tests/report.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Test Execution Report
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
### Environment
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
```text
|
| 9 |
+
Parameter Value
|
| 10 |
+
Timestamp 2025-11-27 15:44:47
|
| 11 |
+
Context turing
|
| 12 |
+
Python Version 3.12.12
|
| 13 |
+
Platform Windows-11-10.0.26100-SP0
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Executive Summary
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
```text
|
| 21 |
+
Total Passed Failed Success Rate
|
| 22 |
+
66 35 31 53.0%
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Detailed Breakdown:
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
### BEHAVIORAL Tests
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
```text
|
| 33 |
+
Module Test Case Result Time Message
|
| 34 |
+
test_directional.py test_java_directional_add_deprecation [ FAILED ] 0.30s turing\tests\behavioral\test_directional.py:16: Assertion...
|
| 35 |
+
test_directional.py test_python_directional_remove_todo [ FAILED ] 0.15s turing\tests\behavioral\test_directional.py:31: Assertion...
|
| 36 |
+
test_directional.py test_pharo_directional_add_responsibility [ FAILED ] 0.13s turing\tests\behavioral\test_directional.py:49: Assertion...
|
| 37 |
+
test_directional.py test_java_directional_contrast_rational [ FAILED ] 0.12s turing\tests\behavioral\test_directional.py:70: Assertion...
|
| 38 |
+
test_directional.py test_python_directional_contrast_todo [ FAILED ] 0.12s turing\tests\behavioral\test_directional.py:87: Assertion...
|
| 39 |
+
test_directional.py test_pharo_directional_contrast_collaborators [ FAILED ] 0.13s turing\tests\behavioral\test_directional.py:112: Assertio...
|
| 40 |
+
test_directional.py test_java_directional_shift_summary_to_expand [ FAILED ] 0.12s turing\tests\behavioral\test_directional.py:132: Assertio...
|
| 41 |
+
test_directional.py test_python_directional_shift_summary_to_devnotes [ FAILED ] 0.12s turing\tests\behavioral\test_directional.py:152: Assertio...
|
| 42 |
+
test_directional.py test_pharo_directional_shift_to_example [ FAILED ] 0.12s turing\tests\behavioral\test_directional.py:173: Assertio...
|
| 43 |
+
test_invariance.py test_python_invariance_parameters[:param user_i... [ FAILED ] 0.22s turing\tests\behavioral\test_invariance.py:15: AssertionE...
|
| 44 |
+
test_invariance.py test_python_invariance_parameters[:PARAM USER_I... [ FAILED ] 0.07s turing\tests\behavioral\test_invariance.py:15: AssertionE...
|
| 45 |
+
test_invariance.py test_python_invariance_parameters[ :param user... [ FAILED ] 0.06s turing\tests\behavioral\test_invariance.py:15: AssertionE...
|
| 46 |
+
test_invariance.py test_python_invariance_parameters[:param user_i... [ FAILED ] 0.06s turing\tests\behavioral\test_invariance.py:15: AssertionE...
|
| 47 |
+
test_invariance.py test_java_invariance_deprecation [ FAILED ] 0.13s turing\tests\behavioral\test_invariance.py:26: AssertionE...
|
| 48 |
+
test_invariance.py test_python_invariance_summary [ FAILED ] 0.13s turing\tests\behavioral\test_invariance.py:45: AssertionE...
|
| 49 |
+
test_invariance.py test_pharo_invariance_intent [ FAILED ] 0.13s turing\tests\behavioral\test_invariance.py:64: AssertionE...
|
| 50 |
+
test_invariance.py test_python_invariance_typos_parameters [ FAILED ] 0.07s turing\tests\behavioral\test_invariance.py:85: AssertionE...
|
| 51 |
+
test_invariance.py test_java_invariance_semantic_summary [ PASS ] 0.32s
|
| 52 |
+
test_minimum_functionality.py test_java_mft[test getfilestatus and related li... [ PASS ] 0.06s
|
| 53 |
+
test_minimum_functionality.py test_java_mft[/* @deprecated Use something else... [ FAILED ] 0.06s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 54 |
+
test_minimum_functionality.py test_java_mft[code source of this file http gre... [ FAILED ] 0.06s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 55 |
+
test_minimum_functionality.py test_java_mft[this is balanced if each pool is ... [ FAILED ] 0.06s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 56 |
+
test_minimum_functionality.py test_java_mft[// For internal use only.-expecte... [ FAILED ] 0.06s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 57 |
+
test_minimum_functionality.py test_java_mft[this impl delegates to the old fi... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 58 |
+
test_minimum_functionality.py test_java_mft[/** Usage: new MyClass(arg1). */-... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:17:...
|
| 59 |
+
test_minimum_functionality.py test_python_mft[a service specific account of t... [ PASS ] 0.06s
|
| 60 |
+
test_minimum_functionality.py test_python_mft[:param user_id: The ID of the u... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:29:...
|
| 61 |
+
test_minimum_functionality.py test_python_mft[# TODO: Refactor this entire bl... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:29:...
|
| 62 |
+
test_minimum_functionality.py test_python_mft[use this class if you want acce... [ PASS ] 0.06s
|
| 63 |
+
test_minimum_functionality.py test_python_mft[# create a new list by filterin... [ FAILED ] 0.08s turing\tests\behavioral\test_minimum_functionality.py:29:...
|
| 64 |
+
test_minimum_functionality.py test_pharo_mft[i am a simple arrow like arrowhe... [ PASS ] 0.07s
|
| 65 |
+
test_minimum_functionality.py test_pharo_mft[the example below shows how to c... [ PASS ] 0.07s
|
| 66 |
+
test_minimum_functionality.py test_pharo_mft[i provide a data structure indep... [ FAILED ] 0.06s turing\tests\behavioral\test_minimum_functionality.py:43:...
|
| 67 |
+
test_minimum_functionality.py test_pharo_mft[the cache is cleared after each ... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:43:...
|
| 68 |
+
test_minimum_functionality.py test_pharo_mft[it is possible hovewer to custom... [ PASS ] 0.07s
|
| 69 |
+
test_minimum_functionality.py test_pharo_mft[collaborators: BlElement, BlSpac... [ FAILED ] 0.07s turing\tests\behavioral\test_minimum_functionality.py:43:...
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
### UNIT Tests
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
```text
|
| 77 |
+
Module Test Case Result Time Message
|
| 78 |
+
test_config.py test_proj_root_is_correctly_identified [ PASS ] 0.00s
|
| 79 |
+
test_config.py test_directory_paths_are_correctly_structured [ PASS ] 0.00s
|
| 80 |
+
test_config.py test_dataset_constants_are_valid [ PASS ] 0.00s
|
| 81 |
+
test_config.py test_labels_map_and_total_categories_are_correct [ PASS ] 0.00s
|
| 82 |
+
test_config.py test_numeric_parameters_are_positive [ PASS ] 0.00s
|
| 83 |
+
test_config.py test_load_dotenv_is_called_on_module_load [ PASS ] 0.00s
|
| 84 |
+
test_dataset.py test_initialization_paths_are_correct [ FAILED ] 0.00s turing\tests\unit\test_dataset.py:24: AssertionError
|
| 85 |
+
test_dataset.py test_format_labels_for_csv[input_labels0-[1, 0,... [ PASS ] 0.00s
|
| 86 |
+
test_dataset.py test_format_labels_for_csv[[1, 0, 1]-[1, 0, 1]] [ PASS ] 0.00s
|
| 87 |
+
test_dataset.py test_format_labels_for_csv[input_labels2-[]] [ PASS ] 0.00s
|
| 88 |
+
test_dataset.py test_format_labels_for_csv[None-None] [ PASS ] 0.00s
|
| 89 |
+
test_dataset.py test_get_dataset_raises_file_not_found [ PASS ] 0.00s
|
| 90 |
+
test_dataset.py test_get_dataset_success_and_label_parsing [ PASS ] 0.48s
|
| 91 |
+
test_features.py test_config_id_generation [ PASS ] 0.00s
|
| 92 |
+
test_features.py test_config_attributes [ PASS ] 0.00s
|
| 93 |
+
test_features.py test_clean_text_basic [ PASS ] 0.00s
|
| 94 |
+
test_features.py test_clean_text_stopwords [ PASS ] 2.39s
|
| 95 |
+
test_features.py test_clean_text_lemmatization [ PASS ] 0.00s
|
| 96 |
+
test_features.py test_clean_text_handles_none [ PASS ] 0.00s
|
| 97 |
+
test_features.py test_extract_numeric_features [ PASS ] 0.00s
|
| 98 |
+
test_model.py test_model_initialization[randomForestTfIdf] [ PASS ] 0.00s
|
| 99 |
+
test_model.py test_model_initialization[codeBerta] [ PASS ] 0.00s
|
| 100 |
+
test_model.py test_model_setup[randomForestTfIdf] [ PASS ] 0.00s
|
| 101 |
+
test_model.py test_model_setup[codeBerta] [ PASS ] 1.39s
|
| 102 |
+
test_model.py test_model_train[randomForestTfIdf] [ PASS ] 3.06s
|
| 103 |
+
test_model.py test_model_train[codeBerta] [ PASS ] 4.90s
|
| 104 |
+
test_model.py test_model_evaluate[randomForestTfIdf] [ PASS ] 1.39s
|
| 105 |
+
test_model.py test_model_evaluate[codeBerta] [ FAILED ] 6.36s turing\tests\unit\test_model.py:101: AssertionError
|
| 106 |
+
test_model.py test_model_predict[randomForestTfIdf] [ PASS ] 1.36s
|
| 107 |
+
test_model.py test_model_predict[codeBerta] [ PASS ] 5.26s
|
| 108 |
+
```
|
reports/unit_tests/report.md
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Turing Test Execution Report
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Environment Information
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
| Parameter | Value |
|
| 14 |
+
|:---------------|:---------------------------|
|
| 15 |
+
| Timestamp | 2025-12-04 18:14:18 |
|
| 16 |
+
| Context | TURING |
|
| 17 |
+
| Python Version | 3.12.12 |
|
| 18 |
+
| Platform | macOS-15.6-arm64-arm-64bit |
|
| 19 |
+
| Architecture | arm64 |
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Executive Summary
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
**Overall Status:** MOSTLY PASSED
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
**Success Rate:** 91.2%
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
| Metric | Count |
|
| 35 |
+
|:-------------|--------:|
|
| 36 |
+
| Total Tests | 34 |
|
| 37 |
+
| Passed | 31 |
|
| 38 |
+
| Failed | 3 |
|
| 39 |
+
| Success Rate | 91.2% |
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
**Visual Progress:**
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
Progress: [█████████████████████████████████████████████░░░░░] 91.2%
|
| 47 |
+
Passed: 31/34 tests
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## UNIT Tests
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
### Statistics
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
| Status | Count |
|
| 61 |
+
|:---------|-----------:|
|
| 62 |
+
| Total | 34 |
|
| 63 |
+
| Passed | 31 (91.2%) |
|
| 64 |
+
| Failed | 3 (8.8%) |
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
### Test Results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
| Module | Test Case | Result | Time | Message |
|
| 71 |
+
|:----------------|:---------------------------------------------------|:---------|:-------|:-----------------------------------------------------|
|
| 72 |
+
| test_api.py | test_health_check_returns_ok | PASS | 0.01s | |
|
| 73 |
+
| test_api.py | test_predict_success_java | PASS | 0.02s | |
|
| 74 |
+
| test_api.py | test_predict_success_python | PASS | 0.00s | |
|
| 75 |
+
| test_api.py | test_predict_success_pharo | PASS | 0.00s | |
|
| 76 |
+
| test_api.py | test_predict_missing_texts | PASS | 0.00s | |
|
| 77 |
+
| test_api.py | test_predict_missing_language | PASS | 0.00s | |
|
| 78 |
+
| test_api.py | test_predict_empty_texts | PASS | 0.00s | |
|
| 79 |
+
| test_api.py | test_predict_error_handling | PASS | 0.00s | |
|
| 80 |
+
| test_api.py | test_predict_invalid_language | PASS | 0.00s | |
|
| 81 |
+
| test_api.py | test_prediction_request_valid | PASS | 0.00s | |
|
| 82 |
+
| test_api.py | test_prediction_response_valid | PASS | 0.00s | |
|
| 83 |
+
| test_config.py | test_proj_root_is_correctly_identified | PASS | 0.00s | |
|
| 84 |
+
| test_config.py | test_directory_paths_are_correctly_structured | PASS | 0.00s | |
|
| 85 |
+
| test_config.py | test_dataset_constants_are_valid | PASS | 0.00s | |
|
| 86 |
+
| test_config.py | test_labels_map_and_total_categories_are_correct | PASS | 0.00s | |
|
| 87 |
+
| test_config.py | test_numeric_parameters_are_positive | PASS | 0.00s | |
|
| 88 |
+
| test_config.py | test_load_dotenv_is_called_on_module_load | PASS | 0.00s | |
|
| 89 |
+
| test_dataset.py | test_initialization_paths_are_correct | FAIL | 0.00s | turing/tests/unit/test_dataset.py:25: AssertionError |
|
| 90 |
+
| test_dataset.py | test_format_labels_for_csv[input_labels0-[1, 0,... | PASS | 0.00s | |
|
| 91 |
+
| test_dataset.py | test_format_labels_for_csv[[1, 0, 1]-[1, 0, 1]] | PASS | 0.00s | |
|
| 92 |
+
| test_dataset.py | test_format_labels_for_csv[input_labels2-[]] | PASS | 0.00s | |
|
| 93 |
+
| test_dataset.py | test_format_labels_for_csv[None-None] | PASS | 0.00s | |
|
| 94 |
+
| test_dataset.py | test_get_dataset_raises_file_not_found | PASS | 0.00s | |
|
| 95 |
+
| test_dataset.py | test_get_dataset_success_and_label_parsing | FAIL | 0.00s | turing/dataset.py:128: FileNotFoundError |
|
| 96 |
+
| test_model.py | test_model_initialization[randomForestTfIdf] | PASS | 0.00s | |
|
| 97 |
+
| test_model.py | test_model_initialization[codeBerta] | PASS | 0.00s | |
|
| 98 |
+
| test_model.py | test_model_setup[randomForestTfIdf] | PASS | 0.00s | |
|
| 99 |
+
| test_model.py | test_model_setup[codeBerta] | PASS | 0.93s | |
|
| 100 |
+
| test_model.py | test_model_train[randomForestTfIdf] | PASS | 2.66s | |
|
| 101 |
+
| test_model.py | test_model_train[codeBerta] | PASS | 7.22s | |
|
| 102 |
+
| test_model.py | test_model_evaluate[randomForestTfIdf] | PASS | 1.31s | |
|
| 103 |
+
| test_model.py | test_model_evaluate[codeBerta] | FAIL | 8.83s | turing/tests/unit/test_model.py:101: AssertionError |
|
| 104 |
+
| test_model.py | test_model_predict[randomForestTfIdf] | PASS | 1.21s | |
|
| 105 |
+
| test_model.py | test_model_predict[codeBerta] | PASS | 5.98s | |
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
> **ERROR**: 3 test(s) failed. Please review the error messages above.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
*Report generated on 2025-12-04 at 18:14:18*
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
*Powered by Turing Test Suite*
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
loguru
|
| 4 |
+
pydantic
|
| 5 |
+
python-dotenv
|
| 6 |
+
mlflow
|
| 7 |
+
numpy
|
| 8 |
+
transformers
|
| 9 |
+
dagshub
|
| 10 |
+
datasets
|
| 11 |
+
accelerate
|
| 12 |
+
scikit-learn
|
| 13 |
+
gradio
|
turing/CLI_runner/run_dataset.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import typer
|
| 7 |
+
from typing_extensions import Annotated
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from turing.config import INTERIM_DATA_DIR, RAW_DATA_DIR
|
| 11 |
+
from turing.dataset import DatasetManager
|
| 12 |
+
except ImportError:
|
| 13 |
+
logger.error("Error: Could not import DatasetManager. Check sys.path configuration.")
|
| 14 |
+
logger.error(f"Current sys.path: {sys.path}")
|
| 15 |
+
sys.exit(1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
proj_root = os.path.dirname(os.path.dirname(script_dir))
|
| 20 |
+
sys.path.append(proj_root)
|
| 21 |
+
|
| 22 |
+
app = typer.Typer(help="CLI for dataset management (Download, Conversion, and Search).")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@app.command()
|
| 26 |
+
def download():
|
| 27 |
+
"""
|
| 28 |
+
Loads the dataset from Hugging Face and saves it into the "raw" folder.
|
| 29 |
+
"""
|
| 30 |
+
logger.info("Starting dataset download...")
|
| 31 |
+
manager = DatasetManager()
|
| 32 |
+
manager.download_dataset()
|
| 33 |
+
logger.success("Download complete.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@app.command(name="parquet-to-csv")
|
| 37 |
+
def parquet_to_csv():
|
| 38 |
+
"""
|
| 39 |
+
Converts all parquet files in the raw data directory
|
| 40 |
+
to CSV format in the interim data directory.
|
| 41 |
+
"""
|
| 42 |
+
logger.info("Starting Parquet -> CSV conversion...")
|
| 43 |
+
manager = DatasetManager()
|
| 44 |
+
manager.parquet_to_csv()
|
| 45 |
+
logger.success("Conversion complete.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@app.command()
|
| 49 |
+
def search(
|
| 50 |
+
filename: Annotated[
|
| 51 |
+
str, typer.Argument(help="The exact filename to search for (e.g., 'java_train.parquet')")
|
| 52 |
+
],
|
| 53 |
+
directory: Annotated[
|
| 54 |
+
str,
|
| 55 |
+
typer.Option(
|
| 56 |
+
"--directory",
|
| 57 |
+
"-d",
|
| 58 |
+
help="Directory to search in. Keywords 'raw' or 'interim' can be used.",
|
| 59 |
+
),
|
| 60 |
+
] = "raw",
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Searches for a file by name in the data directories.
|
| 64 |
+
"""
|
| 65 |
+
logger.info(f"Initializing search for '{filename}'...")
|
| 66 |
+
manager = DatasetManager()
|
| 67 |
+
|
| 68 |
+
search_path = None
|
| 69 |
+
if directory.lower() == "raw":
|
| 70 |
+
search_path = RAW_DATA_DIR
|
| 71 |
+
logger.info("Searching in 'raw' data directory.")
|
| 72 |
+
elif directory.lower() == "interim":
|
| 73 |
+
search_path = INTERIM_DATA_DIR
|
| 74 |
+
logger.info("Searching in 'interim' data directory.")
|
| 75 |
+
else:
|
| 76 |
+
search_path = Path(directory)
|
| 77 |
+
logger.info(f"Searching in custom path: {search_path}")
|
| 78 |
+
|
| 79 |
+
results = manager.search_file(filename, search_directory=search_path)
|
| 80 |
+
|
| 81 |
+
if results:
|
| 82 |
+
logger.success(f"Found {len(results)} file(s):")
|
| 83 |
+
for res in results:
|
| 84 |
+
print(f"-> {res}")
|
| 85 |
+
else:
|
| 86 |
+
logger.warning(f"File '{filename}' not found in {search_path}.")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.command(name="show-raw-hf")
|
| 90 |
+
def show_raw_hf():
|
| 91 |
+
"""
|
| 92 |
+
Loads and displays info about the raw dataset from Hugging Face.
|
| 93 |
+
"""
|
| 94 |
+
logger.info("Loading raw dataset info from Hugging Face...")
|
| 95 |
+
manager = DatasetManager()
|
| 96 |
+
dataset = manager.get_raw_dataset_from_hf()
|
| 97 |
+
if dataset:
|
| 98 |
+
logger.info("Dataset info:")
|
| 99 |
+
print(dataset)
|
| 100 |
+
else:
|
| 101 |
+
logger.error("Could not retrieve dataset.")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
app()
|
turing/CLI_runner/run_prediction.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from loguru import logger
|
| 5 |
+
import typer
|
| 6 |
+
|
| 7 |
+
from turing.modeling.models.randomForestTfIdf import RandomForestTfIdf
|
| 8 |
+
from turing.modeling.predict import ModelInference
|
| 9 |
+
|
| 10 |
+
# Add project root to sys.path
|
| 11 |
+
current_dir = Path(__file__).resolve().parent
|
| 12 |
+
project_root = current_dir.parent
|
| 13 |
+
if str(project_root) not in sys.path:
|
| 14 |
+
sys.path.append(str(project_root))
|
| 15 |
+
|
| 16 |
+
app = typer.Typer()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@app.command()
|
| 20 |
+
def main(
|
| 21 |
+
mlflow_run_id: str = typer.Option(
|
| 22 |
+
"af1fa5959dc14fa9a29a0a19c11f1b08", help="The MLflow Run ID"
|
| 23 |
+
),
|
| 24 |
+
artifact_name: str = typer.Option(
|
| 25 |
+
"RandomForestTfIdf_java", help="The name of the model artifact"
|
| 26 |
+
),
|
| 27 |
+
language: str = typer.Option("java", help="The target programming language"),
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Run inference using the dataset stored on disk (Standard CML/DVC workflow).
|
| 31 |
+
"""
|
| 32 |
+
logger.info("Starting CLI inference process...")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# Initialize inference engine
|
| 36 |
+
inference_engine = ModelInference()
|
| 37 |
+
|
| 38 |
+
# Run prediction on the test dataset
|
| 39 |
+
results = inference_engine.predict_from_mlflow(
|
| 40 |
+
mlflow_run_id=mlflow_run_id,
|
| 41 |
+
artifact_name=artifact_name,
|
| 42 |
+
language=language,
|
| 43 |
+
model_class=RandomForestTfIdf,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Output results
|
| 47 |
+
print("\n--- Prediction Results ---")
|
| 48 |
+
print(results)
|
| 49 |
+
print("--------------------------")
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"CLI Prediction failed: {e}")
|
| 53 |
+
raise typer.Exit(code=1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
app()
|
turing/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from turing import config # noqa: F401
|
turing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
turing/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
turing/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
turing/__pycache__/evaluate_model.cpython-312.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
turing/api/__init__.py
ADDED
|
File without changes
|
turing/api/app.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from turing.api.demo import create_demo
|
| 10 |
+
from turing.api.schemas import PredictionRequest, PredictionResponse
|
| 11 |
+
from turing.modeling.predict import ModelInference
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_logo_b64_src(filename="logo_header.svg"):
|
| 15 |
+
"""read SVG and convert it into a string Base64 for HTML."""
|
| 16 |
+
try:
|
| 17 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
target_path = os.path.join(base_path, "..", "..", "reports", "figures", filename)
|
| 19 |
+
target_path = os.path.normpath(target_path)
|
| 20 |
+
|
| 21 |
+
with open(target_path, "rb") as f:
|
| 22 |
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
| 23 |
+
return f"data:image/svg+xml;base64,{encoded}"
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Unable to load logo for API: {e}")
|
| 26 |
+
return ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# load logo
|
| 30 |
+
logo_src = get_logo_b64_src()
|
| 31 |
+
|
| 32 |
+
# html
|
| 33 |
+
logo_html_big = f"""
|
| 34 |
+
<a href="/gradio">
|
| 35 |
+
<img src="{logo_src}" width="150" style="display: block; margin: 10px 0;">
|
| 36 |
+
</a>
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# description
|
| 40 |
+
description_md = f"""
|
| 41 |
+
API for classifying code comments.
|
| 42 |
+
|
| 43 |
+
You can interact with the model directly using the visual interface.
|
| 44 |
+
Click the logo below to open it:
|
| 45 |
+
|
| 46 |
+
{logo_html_big}
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
app = FastAPI(
|
| 51 |
+
title="Turing Team Code Classification API",
|
| 52 |
+
description=description_md,
|
| 53 |
+
version="1.0.0"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@app.get("/manifest.json")
|
| 57 |
+
def get_manifest():
|
| 58 |
+
return JSONResponse(content={
|
| 59 |
+
"name": "Turing App",
|
| 60 |
+
"short_name": "Turing",
|
| 61 |
+
"start_url": "/gradio",
|
| 62 |
+
"display": "standalone",
|
| 63 |
+
"background_color": "#ffffff",
|
| 64 |
+
"theme_color": "#000000",
|
| 65 |
+
"icons": []
|
| 66 |
+
})
|
| 67 |
+
|
| 68 |
+
# Global inference engine instance
|
| 69 |
+
inference_engine = ModelInference()
|
| 70 |
+
|
| 71 |
+
demo = create_demo(inference_engine)
|
| 72 |
+
app = gr.mount_gradio_app(app, demo, path="/gradio")
|
| 73 |
+
|
| 74 |
+
@app.get("/")
|
| 75 |
+
def health_check():
|
| 76 |
+
"""
|
| 77 |
+
Root endpoint to verify API status.
|
| 78 |
+
"""
|
| 79 |
+
return {"status": "ok", "message": "Turing Code Classification API is ready.", "ui_url": "/gradio"}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 83 |
+
def predict(request: PredictionRequest):
|
| 84 |
+
"""
|
| 85 |
+
Endpoint to classify a list of code comments.
|
| 86 |
+
Dynamically loads the model from MLflow based on the request parameters.
|
| 87 |
+
"""
|
| 88 |
+
try:
|
| 89 |
+
logger.info(f"Received prediction request for language: {request.language}")
|
| 90 |
+
|
| 91 |
+
# Perform prediction using the inference engine
|
| 92 |
+
raw, predictions, run_id, artifact = inference_engine.predict_payload(
|
| 93 |
+
texts=request.texts, language=request.language
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Ensure predictions are serializable (convert numpy arrays to lists)
|
| 97 |
+
if hasattr(predictions, "tolist"):
|
| 98 |
+
predictions = predictions.tolist()
|
| 99 |
+
|
| 100 |
+
return PredictionResponse(
|
| 101 |
+
predictions=raw.tolist(),
|
| 102 |
+
labels=predictions,
|
| 103 |
+
model_info={"artifact": artifact, "language": request.language},
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Prediction failed: {str(e)}")
|
| 108 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Entry point for running the API directly with python
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
import uvicorn
|
| 114 |
+
|
| 115 |
+
uvicorn.run(app, host="127.0.0.1", port=7860)
|
turing/api/demo.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
# ---IMPORTS ---
|
| 8 |
+
try:
|
| 9 |
+
from turing.modeling.models.codeBerta import CodeBERTa
|
| 10 |
+
from turing.modeling.predict import ModelInference
|
| 11 |
+
except ImportError as e:
|
| 12 |
+
print(f"WARNING: Error importing real modules: {e}")
|
| 13 |
+
class CodeBERTa:
|
| 14 |
+
pass
|
| 15 |
+
class ModelInference:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
# --- CONFIGURATION ---
|
| 19 |
+
FEEDBACK_FILE = "reports/feedback/feedback_data.csv"
|
| 20 |
+
|
| 21 |
+
LABELS_MAP = {
|
| 22 |
+
"java": ["summary", "Ownership", "Expand", "usage", "Pointer", "deprecation", "rational"],
|
| 23 |
+
"python": ["Usage", "Parameters", "DevelopmentNotes", "Expand", "Summary"],
|
| 24 |
+
"pharo": ["Keyimplementationpoints", "Example", "Responsibilities", "Intent", "Keymessages", "Collaborators"],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# --- CSS ---
|
| 28 |
+
CSS = """
|
| 29 |
+
:root {
|
| 30 |
+
--bg-primary: #fafaf9; --bg-secondary: #ffffff; --border-color: #e5e7eb;
|
| 31 |
+
--text-primary: #1f2937; --text-secondary: #6b7280; --accent-bg: #f3f4f6;
|
| 32 |
+
--primary-btn: #ea580c; --primary-btn-hover: #c2410c;
|
| 33 |
+
}
|
| 34 |
+
.dark, body.dark, .gradio-container.dark {
|
| 35 |
+
--bg-primary: #0f172a; --bg-secondary: #1e293b; --border-color: #374151;
|
| 36 |
+
--text-primary: #f3f4f6; --text-secondary: #9ca3af; --accent-bg: #334155;
|
| 37 |
+
}
|
| 38 |
+
body, .gradio-container {
|
| 39 |
+
background-color: var(--bg-primary) !important; color: var(--text-primary) !important;
|
| 40 |
+
font-family: 'Segoe UI', system-ui, sans-serif; transition: background 0.3s, color 0.3s;
|
| 41 |
+
}
|
| 42 |
+
.compact-header {
|
| 43 |
+
display: flex; align-items: center; justify-content: space-between; padding: 1.5rem 2rem;
|
| 44 |
+
border-bottom: 1px solid var(--border-color); margin-bottom: 2rem;
|
| 45 |
+
background-color: var(--bg-secondary); flex-wrap: wrap; gap: 1rem; border-radius: 0 0 12px 12px;
|
| 46 |
+
}
|
| 47 |
+
.input-card, .output-card {
|
| 48 |
+
background-color: var(--bg-secondary); border: 1px solid var(--border-color);
|
| 49 |
+
border-radius: 12px; padding: 1.5rem; margin-bottom: 1rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);
|
| 50 |
+
}
|
| 51 |
+
.header-left { display: flex; align-items: center; gap: 1.5rem; }
|
| 52 |
+
.logo-icon {
|
| 53 |
+
height: 55px; width: auto; padding: 0; background-color: transparent;
|
| 54 |
+
border: none; box-shadow: none; display: flex; align-items: center; justify-content: center; flex-shrink: 0;
|
| 55 |
+
}
|
| 56 |
+
.logo-icon svg { height: 100%; width: auto; fill: var(--primary-btn); }
|
| 57 |
+
.title-group { display: flex; flex-direction: column; }
|
| 58 |
+
.main-title { font-size: 1.6rem; font-weight: 800; margin: 0; line-height: 1.1; color: var(--text-primary); letter-spacing: -0.5px; }
|
| 59 |
+
.subtitle { font-size: 0.95rem; color: var(--text-secondary); margin: 0; font-weight: 400; }
|
| 60 |
+
.section-title { font-weight: 600; color: var(--text-primary); margin-bottom: 1rem; }
|
| 61 |
+
.header-right { flex: 1; display: flex; justify-content: flex-end; align-items: center; min-width: 250px; }
|
| 62 |
+
.dev-note-container {
|
| 63 |
+
background-color: var(--accent-bg); border: 1px solid var(--border-color); border-radius: 16px;
|
| 64 |
+
width: 520px; height: 64px; display: flex; align-items: center; justify-content: flex-start; padding: 0 24px; gap: 1rem;
|
| 65 |
+
}
|
| 66 |
+
.dev-note-container:hover { border-color: var(--primary-btn); }
|
| 67 |
+
.dev-icon { font-size: 1.4rem; background: transparent !important; border: none !important; display: flex; align-items: center; flex-shrink: 0; }
|
| 68 |
+
.dev-text {
|
| 69 |
+
font-family: 'Courier New', monospace; font-size: 0.95rem; color: var(--text-secondary);
|
| 70 |
+
transition: opacity 1.5s ease; white-space: normal; line-height: 1.2; text-align: left;
|
| 71 |
+
display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden;
|
| 72 |
+
}
|
| 73 |
+
.dev-text.hidden { opacity: 0; }
|
| 74 |
+
.feedback-section { margin-top: 2rem; padding-top: 1.5rem; border-top: 1px dashed var(--border-color); }
|
| 75 |
+
.feedback-title { font-size: 0.8rem; font-weight: 700; color: var(--text-secondary); text-transform: uppercase; margin-bottom: 0.8rem; }
|
| 76 |
+
.gr-button-primary { background: var(--primary-btn) !important; border: none !important; color: white !important; }
|
| 77 |
+
.gr-button-primary:hover { background: var(--primary-btn-hover) !important; }
|
| 78 |
+
.gr-button-secondary { background: var(--bg-primary) !important; border: 1px solid var(--border-color) !important; color: var(--text-primary) !important; }
|
| 79 |
+
.gr-box, .gr-input, .gr-dropdown { background: var(--bg-primary) !important; border-color: var(--border-color) !important; }
|
| 80 |
+
#result-box textarea {
|
| 81 |
+
font-size: 1.25rem; font-weight: 700; text-align: center; color: var(--primary-btn);
|
| 82 |
+
background-color: transparent; border: none; overflow: hidden !important; resize: none; white-space: normal; line-height: 1.4;
|
| 83 |
+
}
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# --- JAVASCRIPT ---
|
| 87 |
+
JS_LOADER = """
|
| 88 |
+
() => {
|
| 89 |
+
const notes = [
|
| 90 |
+
"Yes, even Pharo. Don’t ask why.",
|
| 91 |
+
"Is ‘deprecated’ significant? Asking for a friend.",
|
| 92 |
+
"Technical debt is just future-me's problem.",
|
| 93 |
+
"Comment first, code later. Obviously.",
|
| 94 |
+
"If it works, don't touch it.",
|
| 95 |
+
"Fixing bugs created by previous-me.",
|
| 96 |
+
"Legacy code: don't breathe on it.",
|
| 97 |
+
"Documentation is a love letter to your future self.",
|
| 98 |
+
"It works on my machine!",
|
| 99 |
+
"404: Motivation not found.",
|
| 100 |
+
"Compiling... please hold."
|
| 101 |
+
];
|
| 102 |
+
let idx = 0;
|
| 103 |
+
function rotateNotes() {
|
| 104 |
+
const textEl = document.getElementById('dev-note-text');
|
| 105 |
+
if (!textEl) { setTimeout(rotateNotes, 500); return; }
|
| 106 |
+
textEl.classList.add('hidden');
|
| 107 |
+
setTimeout(() => {
|
| 108 |
+
idx = (idx + 1) % notes.length;
|
| 109 |
+
textEl.innerText = notes[idx];
|
| 110 |
+
textEl.classList.remove('hidden');
|
| 111 |
+
}, 1500);
|
| 112 |
+
}
|
| 113 |
+
setInterval(rotateNotes, 10000);
|
| 114 |
+
}
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# --- UTILITIES ---
|
| 118 |
+
def load_svg_content(filename="logo_header.svg"):
|
| 119 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
| 120 |
+
target_path = os.path.join(base_path, "..", "..", "reports", "figures", filename)
|
| 121 |
+
target_path = os.path.normpath(target_path)
|
| 122 |
+
|
| 123 |
+
if os.path.exists(target_path):
|
| 124 |
+
with open(target_path, "r", encoding="utf-8") as f:
|
| 125 |
+
return f.read()
|
| 126 |
+
else:
|
| 127 |
+
print(f"[WARNING] Logo not found in: {target_path}")
|
| 128 |
+
return "<span style='color: var(--primary-btn); font-weight:bold;'>CCC</span>"
|
| 129 |
+
|
| 130 |
+
def save_feedback_to_csv(text, language, predicted, suggested):
|
| 131 |
+
if not text:
|
| 132 |
+
return "No data."
|
| 133 |
+
try:
|
| 134 |
+
os.makedirs(os.path.dirname(FEEDBACK_FILE), exist_ok=True)
|
| 135 |
+
file_exists = os.path.isfile(FEEDBACK_FILE)
|
| 136 |
+
with open(FEEDBACK_FILE, mode='a', newline='', encoding='utf-8') as f:
|
| 137 |
+
writer = csv.writer(f)
|
| 138 |
+
if not file_exists:
|
| 139 |
+
writer.writerow(["Timestamp", "Input_Text", "Language", "Model_Prediction", "User_Correction"])
|
| 140 |
+
|
| 141 |
+
pred_label = predicted
|
| 142 |
+
if isinstance(predicted, dict):
|
| 143 |
+
pred_label = max(predicted, key=predicted.get) if predicted else "Unknown"
|
| 144 |
+
|
| 145 |
+
writer.writerow([
|
| 146 |
+
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 147 |
+
text.strip(),
|
| 148 |
+
language,
|
| 149 |
+
pred_label,
|
| 150 |
+
suggested
|
| 151 |
+
])
|
| 152 |
+
return "Feedback saved successfully!"
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return f"Error saving feedback: {str(e)}"
|
| 155 |
+
|
| 156 |
+
# --- SYNTAX VALIDATION LOGIC ---
|
| 157 |
+
def is_valid_syntax(text: str, language: str) -> bool:
|
| 158 |
+
"""
|
| 159 |
+
Validates if the text follows the basic comment syntax for the given language.
|
| 160 |
+
"""
|
| 161 |
+
text = text.strip()
|
| 162 |
+
if not text:
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
if language == "java":
|
| 166 |
+
# Supports: // comment OR /* comment */
|
| 167 |
+
return text.startswith("//") or (text.startswith("/*") and text.endswith("*/"))
|
| 168 |
+
|
| 169 |
+
elif language == "python":
|
| 170 |
+
# Supports: # comment OR """ docstring """ OR ''' docstring '''
|
| 171 |
+
return text.startswith("#") or \
|
| 172 |
+
(text.startswith('"""') and text.endswith('"""')) or \
|
| 173 |
+
(text.startswith("'''") and text.endswith("'''"))
|
| 174 |
+
|
| 175 |
+
elif language == "pharo":
|
| 176 |
+
# Supports: " comment "
|
| 177 |
+
return text.startswith('"') and text.endswith('"')
|
| 178 |
+
|
| 179 |
+
return True
|
| 180 |
+
|
| 181 |
+
# --- MAIN DEMO ---
|
| 182 |
+
def create_demo(inference_engine: ModelInference):
|
| 183 |
+
|
| 184 |
+
def classify_comment(text: str, language: str):
|
| 185 |
+
"""
|
| 186 |
+
Calls the inference engine only if syntax is valid.
|
| 187 |
+
"""
|
| 188 |
+
if not text:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
# SYNTAX CHECK
|
| 192 |
+
if not is_valid_syntax(text, language):
|
| 193 |
+
error_msg = "Error: Invalid Syntax."
|
| 194 |
+
if language == "java":
|
| 195 |
+
error_msg += " Java comments must start with '//' or be enclosed in '/* ... */'."
|
| 196 |
+
elif language == "python":
|
| 197 |
+
error_msg += " Python comments must start with '#' or use docstrings ('\"\"\"' / \"'''\")."
|
| 198 |
+
elif language == "pharo":
|
| 199 |
+
error_msg += " Pharo comments must be enclosed in double quotes (e.g., \"comment\")."
|
| 200 |
+
return error_msg
|
| 201 |
+
|
| 202 |
+
# INFERENCE
|
| 203 |
+
try:
|
| 204 |
+
_, labels, _, _ = inference_engine.predict_payload(
|
| 205 |
+
texts=[text],
|
| 206 |
+
language=language
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
if labels and len(labels) > 0:
|
| 210 |
+
first_prediction = labels[0][0]
|
| 211 |
+
if isinstance(first_prediction, (list, tuple)):
|
| 212 |
+
return first_prediction[0]
|
| 213 |
+
else:
|
| 214 |
+
return str(first_prediction)
|
| 215 |
+
|
| 216 |
+
return "Unknown: Low confidence."
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Prediction Error: {e}")
|
| 220 |
+
return f"System Error: Failed to process request for '{language}'."
|
| 221 |
+
|
| 222 |
+
def update_dropdown(language):
|
| 223 |
+
choices = LABELS_MAP.get(language, [])
|
| 224 |
+
return gr.Dropdown(choices=choices, value=None, interactive=True)
|
| 225 |
+
|
| 226 |
+
def clear_all():
|
| 227 |
+
return (None, "java", "", gr.Dropdown(choices=LABELS_MAP["java"], value=None, interactive=True), "")
|
| 228 |
+
|
| 229 |
+
logo_svg = load_svg_content("logo_header.svg")
|
| 230 |
+
|
| 231 |
+
with gr.Blocks(title="Code Comment Classifier") as demo:
|
| 232 |
+
gr.HTML(f"<style>{CSS}</style>")
|
| 233 |
+
|
| 234 |
+
# --- HEADER ---
|
| 235 |
+
gr.HTML(f"""
|
| 236 |
+
<div class="compact-header">
|
| 237 |
+
<div class="header-left">
|
| 238 |
+
<div class="logo-icon">{logo_svg}</div>
|
| 239 |
+
<div class="title-group">
|
| 240 |
+
<h1 class="main-title">Code Comment Classifier</h1>
|
| 241 |
+
<p class="subtitle">for Java, Python & Pharo</p>
|
| 242 |
+
</div>
|
| 243 |
+
</div>
|
| 244 |
+
<div class="header-right">
|
| 245 |
+
<div class="dev-note-container">
|
| 246 |
+
<span class="dev-icon" style="color: var(--primary-btn);">💭</span>
|
| 247 |
+
<span id="dev-note-text" class="dev-text">Initializing...</span>
|
| 248 |
+
</div>
|
| 249 |
+
</div>
|
| 250 |
+
</div>
|
| 251 |
+
""")
|
| 252 |
+
|
| 253 |
+
with gr.Row():
|
| 254 |
+
with gr.Column():
|
| 255 |
+
gr.HTML('<div class="input-card"><div class="section-title">📝 Input Source</div></div>')
|
| 256 |
+
input_text = gr.Textbox(label="Code Comment", lines=8, show_label=False, placeholder="Enter code comment here...")
|
| 257 |
+
with gr.Row():
|
| 258 |
+
input_lang = gr.Dropdown(["java", "python", "pharo"], label="Language", value="java", scale=2)
|
| 259 |
+
submit_btn = gr.Button("⚡ Classify", variant="primary", scale=1)
|
| 260 |
+
clear_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
|
| 261 |
+
|
| 262 |
+
with gr.Column():
|
| 263 |
+
gr.HTML('<div class="output-card"><div class="section-title">📊 Classification Result</div></div>')
|
| 264 |
+
output_tags = gr.Textbox(
|
| 265 |
+
label="Predicted Category",
|
| 266 |
+
show_label=False,
|
| 267 |
+
elem_id="result-box",
|
| 268 |
+
interactive=False,
|
| 269 |
+
lines=2
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
gr.HTML('<div class="feedback-section"><div class="feedback-title">🛠️ Help Improve the Model</div></div>')
|
| 273 |
+
with gr.Row():
|
| 274 |
+
correction_dropdown = gr.Dropdown(
|
| 275 |
+
choices=LABELS_MAP["java"],
|
| 276 |
+
label="Correct Label",
|
| 277 |
+
show_label=False,
|
| 278 |
+
container=False,
|
| 279 |
+
scale=3,
|
| 280 |
+
interactive=True
|
| 281 |
+
)
|
| 282 |
+
feedback_btn = gr.Button("📤 Save Feedback", variant="secondary", scale=1)
|
| 283 |
+
feedback_msg = gr.Markdown("", show_label=False)
|
| 284 |
+
|
| 285 |
+
gr.Examples(
|
| 286 |
+
examples=[
|
| 287 |
+
["/** Validates the user session token. */", "java"],
|
| 288 |
+
["# Retry logic for DB connection.", "python"],
|
| 289 |
+
['"Manages the network connection lifecycle."', "pharo"]
|
| 290 |
+
],
|
| 291 |
+
inputs=[input_text, input_lang],
|
| 292 |
+
label="Quick Examples"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
input_lang.change(fn=update_dropdown, inputs=input_lang, outputs=correction_dropdown)
|
| 296 |
+
submit_btn.click(fn=classify_comment, inputs=[input_text, input_lang], outputs=[output_tags])
|
| 297 |
+
feedback_btn.click(fn=save_feedback_to_csv, inputs=[input_text, input_lang, output_tags, correction_dropdown], outputs=[feedback_msg])
|
| 298 |
+
clear_btn.click(fn=clear_all, inputs=None, outputs=[input_text, input_lang, output_tags, correction_dropdown, feedback_msg])
|
| 299 |
+
|
| 300 |
+
demo.load(None, js=JS_LOADER)
|
| 301 |
+
|
| 302 |
+
return demo
|
turing/api/schemas.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Input Schema
|
| 7 |
+
class PredictionRequest(BaseModel):
|
| 8 |
+
texts: List[str] = Field(
|
| 9 |
+
...,
|
| 10 |
+
description="List of code comments to classify",
|
| 11 |
+
example=["public void main", "def init self"],
|
| 12 |
+
)
|
| 13 |
+
language: str = Field(
|
| 14 |
+
..., description="Programming language (java, python, pharo)", example="java"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Output Schema
|
| 19 |
+
class PredictionResponse(BaseModel):
|
| 20 |
+
predictions: List[Any] = Field(..., description="List of predicted labels")
|
| 21 |
+
labels: List[Any] = Field(..., description="List of human-readable labels")
|
| 22 |
+
model_info: dict = Field(..., description="Metadata about the model used")
|
turing/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
# Load environment variables from .env file if it exists
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
# Paths
|
| 10 |
+
PROJ_ROOT = Path(__file__).resolve().parents[1]
|
| 11 |
+
logger.info(f"PROJ_ROOT path is: {PROJ_ROOT}")
|
| 12 |
+
|
| 13 |
+
DATA_DIR = PROJ_ROOT / "data"
|
| 14 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 15 |
+
INTERIM_DATA_DIR = DATA_DIR / "interim"
|
| 16 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 17 |
+
EXTERNAL_DATA_DIR = DATA_DIR / "external"
|
| 18 |
+
|
| 19 |
+
MODELS_DIR = PROJ_ROOT / "models"
|
| 20 |
+
|
| 21 |
+
REPORTS_DIR = PROJ_ROOT / "reports"
|
| 22 |
+
FIGURES_DIR = REPORTS_DIR / "figures"
|
| 23 |
+
|
| 24 |
+
# Dataset
|
| 25 |
+
DATASET_HF_ID = "NLBSE/nlbse26-code-comment-classification"
|
| 26 |
+
LANGS = ["java", "python", "pharo"]
|
| 27 |
+
INPUT_COLUMN = "combo"
|
| 28 |
+
LABEL_COLUMN = "labels"
|
| 29 |
+
|
| 30 |
+
LABELS_MAP = {
|
| 31 |
+
"java": ["summary", "Ownership", "Expand", "usage", "Pointer", "deprecation", "rational"],
|
| 32 |
+
"python": ["Usage", "Parameters", "DevelopmentNotes", "Expand", "Summary"],
|
| 33 |
+
"pharo": [
|
| 34 |
+
"Keyimplementationpoints",
|
| 35 |
+
"Example",
|
| 36 |
+
"Responsibilities",
|
| 37 |
+
"Intent",
|
| 38 |
+
"Keymessages",
|
| 39 |
+
"Collaborators",
|
| 40 |
+
],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
TOTAL_CATEGORIES = sum(len(v) for v in LABELS_MAP.values())
|
| 44 |
+
|
| 45 |
+
# Score parameters
|
| 46 |
+
MAX_AVG_RUNTIME = 5.0 # seconds
|
| 47 |
+
MAX_AVG_FLOPS = 5000.0 # GFLOPS
|
| 48 |
+
|
| 49 |
+
# Training parameters
|
| 50 |
+
DEFAULT_BATCH_SIZE = 32
|
| 51 |
+
|
| 52 |
+
# Model configuration mapping
|
| 53 |
+
MODEL_CONFIG = {
|
| 54 |
+
"codeberta": {
|
| 55 |
+
"model_name": "fine-tuned-CodeBERTa",
|
| 56 |
+
"exp_name": "fine-tuned-CodeBERTa",
|
| 57 |
+
"model_class_module": "turing.modeling.models.codeBerta",
|
| 58 |
+
"model_class_name": "CodeBERTa",
|
| 59 |
+
},
|
| 60 |
+
"graphcodebert": {
|
| 61 |
+
"model_name": "GraphCodeBERT",
|
| 62 |
+
"exp_name": "fine-tuned-GraphCodeBERT",
|
| 63 |
+
"model_class_module": "turing.modeling.models.graphCodeBert",
|
| 64 |
+
"model_class_name": "GraphCodeBERTClassifier",
|
| 65 |
+
},
|
| 66 |
+
"tinybert": {
|
| 67 |
+
"model_name": "TinyBERT",
|
| 68 |
+
"exp_name": "fine-tuned-TinyBERT",
|
| 69 |
+
"model_class_module": "turing.modeling.models.tinyBert",
|
| 70 |
+
"model_class_name": "TinyBERTClassifier",
|
| 71 |
+
},
|
| 72 |
+
"randomforest": {
|
| 73 |
+
"model_name": "RandomForest-TfIdf",
|
| 74 |
+
"exp_name": "RandomForest-TfIdf",
|
| 75 |
+
"model_class_module": "turing.modeling.models.randomForestTfIdf",
|
| 76 |
+
"model_class_name": "RandomForestTfIdf",
|
| 77 |
+
},
|
| 78 |
+
}
|
| 79 |
+
DEFAULT_NUM_ITERATIONS = 20
|
| 80 |
+
|
| 81 |
+
# Existing model modules
|
| 82 |
+
EXISTING_MODELS = [
|
| 83 |
+
"randomForestTfIdf",
|
| 84 |
+
"codeBerta",
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# If tqdm is installed, configure loguru with tqdm.write
|
| 88 |
+
# https://github.com/Delgan/loguru/issues/135
|
| 89 |
+
try:
|
| 90 |
+
from tqdm import tqdm
|
| 91 |
+
|
| 92 |
+
logger.remove(0)
|
| 93 |
+
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
|
| 94 |
+
except (ModuleNotFoundError, ValueError):
|
| 95 |
+
pass
|
turing/data_validation.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from deepchecks.tabular import Dataset, Suite
|
| 6 |
+
from deepchecks.tabular.checks import (
|
| 7 |
+
ConflictingLabels,
|
| 8 |
+
DataDuplicates,
|
| 9 |
+
LabelDrift,
|
| 10 |
+
OutlierSampleDetection,
|
| 11 |
+
TrainTestSamplesMix,
|
| 12 |
+
)
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
|
| 16 |
+
from turing.config import LABEL_COLUMN, LABELS_MAP
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from deepchecks.nlp import TextData
|
| 20 |
+
from deepchecks.nlp.checks import (
|
| 21 |
+
PropertyDrift,
|
| 22 |
+
TextEmbeddingsDrift,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
NLP_AVAILABLE = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
NLP_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _encode_labels_for_validation(
|
| 31 |
+
series: pd.Series, class_names: List[str]
|
| 32 |
+
) -> pd.Series:
|
| 33 |
+
def encode(lbl):
|
| 34 |
+
active_labels = []
|
| 35 |
+
for idx, is_active in enumerate(lbl):
|
| 36 |
+
if is_active:
|
| 37 |
+
if idx < len(class_names):
|
| 38 |
+
active_labels.append(class_names[idx])
|
| 39 |
+
else:
|
| 40 |
+
active_labels.append(f"Class_{idx}")
|
| 41 |
+
if not active_labels:
|
| 42 |
+
return "No_Label"
|
| 43 |
+
return " & ".join(active_labels)
|
| 44 |
+
|
| 45 |
+
return series.apply(encode)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _calculate_code_specific_properties(text_series: List[str]) -> pd.DataFrame:
|
| 49 |
+
props = []
|
| 50 |
+
for text in text_series:
|
| 51 |
+
s = str(text)
|
| 52 |
+
length = len(s)
|
| 53 |
+
non_alnum = sum(1 for c in s if not c.isalnum() and not c.isspace())
|
| 54 |
+
props.append(
|
| 55 |
+
{
|
| 56 |
+
"Text_Length": length,
|
| 57 |
+
"Symbol_Ratio": non_alnum / length if length > 0 else 0.0,
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
return pd.DataFrame(props)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _nuke_rogue_files():
|
| 64 |
+
"""
|
| 65 |
+
delete .npy files
|
| 66 |
+
"""
|
| 67 |
+
rogue_filenames = [
|
| 68 |
+
"embeddings.npy"
|
| 69 |
+
|
| 70 |
+
]
|
| 71 |
+
for fname in rogue_filenames:
|
| 72 |
+
p = Path(fname)
|
| 73 |
+
if p.exists():
|
| 74 |
+
try:
|
| 75 |
+
p.unlink()
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def run_custom_deepchecks(
|
| 81 |
+
df_train: pd.DataFrame,
|
| 82 |
+
df_test: pd.DataFrame,
|
| 83 |
+
output_dir: Path,
|
| 84 |
+
stage: str,
|
| 85 |
+
language: str,
|
| 86 |
+
):
|
| 87 |
+
print(f" [Deepchecks] Running Integrity Suite ({stage})...")
|
| 88 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
class_names = LABELS_MAP.get(language, [])
|
| 91 |
+
cols = ["f_length", "f_word_count", "f_starts_verb", "text_hash"]
|
| 92 |
+
|
| 93 |
+
for c in cols:
|
| 94 |
+
if c not in df_train.columns:
|
| 95 |
+
df_train[c] = 0
|
| 96 |
+
if c not in df_test.columns:
|
| 97 |
+
df_test[c] = 0
|
| 98 |
+
|
| 99 |
+
train_ds_df = df_train[cols].copy()
|
| 100 |
+
train_ds_df["target"] = _encode_labels_for_validation(
|
| 101 |
+
df_train[LABEL_COLUMN], class_names
|
| 102 |
+
)
|
| 103 |
+
test_ds_df = df_test[cols].copy()
|
| 104 |
+
test_ds_df["target"] = _encode_labels_for_validation(
|
| 105 |
+
df_test[LABEL_COLUMN], class_names
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
cat_features = ["text_hash", "f_starts_verb"]
|
| 109 |
+
train_ds = Dataset(train_ds_df, label="target", cat_features=cat_features)
|
| 110 |
+
test_ds = Dataset(test_ds_df, label="target", cat_features=cat_features)
|
| 111 |
+
|
| 112 |
+
check_conflicts = ConflictingLabels(columns=["text_hash"])
|
| 113 |
+
if hasattr(check_conflicts, "add_condition_ratio_of_conflicting_labels_not_greater_than"):
|
| 114 |
+
check_conflicts.add_condition_ratio_of_conflicting_labels_not_greater_than(0)
|
| 115 |
+
else:
|
| 116 |
+
check_conflicts.add_condition_ratio_of_conflicting_labels_less_or_equal(0)
|
| 117 |
+
|
| 118 |
+
check_duplicates = DataDuplicates()
|
| 119 |
+
if hasattr(check_duplicates, "add_condition_ratio_not_greater_than"):
|
| 120 |
+
check_duplicates.add_condition_ratio_not_greater_than(0.05)
|
| 121 |
+
else:
|
| 122 |
+
check_duplicates.add_condition_ratio_less_or_equal(0.05)
|
| 123 |
+
|
| 124 |
+
check_leakage = TrainTestSamplesMix(columns=["text_hash"])
|
| 125 |
+
try:
|
| 126 |
+
if hasattr(check_leakage, "add_condition_ratio_not_greater_than"):
|
| 127 |
+
check_leakage.add_condition_ratio_not_greater_than(0)
|
| 128 |
+
except Exception:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
check_outliers = OutlierSampleDetection()
|
| 132 |
+
try:
|
| 133 |
+
if hasattr(check_outliers, "add_condition_outlier_ratio_less_or_equal"):
|
| 134 |
+
check_outliers.add_condition_outlier_ratio_less_or_equal(0.05)
|
| 135 |
+
except Exception:
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
custom_suite = Suite(
|
| 139 |
+
"Code Quality & Integrity",
|
| 140 |
+
check_conflicts,
|
| 141 |
+
check_duplicates,
|
| 142 |
+
check_leakage,
|
| 143 |
+
LabelDrift(),
|
| 144 |
+
check_outliers,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
result = custom_suite.run(train_dataset=train_ds, test_dataset=test_ds)
|
| 149 |
+
report_path = output_dir / f"1_Integrity_{stage}.html"
|
| 150 |
+
result.save_as_html(str(report_path), as_widget=False)
|
| 151 |
+
print(f" [Deepchecks] Report Saved: {report_path}")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f" [Deepchecks] Error: {e}")
|
| 154 |
+
traceback.print_exc()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def run_targeted_nlp_checks(
|
| 158 |
+
df_train: pd.DataFrame,
|
| 159 |
+
df_test: pd.DataFrame,
|
| 160 |
+
output_dir: Path,
|
| 161 |
+
stage: str,
|
| 162 |
+
language: str = "english",
|
| 163 |
+
):
|
| 164 |
+
if not NLP_AVAILABLE:
|
| 165 |
+
print(" [Skip] NLP Suite skipped (libs not installed).")
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
from deepchecks.nlp import Suite as NLPSuite
|
| 169 |
+
|
| 170 |
+
print(f" [NLP Check] Running Semantic Analysis ({stage})...")
|
| 171 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
# Clean up any existing garbage before starting
|
| 174 |
+
_nuke_rogue_files()
|
| 175 |
+
|
| 176 |
+
DRIFT_THRESHOLD = 0.20
|
| 177 |
+
PROP_THRESHOLD = 0.35
|
| 178 |
+
SAMPLE_SIZE = 2000
|
| 179 |
+
df_tr = (
|
| 180 |
+
df_train.sample(n=SAMPLE_SIZE, random_state=42)
|
| 181 |
+
if len(df_train) > SAMPLE_SIZE
|
| 182 |
+
else df_train
|
| 183 |
+
)
|
| 184 |
+
df_te = (
|
| 185 |
+
df_test.sample(n=SAMPLE_SIZE, random_state=42)
|
| 186 |
+
if len(df_test) > SAMPLE_SIZE
|
| 187 |
+
else df_test
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
try: # START MAIN TRY BLOCK
|
| 191 |
+
y_tr = np.vstack(df_tr[LABEL_COLUMN].tolist())
|
| 192 |
+
y_te = np.vstack(df_te[LABEL_COLUMN].tolist())
|
| 193 |
+
|
| 194 |
+
train_ds = TextData(
|
| 195 |
+
df_tr["comment_sentence"].tolist(),
|
| 196 |
+
label=y_tr,
|
| 197 |
+
task_type="text_classification",
|
| 198 |
+
)
|
| 199 |
+
test_ds = TextData(
|
| 200 |
+
df_te["comment_sentence"].tolist(),
|
| 201 |
+
label=y_te,
|
| 202 |
+
task_type="text_classification",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
print(" [NLP Check] Calculating custom code properties...")
|
| 206 |
+
train_props = _calculate_code_specific_properties(
|
| 207 |
+
df_tr["comment_sentence"].tolist()
|
| 208 |
+
)
|
| 209 |
+
test_props = _calculate_code_specific_properties(
|
| 210 |
+
df_te["comment_sentence"].tolist()
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
train_ds.set_properties(train_props)
|
| 214 |
+
test_ds.set_properties(test_props)
|
| 215 |
+
|
| 216 |
+
# In-memory calculation only.
|
| 217 |
+
train_ds.calculate_builtin_embeddings()
|
| 218 |
+
test_ds.calculate_builtin_embeddings()
|
| 219 |
+
|
| 220 |
+
check_embeddings = TextEmbeddingsDrift()
|
| 221 |
+
if hasattr(check_embeddings, "add_condition_drift_score_not_greater_than"):
|
| 222 |
+
check_embeddings.add_condition_drift_score_not_greater_than(DRIFT_THRESHOLD)
|
| 223 |
+
elif hasattr(check_embeddings, "add_condition_drift_score_less_than"):
|
| 224 |
+
check_embeddings.add_condition_drift_score_less_than(DRIFT_THRESHOLD)
|
| 225 |
+
|
| 226 |
+
check_len = PropertyDrift(custom_property_name="Text_Length")
|
| 227 |
+
if hasattr(check_len, "add_condition_drift_score_not_greater_than"):
|
| 228 |
+
check_len.add_condition_drift_score_not_greater_than(PROP_THRESHOLD)
|
| 229 |
+
elif hasattr(check_len, "add_condition_drift_score_less_than"):
|
| 230 |
+
check_len.add_condition_drift_score_less_than(PROP_THRESHOLD)
|
| 231 |
+
|
| 232 |
+
check_sym = PropertyDrift(custom_property_name="Symbol_Ratio")
|
| 233 |
+
if hasattr(check_sym, "add_condition_drift_score_not_greater_than"):
|
| 234 |
+
check_sym.add_condition_drift_score_not_greater_than(PROP_THRESHOLD)
|
| 235 |
+
elif hasattr(check_sym, "add_condition_drift_score_less_than"):
|
| 236 |
+
check_sym.add_condition_drift_score_less_than(PROP_THRESHOLD)
|
| 237 |
+
|
| 238 |
+
suite = NLPSuite(
|
| 239 |
+
"Code Comment Semantic Analysis",
|
| 240 |
+
check_embeddings,
|
| 241 |
+
check_len,
|
| 242 |
+
check_sym
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
res = suite.run(train_ds, test_ds)
|
| 246 |
+
|
| 247 |
+
report_path = output_dir / f"2_Semantic_{stage}.html"
|
| 248 |
+
res.save_as_html(str(report_path), as_widget=False)
|
| 249 |
+
print(f" [NLP Check] Report saved: {report_path}")
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
passed = res.get_passed_checks()
|
| 253 |
+
n_passed = len(passed)
|
| 254 |
+
n_total = len(res.results)
|
| 255 |
+
print(f" [NLP Result] {n_passed}/{n_total} checks passed.")
|
| 256 |
+
|
| 257 |
+
if n_passed < n_total:
|
| 258 |
+
print(" [NLP Warning] Failed Checks details:")
|
| 259 |
+
for result in res.results:
|
| 260 |
+
if not result.passed_conditions():
|
| 261 |
+
print(f" - {result.check.name}: {result.conditions_results[0].details}")
|
| 262 |
+
except Exception:
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
print(f" [NLP Check] Failed: {e}")
|
| 267 |
+
import traceback
|
| 268 |
+
traceback.print_exc()
|
| 269 |
+
|
| 270 |
+
finally:
|
| 271 |
+
_nuke_rogue_files()
|
turing/dataset.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from datasets import DatasetDict, load_dataset
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
import turing.config as config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DatasetManager:
|
| 12 |
+
"""
|
| 13 |
+
Manages the loading, transformation, and access of project datasets.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, dataset_path: Path = None):
|
| 17 |
+
self.hf_id = config.DATASET_HF_ID
|
| 18 |
+
self.raw_data_dir = config.RAW_DATA_DIR
|
| 19 |
+
self.interim_data_dir = config.INTERIM_DATA_DIR
|
| 20 |
+
self.base_interim_path = self.interim_data_dir / "base"
|
| 21 |
+
|
| 22 |
+
if dataset_path:
|
| 23 |
+
self.dataset_path = dataset_path
|
| 24 |
+
else:
|
| 25 |
+
self.dataset_path = self.base_interim_path
|
| 26 |
+
|
| 27 |
+
def _format_labels_for_csv(self, example: dict) -> dict:
|
| 28 |
+
"""
|
| 29 |
+
Formats the labels list as a string for CSV storage.
|
| 30 |
+
(Private class method)
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
example (dict): A single example from the dataset.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
dict: The example with labels converted to string.
|
| 37 |
+
"""
|
| 38 |
+
labels = example.get("labels")
|
| 39 |
+
if isinstance(labels, list):
|
| 40 |
+
example["labels"] = str(labels)
|
| 41 |
+
return example
|
| 42 |
+
|
| 43 |
+
def download_dataset(self):
|
| 44 |
+
"""
|
| 45 |
+
Loads the dataset from Hugging Face and saves it into the "raw" folder.
|
| 46 |
+
"""
|
| 47 |
+
logger.info(f"Loading dataset: {self.hf_id}")
|
| 48 |
+
try:
|
| 49 |
+
ds = load_dataset(self.hf_id)
|
| 50 |
+
logger.success("Dataset loaded successfully.")
|
| 51 |
+
logger.info(f"Dataset splits: {ds}")
|
| 52 |
+
|
| 53 |
+
self.raw_data_dir.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
for split_name, dataset_split in ds.items():
|
| 56 |
+
output_path = os.path.join(
|
| 57 |
+
self.raw_data_dir, f"{split_name.replace('-', '_')}.parquet"
|
| 58 |
+
)
|
| 59 |
+
dataset_split.to_parquet(output_path)
|
| 60 |
+
|
| 61 |
+
logger.success(f"Dataset saved to {self.raw_data_dir}.")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.warning(f"Error during loading: {e}.")
|
| 64 |
+
|
| 65 |
+
def parquet_to_csv(self):
|
| 66 |
+
"""
|
| 67 |
+
Converts all parquet files in the raw data directory
|
| 68 |
+
to CSV format in the interim data directory.
|
| 69 |
+
"""
|
| 70 |
+
logger.info("Starting Parquet to CSV conversion...")
|
| 71 |
+
self.base_interim_path.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
for file_name in os.listdir(self.raw_data_dir):
|
| 74 |
+
if file_name.endswith(".parquet"):
|
| 75 |
+
part_name = file_name.replace(".parquet", "").replace("-", "_")
|
| 76 |
+
|
| 77 |
+
# Load the parquet file
|
| 78 |
+
dataset = load_dataset(
|
| 79 |
+
"parquet", data_files={part_name: str(self.raw_data_dir / file_name)}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Map and format labels
|
| 83 |
+
dataset[part_name] = dataset[part_name].map(self._format_labels_for_csv)
|
| 84 |
+
|
| 85 |
+
# Save to CSV
|
| 86 |
+
csv_output_path = os.path.join(self.base_interim_path, f"{part_name}.csv")
|
| 87 |
+
dataset[part_name].to_csv(csv_output_path)
|
| 88 |
+
|
| 89 |
+
logger.info(f"Converted {file_name} to {csv_output_path}")
|
| 90 |
+
|
| 91 |
+
logger.success("Parquet -> CSV conversion complete.")
|
| 92 |
+
|
| 93 |
+
def get_dataset_name(self) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Returns the name of the current dataset being used.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: The name of the dataset (e.g., 'clean-aug-soft-k5000').
|
| 99 |
+
"""
|
| 100 |
+
return self.dataset_path.name
|
| 101 |
+
|
| 102 |
+
def get_dataset(self) -> DatasetDict:
|
| 103 |
+
"""
|
| 104 |
+
Returns the processed dataset from the interim data directory
|
| 105 |
+
as a DatasetDict (loaded from CSVs).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
DatasetDict: The complete dataset with train and test splits for each language.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
dataset_path = self.dataset_path
|
| 112 |
+
|
| 113 |
+
# Define the base filenames
|
| 114 |
+
data_files = {
|
| 115 |
+
"java_train": str(dataset_path / "java_train.csv"),
|
| 116 |
+
"java_test": str(dataset_path / "java_test.csv"),
|
| 117 |
+
"python_train": str(dataset_path / "python_train.csv"),
|
| 118 |
+
"python_test": str(dataset_path / "python_test.csv"),
|
| 119 |
+
"pharo_train": str(dataset_path / "pharo_train.csv"),
|
| 120 |
+
"pharo_test": str(dataset_path / "pharo_test.csv"),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# Verify file existence before loading
|
| 124 |
+
logger.info("Loading CSV dataset from splits...")
|
| 125 |
+
existing_data_files = {}
|
| 126 |
+
for key, path in data_files.items():
|
| 127 |
+
if not os.path.exists(path):
|
| 128 |
+
found = False
|
| 129 |
+
if os.path.exists(dataset_path):
|
| 130 |
+
for f in os.listdir(dataset_path):
|
| 131 |
+
if f.startswith(key) and f.endswith(".csv"):
|
| 132 |
+
existing_data_files[key] = str(dataset_path / f)
|
| 133 |
+
found = True
|
| 134 |
+
break
|
| 135 |
+
if not found:
|
| 136 |
+
logger.warning(f"File not found for split '{key}': {path}")
|
| 137 |
+
else:
|
| 138 |
+
existing_data_files[key] = path
|
| 139 |
+
|
| 140 |
+
if not existing_data_files:
|
| 141 |
+
logger.error("No dataset CSV files found. Run 'parquet-to-csv' first.")
|
| 142 |
+
raise FileNotFoundError("Dataset CSV files not found.")
|
| 143 |
+
|
| 144 |
+
logger.info(f"Found files: {list(existing_data_files.keys())}")
|
| 145 |
+
|
| 146 |
+
full_dataset = load_dataset("csv", data_files=existing_data_files)
|
| 147 |
+
|
| 148 |
+
logger.info("Formatting labels (from string back to list)...")
|
| 149 |
+
for split in full_dataset:
|
| 150 |
+
full_dataset[split] = full_dataset[split].map(
|
| 151 |
+
lambda x: {
|
| 152 |
+
"labels": ast.literal_eval(x["labels"])
|
| 153 |
+
if isinstance(x["labels"], str)
|
| 154 |
+
else x["labels"]
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
logger.success("Dataset is ready for use.")
|
| 159 |
+
return full_dataset
|
| 160 |
+
|
| 161 |
+
def get_raw_dataset_from_hf(self) -> DatasetDict:
|
| 162 |
+
"""
|
| 163 |
+
Loads the raw dataset directly from Hugging Face without saving.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
DatasetDict: The raw dataset from Hugging Face.
|
| 167 |
+
"""
|
| 168 |
+
logger.info(f"Loading raw dataset '{self.hf_id}' from Hugging Face...")
|
| 169 |
+
try:
|
| 170 |
+
ds = load_dataset(self.hf_id)
|
| 171 |
+
logger.success(f"Successfully loaded '{self.hf_id}'.")
|
| 172 |
+
return ds
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Failed to load dataset from Hugging Face: {e}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
def search_file(self, file_name: str, search_directory: Path = None) -> list:
|
| 178 |
+
"""
|
| 179 |
+
Recursively searches for a file by name within a specified data directory.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
file_name (str): The name of the file to search for (e.g., "java_train.csv").
|
| 183 |
+
search_directory (Path, optional): The directory to search in.
|
| 184 |
+
Defaults to self.raw_data_dir.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
list: A list of Path objects for all found files.
|
| 188 |
+
"""
|
| 189 |
+
if search_directory is None:
|
| 190 |
+
search_directory = self.raw_data_dir
|
| 191 |
+
logger.info(f"Defaulting search to raw data directory: {search_directory}")
|
| 192 |
+
|
| 193 |
+
if not search_directory.is_dir():
|
| 194 |
+
logger.error(f"Search directory not found: {search_directory}")
|
| 195 |
+
return []
|
| 196 |
+
|
| 197 |
+
logger.info(f"Searching for '{file_name}' in '{search_directory}'...")
|
| 198 |
+
|
| 199 |
+
found_files = []
|
| 200 |
+
for root, dirs, files in os.walk(search_directory):
|
| 201 |
+
for file in files:
|
| 202 |
+
if file == file_name:
|
| 203 |
+
found_files.append(Path(root) / file)
|
| 204 |
+
|
| 205 |
+
if not found_files:
|
| 206 |
+
logger.warning(f"No files named '{file_name}' found in '{search_directory}'.")
|
| 207 |
+
else:
|
| 208 |
+
logger.success(f"Found {len(found_files)} matching file(s).")
|
| 209 |
+
|
| 210 |
+
return found_files
|
turing/evaluate_model.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
from datasets import DatasetDict
|
| 4 |
+
from loguru import logger
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import turing.config as config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def calculate_submission_score(avg_f1: float, avg_runtime: float, avg_flops: float) -> float:
|
| 13 |
+
"""
|
| 14 |
+
Calculates the final competition score.
|
| 15 |
+
The score is a weighted sum of F1 score, runtime, and GFLOPS.
|
| 16 |
+
Weights:
|
| 17 |
+
- F1 Score: 60%
|
| 18 |
+
- Runtime: 20%
|
| 19 |
+
- GFLOPS: 20%
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
avg_f1 (float): Average F1 score across all categories.
|
| 23 |
+
avg_runtime (float): Average runtime in seconds.
|
| 24 |
+
avg_flops (float): Average GFLOPS.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
float: Final submission score.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
score_f1 = 0.6 * avg_f1
|
| 31 |
+
|
| 32 |
+
runtime_ratio = (config.MAX_AVG_RUNTIME - avg_runtime) / config.MAX_AVG_RUNTIME
|
| 33 |
+
score_runtime = 0.2 * max(runtime_ratio, 0)
|
| 34 |
+
|
| 35 |
+
flops_ratio = (config.MAX_AVG_FLOPS - avg_flops) / config.MAX_AVG_FLOPS
|
| 36 |
+
score_flops = 0.2 * max(flops_ratio, 0)
|
| 37 |
+
|
| 38 |
+
total_score = score_f1 + score_runtime + score_flops
|
| 39 |
+
|
| 40 |
+
logger.info(f" F1 Score (60%): {score_f1:.4f} (avg_f1: {avg_f1:.4f})")
|
| 41 |
+
logger.info(
|
| 42 |
+
f" Runtime Score (20%): {score_runtime:.4f} (avg_runtime: {avg_runtime:.4f}s / {config.MAX_AVG_RUNTIME}s)"
|
| 43 |
+
)
|
| 44 |
+
logger.info(
|
| 45 |
+
f" GFLOPS Score (20%): {score_flops:.4f} (avg_flops: {avg_flops:.4f} / {config.MAX_AVG_FLOPS})"
|
| 46 |
+
)
|
| 47 |
+
logger.info(" ====================")
|
| 48 |
+
logger.info(f" Final Score: {total_score:.4f}")
|
| 49 |
+
|
| 50 |
+
return total_score
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def evaluate_models(models: dict, dataset: DatasetDict):
|
| 54 |
+
"""
|
| 55 |
+
Evaluates the provided models on the test datasets for each language.
|
| 56 |
+
Computes precision, recall, and F1 score for each category and language.
|
| 57 |
+
Also measures average runtime and GFLOPS for model inference.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
models (dict): A dictionary mapping language codes to their respective models.
|
| 61 |
+
dataset (DatasetDict): A DatasetDict containing test datasets for each language.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
pd.DataFrame: DataFrame containing precision, recall, and F1 scores for each category and language.
|
| 65 |
+
float: Final submission score calculated based on average F1, runtime, and GF
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
total_flops = 0
|
| 69 |
+
total_time = 0
|
| 70 |
+
scores = []
|
| 71 |
+
|
| 72 |
+
for lan in config.LANGS:
|
| 73 |
+
logger.info(f"\n--- Evaluating Language: {lan.upper()} ---")
|
| 74 |
+
model = models[lan]
|
| 75 |
+
|
| 76 |
+
with torch.profiler.profile(with_flops=True) as p:
|
| 77 |
+
test_data = dataset[f"{lan}_test"]
|
| 78 |
+
x = test_data[config.INPUT_COLUMN]
|
| 79 |
+
x = list(x) if hasattr(x, 'tolist') else x # Convert pandas Series to list
|
| 80 |
+
y_true = np.array(test_data[config.LABEL_COLUMN]).T
|
| 81 |
+
|
| 82 |
+
begin = time.time()
|
| 83 |
+
for i in range(10):
|
| 84 |
+
y_pred = model.predict(x)
|
| 85 |
+
y_pred = np.asarray(y_pred).T
|
| 86 |
+
total = time.time() - begin
|
| 87 |
+
total_time = total_time + total
|
| 88 |
+
|
| 89 |
+
total_flops = total_flops + (sum(k.flops for k in p.key_averages()) / 1e9)
|
| 90 |
+
|
| 91 |
+
for i in range(len(y_pred)):
|
| 92 |
+
assert len(y_pred[i]) == len(y_true[i])
|
| 93 |
+
tp = sum([true == pred == 1 for (true, pred) in zip(y_true[i], y_pred[i])])
|
| 94 |
+
#tn = sum([true == pred == 0 for (true, pred) in zip(y_true[i], y_pred[i])])
|
| 95 |
+
fp = sum([true == 0 and pred == 1 for (true, pred) in zip(y_true[i], y_pred[i])])
|
| 96 |
+
fn = sum([true == 1 and pred == 0 for (true, pred) in zip(y_true[i], y_pred[i])])
|
| 97 |
+
precision = tp / (tp + fp)
|
| 98 |
+
recall = tp / (tp + fn)
|
| 99 |
+
f1 = (2 * tp) / (2 * tp + fp + fn)
|
| 100 |
+
scores.append({
|
| 101 |
+
"lan": lan,
|
| 102 |
+
"cat": config.LABELS_MAP[lan][i],
|
| 103 |
+
"precision": precision,
|
| 104 |
+
"recall": recall,
|
| 105 |
+
"f1": f1,
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
logger.info(f"Compute in GFLOPs: {total_flops / 10}")
|
| 109 |
+
logger.info(f"Avg runtime in seconds: {total_time / 10}")
|
| 110 |
+
scores = pd.DataFrame(scores)
|
| 111 |
+
print(scores)
|
| 112 |
+
|
| 113 |
+
avg_f1 = scores["f1"].mean()
|
| 114 |
+
avg_runtime = total_time / 10
|
| 115 |
+
avg_flops = total_flops / 10
|
| 116 |
+
|
| 117 |
+
final_score = calculate_submission_score(avg_f1, avg_runtime, avg_flops)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Final Score for {lan.upper()}: {final_score:.4f}")
|
| 120 |
+
|
| 121 |
+
return scores, final_score
|
turing/features.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import hashlib
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import nltk
|
| 9 |
+
from nltk.corpus import stopwords, wordnet
|
| 10 |
+
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 14 |
+
from sklearn.feature_selection import SelectKBest, chi2
|
| 15 |
+
import typer
|
| 16 |
+
|
| 17 |
+
from turing.config import (
|
| 18 |
+
INTERIM_DATA_DIR,
|
| 19 |
+
LABEL_COLUMN,
|
| 20 |
+
LANGS,
|
| 21 |
+
)
|
| 22 |
+
from turing.data_validation import run_custom_deepchecks, run_targeted_nlp_checks
|
| 23 |
+
from turing.dataset import DatasetManager
|
| 24 |
+
|
| 25 |
+
# --- NLTK Resource Check ---
|
| 26 |
+
REQUIRED_NLTK_PACKAGES = [
|
| 27 |
+
"stopwords",
|
| 28 |
+
"wordnet",
|
| 29 |
+
"omw-1.4",
|
| 30 |
+
"averaged_perceptron_tagger",
|
| 31 |
+
"punkt",
|
| 32 |
+
]
|
| 33 |
+
for package in REQUIRED_NLTK_PACKAGES:
|
| 34 |
+
try:
|
| 35 |
+
nltk.data.find(f"corpora/{package}")
|
| 36 |
+
except LookupError:
|
| 37 |
+
try:
|
| 38 |
+
nltk.download(package, quiet=True)
|
| 39 |
+
except Exception:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
app = typer.Typer()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# --- CONFIGURATION CLASS ---
|
| 46 |
+
class FeaturePipelineConfig:
|
| 47 |
+
"""
|
| 48 |
+
Configuration holder for the pipeline. Generates a unique ID based on parameters
|
| 49 |
+
to version the output directories.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
use_stopwords: bool,
|
| 55 |
+
use_lemmatization: bool,
|
| 56 |
+
use_combo_feature: bool,
|
| 57 |
+
max_features: int,
|
| 58 |
+
min_comment_length: int,
|
| 59 |
+
max_comment_length: int,
|
| 60 |
+
enable_augmentation: bool,
|
| 61 |
+
custom_tags: str = "base",
|
| 62 |
+
):
|
| 63 |
+
self.use_stopwords = use_stopwords
|
| 64 |
+
self.use_lemmatization = use_lemmatization
|
| 65 |
+
self.use_combo_feature = use_combo_feature
|
| 66 |
+
self.max_features = max_features
|
| 67 |
+
self.min_comment_length = min_comment_length
|
| 68 |
+
self.max_comment_length = max_comment_length
|
| 69 |
+
self.enable_augmentation = enable_augmentation
|
| 70 |
+
self.custom_tags = custom_tags
|
| 71 |
+
self.hash_id = self._generate_readable_id()
|
| 72 |
+
|
| 73 |
+
def _generate_readable_id(self) -> str:
|
| 74 |
+
tags = ["clean"]
|
| 75 |
+
if self.enable_augmentation:
|
| 76 |
+
tags.append("aug-soft")
|
| 77 |
+
tags.append(f"k{self.max_features}")
|
| 78 |
+
if self.custom_tags != "base":
|
| 79 |
+
tags.append(self.custom_tags)
|
| 80 |
+
return "-".join(tags)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# --- TEXT UTILITIES ---
|
| 84 |
+
class TextCanonicalizer:
|
| 85 |
+
"""
|
| 86 |
+
Reduces text to a 'canonical' form (stemmed, lowercase)
|
| 87 |
+
to detect semantic duplicates.
|
| 88 |
+
preserves javadoc tags to distinguish usage (@return) from summary (Returns).
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
self.stemmer = PorterStemmer()
|
| 93 |
+
self.stop_words = set(stopwords.words("english"))
|
| 94 |
+
# Code keywords are preserved as they carry semantic weight
|
| 95 |
+
self.code_keywords = {
|
| 96 |
+
"return",
|
| 97 |
+
"true",
|
| 98 |
+
"false",
|
| 99 |
+
"null",
|
| 100 |
+
"if",
|
| 101 |
+
"else",
|
| 102 |
+
"void",
|
| 103 |
+
"int",
|
| 104 |
+
"boolean",
|
| 105 |
+
"param",
|
| 106 |
+
"throws",
|
| 107 |
+
"exception",
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def to_canonical(self, text: str) -> str:
|
| 111 |
+
if pd.isna(text):
|
| 112 |
+
return ""
|
| 113 |
+
text = str(text).lower()
|
| 114 |
+
text = re.sub(r"[^a-z0-9\s@]", " ", text)
|
| 115 |
+
|
| 116 |
+
words = text.split()
|
| 117 |
+
canonical_words = []
|
| 118 |
+
|
| 119 |
+
for w in words:
|
| 120 |
+
# If the word starts with @ (e.g., @return), keep it as is
|
| 121 |
+
if w.startswith("@"):
|
| 122 |
+
canonical_words.append(w)
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
if w in self.stop_words and w not in self.code_keywords:
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
stemmed = self.stemmer.stem(w)
|
| 129 |
+
canonical_words.append(stemmed)
|
| 130 |
+
|
| 131 |
+
return " ".join(canonical_words).strip()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class TextProcessor:
|
| 135 |
+
"""
|
| 136 |
+
Standard text cleaning logic for final feature extraction (TF-IDF).
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, config: FeaturePipelineConfig, language: str = "english"):
|
| 140 |
+
self.config = config
|
| 141 |
+
self.stop_words = set(stopwords.words(language))
|
| 142 |
+
self.lemmatizer = WordNetLemmatizer()
|
| 143 |
+
|
| 144 |
+
def clean_text(self, text: str) -> str:
|
| 145 |
+
if pd.isna(text):
|
| 146 |
+
return ""
|
| 147 |
+
text = str(text).lower()
|
| 148 |
+
# Remove heavy code markers but keep text structure
|
| 149 |
+
text = re.sub(r"(^\s*//+|^\s*/\*+|\*/$)", "", text)
|
| 150 |
+
# Keep only alpha characters for NLP model (plus pipe for combo)
|
| 151 |
+
text = re.sub(r"[^a-z\s|]", " ", text)
|
| 152 |
+
tokens = text.split()
|
| 153 |
+
if self.config.use_stopwords:
|
| 154 |
+
tokens = [w for w in tokens if w not in self.stop_words]
|
| 155 |
+
if self.config.use_lemmatization:
|
| 156 |
+
tokens = [self.lemmatizer.lemmatize(w) for w in tokens]
|
| 157 |
+
return " ".join(tokens)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# --- AUGMENTATION ---
|
| 161 |
+
class SafeAugmenter:
|
| 162 |
+
"""
|
| 163 |
+
protects reserved keywords from synonym replacement.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, aug_prob=0.3):
|
| 167 |
+
self.aug_prob = aug_prob
|
| 168 |
+
self.protected_words = {
|
| 169 |
+
"return",
|
| 170 |
+
"public",
|
| 171 |
+
"private",
|
| 172 |
+
"void",
|
| 173 |
+
"class",
|
| 174 |
+
"static",
|
| 175 |
+
"final",
|
| 176 |
+
"if",
|
| 177 |
+
"else",
|
| 178 |
+
"for",
|
| 179 |
+
"while",
|
| 180 |
+
"try",
|
| 181 |
+
"catch",
|
| 182 |
+
"import",
|
| 183 |
+
"package",
|
| 184 |
+
"null",
|
| 185 |
+
"true",
|
| 186 |
+
"false",
|
| 187 |
+
"self",
|
| 188 |
+
"def",
|
| 189 |
+
"todo",
|
| 190 |
+
"fixme",
|
| 191 |
+
"param",
|
| 192 |
+
"throw",
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def get_synonyms(self, word):
|
| 196 |
+
synonyms = set()
|
| 197 |
+
for syn in wordnet.synsets(word):
|
| 198 |
+
for lemma in syn.lemmas():
|
| 199 |
+
name = lemma.name().replace("_", " ")
|
| 200 |
+
if name.isalpha() and name.lower() != word.lower():
|
| 201 |
+
synonyms.add(name)
|
| 202 |
+
return list(synonyms)
|
| 203 |
+
|
| 204 |
+
def augment(self, text: str) -> str:
|
| 205 |
+
if pd.isna(text) or not text:
|
| 206 |
+
return ""
|
| 207 |
+
words = text.split()
|
| 208 |
+
if len(words) < 2:
|
| 209 |
+
return text
|
| 210 |
+
new_words = []
|
| 211 |
+
for word in words:
|
| 212 |
+
word_lower = word.lower()
|
| 213 |
+
|
| 214 |
+
if word_lower in self.protected_words:
|
| 215 |
+
new_words.append(word)
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
# Random Case Injection (Noise)
|
| 219 |
+
if random.random() < 0.1:
|
| 220 |
+
if word[0].isupper():
|
| 221 |
+
new_words.append(word.lower())
|
| 222 |
+
else:
|
| 223 |
+
new_words.append(word.capitalize())
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# Synonym Replacement
|
| 227 |
+
if random.random() < self.aug_prob and len(word) > 3:
|
| 228 |
+
syns = self.get_synonyms(word_lower)
|
| 229 |
+
if syns:
|
| 230 |
+
replacement = random.choice(syns)
|
| 231 |
+
if word[0].isupper():
|
| 232 |
+
replacement = replacement.capitalize()
|
| 233 |
+
new_words.append(replacement)
|
| 234 |
+
else:
|
| 235 |
+
new_words.append(word)
|
| 236 |
+
else:
|
| 237 |
+
new_words.append(word)
|
| 238 |
+
return " ".join(new_words)
|
| 239 |
+
|
| 240 |
+
def apply_balancing(
|
| 241 |
+
self, df: pd.DataFrame, min_samples: int = 100
|
| 242 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 243 |
+
"""
|
| 244 |
+
Generates synthetic data for minority classes.
|
| 245 |
+
Returns: (Balanced DataFrame, Report DataFrame)
|
| 246 |
+
"""
|
| 247 |
+
df["temp_label_str"] = df[LABEL_COLUMN].astype(str)
|
| 248 |
+
counts = df["temp_label_str"].value_counts()
|
| 249 |
+
print(
|
| 250 |
+
f"\n [Balance Check - PRE] Min class size: {counts.min()} | Max: {counts.max()}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
existing_sentences = set(df["comment_sentence"].str.strip())
|
| 254 |
+
new_rows = []
|
| 255 |
+
report_rows = []
|
| 256 |
+
|
| 257 |
+
for label_str, count in counts.items():
|
| 258 |
+
if count < min_samples:
|
| 259 |
+
needed = min_samples - count
|
| 260 |
+
class_subset = df[df["temp_label_str"] == label_str]
|
| 261 |
+
if class_subset.empty:
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
samples = class_subset["comment_sentence"].tolist()
|
| 265 |
+
orig_label = class_subset[LABEL_COLUMN].iloc[0]
|
| 266 |
+
|
| 267 |
+
# Propagate 'combo' if present
|
| 268 |
+
orig_combo = None
|
| 269 |
+
if "combo" in class_subset.columns:
|
| 270 |
+
orig_combo = class_subset["combo"].iloc[0]
|
| 271 |
+
|
| 272 |
+
generated = 0
|
| 273 |
+
attempts = 0
|
| 274 |
+
# Cap attempts to avoid infinite loops if vocabulary is too small
|
| 275 |
+
while generated < needed and attempts < needed * 5:
|
| 276 |
+
attempts += 1
|
| 277 |
+
src = random.choice(samples)
|
| 278 |
+
aug_txt = self.augment(src).strip()
|
| 279 |
+
|
| 280 |
+
# Ensure Global Uniqueness
|
| 281 |
+
if aug_txt and aug_txt not in existing_sentences:
|
| 282 |
+
row = {
|
| 283 |
+
"comment_sentence": aug_txt,
|
| 284 |
+
LABEL_COLUMN: orig_label,
|
| 285 |
+
"partition": "train_aug",
|
| 286 |
+
"index": -1, # Placeholder
|
| 287 |
+
}
|
| 288 |
+
if orig_combo:
|
| 289 |
+
row["combo"] = orig_combo
|
| 290 |
+
|
| 291 |
+
new_rows.append(row)
|
| 292 |
+
report_rows.append(
|
| 293 |
+
{
|
| 294 |
+
"original_text": src,
|
| 295 |
+
"augmented_text": aug_txt,
|
| 296 |
+
"label": label_str,
|
| 297 |
+
"reason": f"Class has {count} samples (Target {min_samples})",
|
| 298 |
+
}
|
| 299 |
+
)
|
| 300 |
+
existing_sentences.add(aug_txt)
|
| 301 |
+
generated += 1
|
| 302 |
+
|
| 303 |
+
df = df.drop(columns=["temp_label_str"])
|
| 304 |
+
df_report = pd.DataFrame(report_rows)
|
| 305 |
+
|
| 306 |
+
if new_rows:
|
| 307 |
+
augmented_df = pd.concat([df, pd.DataFrame(new_rows)], ignore_index=True)
|
| 308 |
+
augmented_df["index"] = range(len(augmented_df))
|
| 309 |
+
|
| 310 |
+
temp_counts = augmented_df[LABEL_COLUMN].astype(str).value_counts()
|
| 311 |
+
print(
|
| 312 |
+
f" [Balance Check - POST] Min class size: {temp_counts.min()} | Max: {temp_counts.max()}"
|
| 313 |
+
)
|
| 314 |
+
return augmented_df, df_report
|
| 315 |
+
|
| 316 |
+
return df, df_report
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# --- CLEANING LOGIC ---
|
| 320 |
+
def clean_training_data_smart(
|
| 321 |
+
df: pd.DataFrame, min_len: int, max_len: int, language: str = "english"
|
| 322 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 323 |
+
"""
|
| 324 |
+
Performs 'Smart Cleaning' on the Training Set with language-specific heuristics.
|
| 325 |
+
"""
|
| 326 |
+
canon = TextCanonicalizer()
|
| 327 |
+
dropped_rows = []
|
| 328 |
+
|
| 329 |
+
print(f" [Clean] Computing heuristics (Language: {language})...")
|
| 330 |
+
df["canon_key"] = df["comment_sentence"].apply(canon.to_canonical)
|
| 331 |
+
|
| 332 |
+
# 1. Token Length Filter
|
| 333 |
+
def count_code_tokens(text):
|
| 334 |
+
return len([t for t in re.split(r"[^a-zA-Z0-9]+", str(text)) if t])
|
| 335 |
+
|
| 336 |
+
df["temp_token_len"] = df["comment_sentence"].apply(count_code_tokens)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
MIN_ALPHA_CHARS = 6
|
| 340 |
+
MAX_SYMBOL_RATIO = 0.50
|
| 341 |
+
|
| 342 |
+
# 2. Heuristic Filters (Tiny/Huge/Code)
|
| 343 |
+
def get_heuristics(text):
|
| 344 |
+
s = str(text).strip()
|
| 345 |
+
char_len = len(s)
|
| 346 |
+
if char_len == 0:
|
| 347 |
+
return False, False, 1.0
|
| 348 |
+
|
| 349 |
+
alpha_len = sum(1 for c in s if c.isalpha())
|
| 350 |
+
|
| 351 |
+
non_alnum_chars = sum(1 for c in s if not c.isalnum() and not c.isspace())
|
| 352 |
+
symbol_ratio = non_alnum_chars / char_len if char_len > 0 else 0
|
| 353 |
+
|
| 354 |
+
is_tiny = alpha_len < MIN_ALPHA_CHARS
|
| 355 |
+
is_huge = char_len > 800
|
| 356 |
+
is_code = symbol_ratio > MAX_SYMBOL_RATIO
|
| 357 |
+
|
| 358 |
+
return is_tiny, is_huge, is_code
|
| 359 |
+
|
| 360 |
+
heuristics = df["comment_sentence"].apply(get_heuristics)
|
| 361 |
+
df["is_tiny"] = [x[0] for x in heuristics]
|
| 362 |
+
df["is_huge"] = [x[1] for x in heuristics]
|
| 363 |
+
df["symbol_ratio"] = [x[2] for x in heuristics]
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
df["is_code"] = df["symbol_ratio"] > 0.50
|
| 367 |
+
|
| 368 |
+
mask_keep = (
|
| 369 |
+
(df["temp_token_len"] >= min_len)
|
| 370 |
+
& (df["temp_token_len"] <= max_len)
|
| 371 |
+
& (~df["is_tiny"])
|
| 372 |
+
& (~df["is_huge"])
|
| 373 |
+
& (~df["is_code"])
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
df_dropped_qual = df[~mask_keep].copy()
|
| 377 |
+
if not df_dropped_qual.empty:
|
| 378 |
+
def reason(row):
|
| 379 |
+
if row["is_tiny"]:
|
| 380 |
+
return f"Too Tiny (<{MIN_ALPHA_CHARS} alpha)"
|
| 381 |
+
if row["is_huge"]:
|
| 382 |
+
return "Too Huge (>800 chars)"
|
| 383 |
+
if row["is_code"]:
|
| 384 |
+
return f"Pure Code (>{int(MAX_SYMBOL_RATIO*100)}% symbols)"
|
| 385 |
+
return f"Token Count ({row['temp_token_len']})"
|
| 386 |
+
|
| 387 |
+
df_dropped_qual["drop_reason"] = df_dropped_qual.apply(reason, axis=1)
|
| 388 |
+
dropped_rows.append(df_dropped_qual)
|
| 389 |
+
|
| 390 |
+
df = df[mask_keep].copy()
|
| 391 |
+
|
| 392 |
+
# 3. Semantic Conflicts (Ambiguity)
|
| 393 |
+
df["label_s"] = df[LABEL_COLUMN].astype(str)
|
| 394 |
+
conflict_counts = df.groupby("canon_key")["label_s"].nunique()
|
| 395 |
+
conflicting_keys = conflict_counts[conflict_counts > 1].index
|
| 396 |
+
|
| 397 |
+
mask_conflicts = df["canon_key"].isin(conflicting_keys)
|
| 398 |
+
df_dropped_conflicts = df[mask_conflicts].copy()
|
| 399 |
+
if not df_dropped_conflicts.empty:
|
| 400 |
+
df_dropped_conflicts["drop_reason"] = "Semantic Conflict"
|
| 401 |
+
dropped_rows.append(df_dropped_conflicts)
|
| 402 |
+
|
| 403 |
+
df = df[~mask_conflicts].copy()
|
| 404 |
+
|
| 405 |
+
# 4. Exact Duplicates
|
| 406 |
+
mask_dupes = df.duplicated(subset=["comment_sentence"], keep="first")
|
| 407 |
+
df_dropped_dupes = df[mask_dupes].copy()
|
| 408 |
+
if not df_dropped_dupes.empty:
|
| 409 |
+
df_dropped_dupes["drop_reason"] = "Exact Duplicate"
|
| 410 |
+
dropped_rows.append(df_dropped_dupes)
|
| 411 |
+
|
| 412 |
+
df = df[~mask_dupes].copy()
|
| 413 |
+
|
| 414 |
+
# Cleanup columns
|
| 415 |
+
cols_to_drop = [
|
| 416 |
+
"canon_key",
|
| 417 |
+
"label_s",
|
| 418 |
+
"temp_token_len",
|
| 419 |
+
"is_tiny",
|
| 420 |
+
"is_huge",
|
| 421 |
+
"is_code",
|
| 422 |
+
"symbol_ratio"
|
| 423 |
+
]
|
| 424 |
+
df = df.drop(columns=cols_to_drop, errors="ignore")
|
| 425 |
+
|
| 426 |
+
if dropped_rows:
|
| 427 |
+
df_report = pd.concat(dropped_rows, ignore_index=True)
|
| 428 |
+
cols_rep = ["index", "comment_sentence", LABEL_COLUMN, "drop_reason"]
|
| 429 |
+
final_cols = [c for c in cols_rep if c in df_report.columns]
|
| 430 |
+
df_report = df_report[final_cols]
|
| 431 |
+
else:
|
| 432 |
+
df_report = pd.DataFrame(columns=["index", "comment_sentence", "drop_reason"])
|
| 433 |
+
|
| 434 |
+
print(f" [Clean] Removed {len(df_report)} rows. Final: {len(df)}.")
|
| 435 |
+
return df, df_report
|
| 436 |
+
|
| 437 |
+
# --- FEATURE ENGINEERING ---
|
| 438 |
+
class FeatureEngineer:
|
| 439 |
+
def __init__(self, config: FeaturePipelineConfig):
|
| 440 |
+
self.config = config
|
| 441 |
+
self.processor = TextProcessor(config=config)
|
| 442 |
+
self.tfidf_vectorizer = TfidfVectorizer(max_features=config.max_features)
|
| 443 |
+
|
| 444 |
+
def extract_features_for_check(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 445 |
+
"""Extracts metadata features for analysis."""
|
| 446 |
+
|
| 447 |
+
def analyze(text):
|
| 448 |
+
s = str(text)
|
| 449 |
+
words = s.split()
|
| 450 |
+
n_words = len(words)
|
| 451 |
+
if n_words == 0:
|
| 452 |
+
return 0, 0, 0
|
| 453 |
+
first_word = words[0].lower()
|
| 454 |
+
starts_verb = (
|
| 455 |
+
1
|
| 456 |
+
if first_word.endswith("s")
|
| 457 |
+
or first_word.startswith("get")
|
| 458 |
+
or first_word.startswith("set")
|
| 459 |
+
else 0
|
| 460 |
+
)
|
| 461 |
+
return (len(s), n_words, starts_verb)
|
| 462 |
+
|
| 463 |
+
metrics = df["comment_sentence"].apply(analyze)
|
| 464 |
+
df["f_length"] = [x[0] for x in metrics]
|
| 465 |
+
df["f_word_count"] = [x[1] for x in metrics]
|
| 466 |
+
df["f_starts_verb"] = [x[2] for x in metrics]
|
| 467 |
+
# Calculate MD5 hash for efficient exact duplicate detection in Deepchecks
|
| 468 |
+
df["text_hash"] = df["comment_sentence"].apply(
|
| 469 |
+
lambda x: hashlib.md5(str(x).encode()).hexdigest()
|
| 470 |
+
)
|
| 471 |
+
return df
|
| 472 |
+
|
| 473 |
+
def vectorize_and_select(self, df_train, df_test):
|
| 474 |
+
def clean_fn(x):
|
| 475 |
+
return re.sub(r"[^a-zA-Z\s]", "", str(x).lower())
|
| 476 |
+
|
| 477 |
+
X_train = self.tfidf_vectorizer.fit_transform(
|
| 478 |
+
df_train["comment_sentence"].apply(clean_fn)
|
| 479 |
+
)
|
| 480 |
+
y_train = np.stack(df_train[LABEL_COLUMN].values)
|
| 481 |
+
|
| 482 |
+
# Handling multi-label for Chi2 (using sum or max)
|
| 483 |
+
y_train_sum = (
|
| 484 |
+
y_train.sum(axis=1) if len(y_train.shape) > 1 else y_train
|
| 485 |
+
)
|
| 486 |
+
selector = SelectKBest(
|
| 487 |
+
chi2, k=min(self.config.max_features, X_train.shape[1])
|
| 488 |
+
)
|
| 489 |
+
X_train = selector.fit_transform(X_train, y_train_sum)
|
| 490 |
+
|
| 491 |
+
X_test = self.tfidf_vectorizer.transform(
|
| 492 |
+
df_test["comment_sentence"].apply(clean_fn)
|
| 493 |
+
)
|
| 494 |
+
X_test = selector.transform(X_test)
|
| 495 |
+
|
| 496 |
+
vocab = [
|
| 497 |
+
self.tfidf_vectorizer.get_feature_names_out()[i]
|
| 498 |
+
for i in selector.get_support(indices=True)
|
| 499 |
+
]
|
| 500 |
+
return X_train, X_test, vocab
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# --- MAIN EXECUTION ---
|
| 504 |
+
def main(
|
| 505 |
+
feature_dir: Path = typer.Option(
|
| 506 |
+
INTERIM_DATA_DIR / "features", help="Output dir."
|
| 507 |
+
),
|
| 508 |
+
reports_root: Path = typer.Option(
|
| 509 |
+
Path("reports/data"), help="Reports root."
|
| 510 |
+
),
|
| 511 |
+
max_features: int = typer.Option(5000),
|
| 512 |
+
min_comment_length: int = typer.Option(
|
| 513 |
+
2, help="Remove comments shorter than chars."
|
| 514 |
+
),
|
| 515 |
+
max_comment_length: int = typer.Option(300),
|
| 516 |
+
augment: bool = typer.Option(False, "--augment", help="Enable augmentation."),
|
| 517 |
+
balance_threshold: int = typer.Option(100, help="Min samples per class."),
|
| 518 |
+
run_vectorization: bool = typer.Option(False, "--run-vectorization"),
|
| 519 |
+
run_nlp_check: bool = typer.Option(
|
| 520 |
+
True, "--run-nlp", help="Run Deepchecks NLP suite."
|
| 521 |
+
),
|
| 522 |
+
custom_tags: str = typer.Option("base", help="Custom tags."),
|
| 523 |
+
save_full_csv: bool = typer.Option(False, "--save-full-csv"),
|
| 524 |
+
languages: List[str] = typer.Option(LANGS, show_default=False),
|
| 525 |
+
):
|
| 526 |
+
|
| 527 |
+
config = FeaturePipelineConfig(
|
| 528 |
+
True,
|
| 529 |
+
True,
|
| 530 |
+
True,
|
| 531 |
+
max_features,
|
| 532 |
+
min_comment_length,
|
| 533 |
+
max_comment_length,
|
| 534 |
+
augment,
|
| 535 |
+
custom_tags,
|
| 536 |
+
)
|
| 537 |
+
print(f"=== Pipeline ID: {config.hash_id} ===")
|
| 538 |
+
|
| 539 |
+
dm = DatasetManager()
|
| 540 |
+
full_dataset = dm.get_dataset()
|
| 541 |
+
fe = FeatureEngineer(config)
|
| 542 |
+
augmenter = SafeAugmenter()
|
| 543 |
+
|
| 544 |
+
feat_output_dir = feature_dir / config.hash_id
|
| 545 |
+
feat_output_dir.mkdir(parents=True, exist_ok=True)
|
| 546 |
+
report_output_dir = reports_root / config.hash_id
|
| 547 |
+
|
| 548 |
+
for lang in languages:
|
| 549 |
+
print(f"\n{'='*30}\nPROCESSING LANGUAGE: {lang.upper()}\n{'='*30}")
|
| 550 |
+
df_train = full_dataset[f"{lang}_train"].to_pandas()
|
| 551 |
+
df_test = full_dataset[f"{lang}_test"].to_pandas()
|
| 552 |
+
|
| 553 |
+
# Standardize Label Format
|
| 554 |
+
for df in [df_train, df_test]:
|
| 555 |
+
if isinstance(df[LABEL_COLUMN].iloc[0], str):
|
| 556 |
+
df[LABEL_COLUMN] = (
|
| 557 |
+
df[LABEL_COLUMN]
|
| 558 |
+
.str.replace(r"\s+", ", ", regex=True)
|
| 559 |
+
.apply(ast.literal_eval)
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
lang_report_dir = report_output_dir / lang
|
| 563 |
+
|
| 564 |
+
# 1. RAW AUDIT
|
| 565 |
+
print(" >>> Phase 1: Auditing RAW Data")
|
| 566 |
+
df_train_raw = fe.extract_features_for_check(df_train.copy())
|
| 567 |
+
df_test_raw = fe.extract_features_for_check(df_test.copy())
|
| 568 |
+
run_custom_deepchecks(
|
| 569 |
+
df_train_raw, df_test_raw, lang_report_dir, "raw", lang
|
| 570 |
+
)
|
| 571 |
+
if run_nlp_check:
|
| 572 |
+
run_targeted_nlp_checks(
|
| 573 |
+
df_train_raw, df_test_raw, lang_report_dir, "raw"
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# 2. CLEANING & AUGMENTATION
|
| 577 |
+
print("\n >>> Phase 2: Smart Cleaning & Augmentation")
|
| 578 |
+
df_train, df_dropped = clean_training_data_smart(
|
| 579 |
+
df_train, min_comment_length, max_comment_length, language=lang
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
if not df_dropped.empty:
|
| 583 |
+
dropped_path = lang_report_dir / "dropped_rows.csv"
|
| 584 |
+
df_dropped.to_csv(dropped_path, index=False)
|
| 585 |
+
print(f" [Report] Dropped rows details saved to: {dropped_path}")
|
| 586 |
+
|
| 587 |
+
if augment:
|
| 588 |
+
print(" [Augment] Applying Soft Balancing...")
|
| 589 |
+
df_train, df_aug_report = augmenter.apply_balancing(
|
| 590 |
+
df_train, min_samples=balance_threshold
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if not df_aug_report.empty:
|
| 594 |
+
aug_path = lang_report_dir / "augmentation_report.csv"
|
| 595 |
+
df_aug_report.to_csv(aug_path, index=False)
|
| 596 |
+
print(
|
| 597 |
+
f" [Report] Augmentation details saved to: {aug_path}"
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# 3. PROCESSED AUDIT
|
| 601 |
+
print("\n >>> Phase 3: Auditing PROCESSED Data")
|
| 602 |
+
df_train = fe.extract_features_for_check(df_train)
|
| 603 |
+
df_test = fe.extract_features_for_check(df_test)
|
| 604 |
+
run_custom_deepchecks(
|
| 605 |
+
df_train, df_test, lang_report_dir, "processed", lang
|
| 606 |
+
)
|
| 607 |
+
if run_nlp_check:
|
| 608 |
+
run_targeted_nlp_checks(
|
| 609 |
+
df_train, df_test, lang_report_dir, "processed"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# 4. FINAL PROCESSING & SAVING
|
| 613 |
+
print("\n >>> Phase 4: Final Processing & Save")
|
| 614 |
+
df_train["comment_clean"] = df_train["comment_sentence"].apply(
|
| 615 |
+
fe.processor.clean_text
|
| 616 |
+
)
|
| 617 |
+
df_test["comment_clean"] = df_test["comment_sentence"].apply(
|
| 618 |
+
fe.processor.clean_text
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if config.use_combo_feature:
|
| 622 |
+
if "combo" in df_train.columns:
|
| 623 |
+
df_train["combo_clean"] = df_train["combo"].apply(
|
| 624 |
+
fe.processor.clean_text
|
| 625 |
+
)
|
| 626 |
+
if "combo" in df_test.columns:
|
| 627 |
+
df_test["combo_clean"] = df_test["combo"].apply(
|
| 628 |
+
fe.processor.clean_text
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
X_train, X_test, vocab = None, None, []
|
| 632 |
+
if run_vectorization:
|
| 633 |
+
print(" [Vectorization] TF-IDF & Chi2...")
|
| 634 |
+
X_train, X_test, vocab = fe.vectorize_and_select(df_train, df_test)
|
| 635 |
+
def format_label_robust(lbl):
|
| 636 |
+
if hasattr(lbl, "tolist"): # Check if numpy array
|
| 637 |
+
lbl = lbl.tolist()
|
| 638 |
+
return str(lbl)
|
| 639 |
+
|
| 640 |
+
df_train[LABEL_COLUMN] = df_train[LABEL_COLUMN].apply(format_label_robust)
|
| 641 |
+
df_test[LABEL_COLUMN] = df_test[LABEL_COLUMN].apply(format_label_robust)
|
| 642 |
+
|
| 643 |
+
cols_to_save = [
|
| 644 |
+
"index",
|
| 645 |
+
LABEL_COLUMN,
|
| 646 |
+
"comment_sentence",
|
| 647 |
+
"comment_clean",
|
| 648 |
+
]
|
| 649 |
+
if "combo" in df_train.columns:
|
| 650 |
+
cols_to_save.append("combo")
|
| 651 |
+
if "combo_clean" in df_train.columns:
|
| 652 |
+
cols_to_save.append("combo_clean")
|
| 653 |
+
meta_cols = [c for c in df_train.columns if c.startswith("f_")]
|
| 654 |
+
cols_to_save.extend(meta_cols)
|
| 655 |
+
|
| 656 |
+
print(f" [Save] Columns: {cols_to_save}")
|
| 657 |
+
df_train[cols_to_save].to_csv(
|
| 658 |
+
feat_output_dir / f"{lang}_train.csv", index=False
|
| 659 |
+
)
|
| 660 |
+
df_test[cols_to_save].to_csv(
|
| 661 |
+
feat_output_dir / f"{lang}_test.csv", index=False
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
if run_vectorization and X_train is not None:
|
| 665 |
+
from scipy.sparse import save_npz
|
| 666 |
+
|
| 667 |
+
save_npz(feat_output_dir / f"{lang}_train_tfidf.npz", X_train)
|
| 668 |
+
save_npz(feat_output_dir / f"{lang}_test_tfidf.npz", X_test)
|
| 669 |
+
with open(
|
| 670 |
+
feat_output_dir / f"{lang}_vocab.txt", "w", encoding="utf-8"
|
| 671 |
+
) as f:
|
| 672 |
+
f.write("\n".join(vocab))
|
| 673 |
+
|
| 674 |
+
print(f"\nAll Done. Reports in: {report_output_dir}")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
if __name__ == "__main__":
|
| 678 |
+
typer.run(main)
|
turing/modeling/__init__.py
ADDED
|
File without changes
|
turing/modeling/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
turing/modeling/__pycache__/baseModel.cpython-312.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
turing/modeling/baseModel.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import mlflow
|
| 7 |
+
from numpy import ndarray
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseModel(ABC):
|
| 11 |
+
"""
|
| 12 |
+
Abstract base class for training models.
|
| 13 |
+
Subclasses should define the model and implement specific logic
|
| 14 |
+
for training, evaluation, and model persistence.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, language, path=None):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the trainer.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
language (str): Language for the model.
|
| 23 |
+
path (str, optional): Path to load a pre-trained model. Defaults to None.
|
| 24 |
+
If None, a new model is initialized.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
self.language = language
|
| 28 |
+
self.model = None
|
| 29 |
+
if path:
|
| 30 |
+
self.load(path)
|
| 31 |
+
else:
|
| 32 |
+
self.setup_model()
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def setup_model(self):
|
| 36 |
+
"""
|
| 37 |
+
Initialize or build the model.
|
| 38 |
+
Called in __init__ of subclass.
|
| 39 |
+
"""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def train(self, X_train, y_train) -> dict[str,any]:
|
| 44 |
+
"""
|
| 45 |
+
Main training logic for the model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
X_train: Input training data.
|
| 49 |
+
y_train: True labels for training data.
|
| 50 |
+
"""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def evaluate(self, X_test, y_test) -> dict[str,any]:
|
| 55 |
+
"""
|
| 56 |
+
Evaluation logic for the model.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
X_test: Input test data.
|
| 60 |
+
y_test: True labels for test data.
|
| 61 |
+
"""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def predict(self, X) -> ndarray:
|
| 66 |
+
"""
|
| 67 |
+
Make predictions using the trained model.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
X: Input data for prediction.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Predictions made by the model.
|
| 74 |
+
"""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
def save(self, path, model_name):
|
| 78 |
+
"""
|
| 79 |
+
Save model and log to MLflow.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
path (str): Path to save the model.
|
| 83 |
+
model_name (str): Name to use when saving the model (without extension).
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
if self.model is None:
|
| 87 |
+
raise ValueError("Model is not trained. Cannot save uninitialized model.")
|
| 88 |
+
|
| 89 |
+
complete_path = os.path.join(path, f"{model_name}_{self.language}")
|
| 90 |
+
if os.path.exists(complete_path) and os.path.isdir(complete_path):
|
| 91 |
+
shutil.rmtree(complete_path)
|
| 92 |
+
mlflow.sklearn.save_model(self.model, complete_path)
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
mlflow.log_artifact(complete_path)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Failed to log model to MLflow: {e}")
|
| 98 |
+
|
| 99 |
+
logger.info(f"Model saved to: {complete_path}")
|
| 100 |
+
|
| 101 |
+
def load(self, model_path):
|
| 102 |
+
"""
|
| 103 |
+
Load model from specified local path or mlflow model URI.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
model_path (str): Path to load the model from (local or mlflow URI).
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
self.model = mlflow.sklearn.load_model(model_path)
|
| 110 |
+
logger.info(f"Model loaded from: {model_path}")
|
| 111 |
+
|
turing/modeling/model_selector.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from mlflow.tracking import MlflowClient
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_best_model_by_tag(
|
| 8 |
+
language: str,
|
| 9 |
+
tag_key: str = "best_model",
|
| 10 |
+
metric: str = "f1_score"
|
| 11 |
+
) -> Optional[dict]:
|
| 12 |
+
"""
|
| 13 |
+
Retrieve the best model for a specific language using MLflow tags.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
language: Programming language (java, python, pharo)
|
| 17 |
+
tag_key: Tag key to search for (default: "best_model")
|
| 18 |
+
metric: Metric to use for ordering (default: "f1_score")
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Dict with run_id and artifact_name of the best model or None if not found
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
client = MlflowClient()
|
| 25 |
+
experiments = client.search_experiments()
|
| 26 |
+
if not experiments:
|
| 27 |
+
logger.error("No experiments found in MLflow")
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
runs = client.search_runs(
|
| 32 |
+
experiment_ids=[exp.experiment_id for exp in experiments],
|
| 33 |
+
filter_string=f"tags.{tag_key} = 'true' and tags.Language = '{language}'",
|
| 34 |
+
order_by=[f"metrics.{metric} DESC"],
|
| 35 |
+
max_results=1
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if not runs:
|
| 39 |
+
logger.warning(f"No runs found with tag '{tag_key}' for language '{language}'")
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
best_run = runs[0]
|
| 43 |
+
run_id = best_run.info.run_id
|
| 44 |
+
exp_name = client.get_experiment(best_run.info.experiment_id).name
|
| 45 |
+
run_name = best_run.info.run_name
|
| 46 |
+
artifact_name = best_run.data.tags.get("model_name")
|
| 47 |
+
model_id = best_run.data.tags.get("model_id")
|
| 48 |
+
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"run_id": run_id,
|
| 52 |
+
"artifact": artifact_name,
|
| 53 |
+
"model_id": model_id
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error searching for best model: {e}")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_best_model_info(
|
| 62 |
+
language: str,
|
| 63 |
+
fallback_registry: dict = None
|
| 64 |
+
) -> dict:
|
| 65 |
+
"""
|
| 66 |
+
Retrieve the best model information for a language.
|
| 67 |
+
First searches by tag, then falls back to hardcoded registry.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
language: Programming language
|
| 71 |
+
fallback_registry: Fallback registry with run_id and artifact
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Dict with run_id and artifact of the model
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
model_info = get_best_model_by_tag(language, "best_model")
|
| 78 |
+
|
| 79 |
+
if model_info:
|
| 80 |
+
logger.info(f"Using tagged best model for {language}")
|
| 81 |
+
return model_info
|
| 82 |
+
|
| 83 |
+
if fallback_registry and language in fallback_registry:
|
| 84 |
+
logger.warning(f"No tagged model found for {language}, using fallback registry")
|
| 85 |
+
return fallback_registry[language]
|
| 86 |
+
|
| 87 |
+
model_info = get_best_model_by_metric(language)
|
| 88 |
+
|
| 89 |
+
if model_info:
|
| 90 |
+
logger.warning(f"Using best model by metric for {language}")
|
| 91 |
+
return model_info
|
| 92 |
+
|
| 93 |
+
raise ValueError(f"No model found for language {language}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_best_model_by_metric(
|
| 97 |
+
language: str,
|
| 98 |
+
metric: str = "f1_score"
|
| 99 |
+
) -> Optional[dict]:
|
| 100 |
+
"""
|
| 101 |
+
Find the model with the best metric for a language.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
language: Programming language
|
| 105 |
+
metric: Metric to use for ordering
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Dict with run_id and artifact of the model or None
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
client = MlflowClient()
|
| 112 |
+
experiments = client.search_experiments()
|
| 113 |
+
if not experiments:
|
| 114 |
+
logger.error("No experiments found in MLflow")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
runs = client.search_runs(
|
| 119 |
+
experiment_ids=[exp.experiment_id for exp in experiments],
|
| 120 |
+
filter_string=f"tags.Language = '{language}'",
|
| 121 |
+
order_by=[f"metrics.{metric} DESC"],
|
| 122 |
+
max_results=1
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if not runs:
|
| 126 |
+
logger.warning(f"No runs found for language '{language}'")
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
best_run = runs[0]
|
| 130 |
+
run_id = best_run.info.run_id
|
| 131 |
+
exp_name = client.get_experiment(best_run.info.experiment_id).name
|
| 132 |
+
run_name = best_run.info.run_name
|
| 133 |
+
artifact_name = best_run.data.tags.get("model_name")
|
| 134 |
+
model_id = best_run.data.tags.get("model_id")
|
| 135 |
+
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"run_id": run_id,
|
| 139 |
+
"artifact": artifact_name,
|
| 140 |
+
"model_id": model_id
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Error finding best model by metric: {e}")
|
| 145 |
+
return None
|
turing/modeling/models/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model classes for code comment classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from turing.modeling.models.codeBerta import CodeBERTa
|
| 6 |
+
from turing.modeling.models.graphCodeBert import GraphCodeBERTClassifier
|
| 7 |
+
from turing.modeling.models.randomForestTfIdf import RandomForestTfIdf
|
| 8 |
+
from turing.modeling.models.tinyBert import TinyBERTClassifier
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"CodeBERTa",
|
| 12 |
+
"RandomForestTfIdf",
|
| 13 |
+
"TinyBERTClassifier",
|
| 14 |
+
"GraphCodeBERTClassifier",
|
| 15 |
+
]
|
turing/modeling/models/__pycache__/miniLM.cpython-312.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
turing/modeling/models/__pycache__/miniLmWithClassificationHead.cpython-312.pyc
ADDED
|
Binary file (1.37 kB). View file
|
|
|
turing/modeling/models/__pycache__/randomForestTfIdf.cpython-312.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
turing/modeling/models/codeBerta.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import mlflow
|
| 7 |
+
import numpy as np
|
| 8 |
+
from numpy import ndarray
|
| 9 |
+
from sklearn.metrics import (
|
| 10 |
+
accuracy_score,
|
| 11 |
+
classification_report,
|
| 12 |
+
f1_score,
|
| 13 |
+
precision_score,
|
| 14 |
+
recall_score,
|
| 15 |
+
)
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoModelForSequenceClassification,
|
| 20 |
+
AutoTokenizer,
|
| 21 |
+
EarlyStoppingCallback,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from turing.config import MODELS_DIR
|
| 27 |
+
|
| 28 |
+
from ..baseModel import BaseModel
|
| 29 |
+
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_metrics(eval_pred):
|
| 34 |
+
predictions, labels = eval_pred
|
| 35 |
+
|
| 36 |
+
# Sigmoid function to convert logits to probabilities
|
| 37 |
+
probs = 1 / (1 + np.exp(-predictions))
|
| 38 |
+
|
| 39 |
+
# Apply threshold of 0.5 (becomes 1 if > 0.5, otherwise 0)
|
| 40 |
+
preds = (probs > 0.5).astype(int)
|
| 41 |
+
|
| 42 |
+
# Calculate F1 score (macro average for multi-label)
|
| 43 |
+
f1 = f1_score(labels, preds, average='macro')
|
| 44 |
+
precision = precision_score(labels, preds, average='macro', zero_division=0)
|
| 45 |
+
recall = recall_score(labels, preds, average='macro', zero_division=0)
|
| 46 |
+
|
| 47 |
+
return {
|
| 48 |
+
'f1': f1,
|
| 49 |
+
'precision': precision,
|
| 50 |
+
'recall': recall,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CodeBERTaDataset(Dataset):
|
| 56 |
+
"""
|
| 57 |
+
Internal Dataset class for CodeBERTa.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, encodings, labels=None, num_labels=None):
|
| 61 |
+
"""
|
| 62 |
+
Initialize the InternalDataset.
|
| 63 |
+
Args:
|
| 64 |
+
encodings (dict): Tokenized encodings.
|
| 65 |
+
labels (list or np.ndarray, optional): Corresponding labels.
|
| 66 |
+
num_labels (int, optional): Total number of classes. Required for auto-converting indices to one-hot.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
self.encodings = {key: torch.tensor(val) for key, val in encodings.items()}
|
| 70 |
+
|
| 71 |
+
if labels is not None:
|
| 72 |
+
if not isinstance(labels, (np.ndarray, torch.Tensor)):
|
| 73 |
+
labels = np.array(labels)
|
| 74 |
+
|
| 75 |
+
# Case A: labels are indices (integers)
|
| 76 |
+
if num_labels is not None and (len(labels.shape) == 1 or (len(labels.shape) == 2 and labels.shape[1] == 1)):
|
| 77 |
+
labels_flat = labels.flatten()
|
| 78 |
+
|
| 79 |
+
# Create one-hot encoded matrix
|
| 80 |
+
one_hot = np.zeros((len(labels_flat), num_labels), dtype=np.float32)
|
| 81 |
+
|
| 82 |
+
# Set the corresponding index to 1
|
| 83 |
+
valid_indices = labels_flat < num_labels
|
| 84 |
+
one_hot[valid_indices, labels_flat[valid_indices]] = 1.0
|
| 85 |
+
|
| 86 |
+
self.labels = torch.tensor(one_hot, dtype=torch.float)
|
| 87 |
+
|
| 88 |
+
# Case B: labels are already vectors (e.g., One-Hot or Multi-Hot)
|
| 89 |
+
else:
|
| 90 |
+
self.labels = torch.tensor(labels, dtype=torch.float)
|
| 91 |
+
else:
|
| 92 |
+
self.labels = None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, idx):
|
| 96 |
+
"""
|
| 97 |
+
Retrieve item at index idx.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
idx (int): Index of the item to retrieve.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
dict: Dictionary containing input_ids, attention_mask, and labels (if available).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
item = {key: val[idx] for key, val in self.encodings.items()}
|
| 107 |
+
if self.labels is not None:
|
| 108 |
+
item['labels'] = self.labels[idx]
|
| 109 |
+
return item
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
"""
|
| 114 |
+
Return the length of the dataset.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
int: Length of the dataset.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
return len(self.encodings['input_ids'])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class CodeBERTa(BaseModel):
|
| 125 |
+
"""
|
| 126 |
+
HuggingFace implementation of BaseModel for Code Comment Classification.
|
| 127 |
+
Uses CodeBERTa-small-v1 for efficient inference.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, language, path=None):
|
| 131 |
+
"""
|
| 132 |
+
Initialize the CodeBERTa model with configuration parameters.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
language (str): Language for the model.
|
| 136 |
+
path (str, optional): Path to load a pre-trained model. Defaults to None.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
self.params = {
|
| 140 |
+
"model_name_hf": "huggingface/CodeBERTa-small-v1",
|
| 141 |
+
"num_labels": 7 if language == "java" else 5 if language == "python" else 6,
|
| 142 |
+
"max_length": 128,
|
| 143 |
+
"epochs": 15,
|
| 144 |
+
"batch_size_train": 16,
|
| 145 |
+
"batch_size_eval": 64,
|
| 146 |
+
"learning_rate": 1e-5,
|
| 147 |
+
"weight_decay": 0.02,
|
| 148 |
+
"train_size": 0.8,
|
| 149 |
+
"early_stopping_patience": 3,
|
| 150 |
+
"early_stopping_threshold": 0.005
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 154 |
+
self.tokenizer = None
|
| 155 |
+
|
| 156 |
+
super().__init__(language, path)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def setup_model(self):
|
| 160 |
+
"""
|
| 161 |
+
Initialize the CodeBERTa tokenizer and model.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
logger.info(f"Initializing {self.params['model_name_hf']} on {self.device}...")
|
| 165 |
+
|
| 166 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.params["model_name_hf"])
|
| 167 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 168 |
+
self.params["model_name_hf"],
|
| 169 |
+
num_labels=self.params["num_labels"],
|
| 170 |
+
problem_type="multi_label_classification"
|
| 171 |
+
).to(self.device)
|
| 172 |
+
logger.info("CodeBERTa model initialized.")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _tokenize(self, texts):
|
| 176 |
+
"""
|
| 177 |
+
Helper to tokenize list of texts efficiently.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
texts (list): List of text strings to tokenize.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
dict: Tokenized encodings.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
safe_texts = []
|
| 187 |
+
for t in texts:
|
| 188 |
+
if t is None:
|
| 189 |
+
safe_texts.append("")
|
| 190 |
+
elif isinstance(t, (int, float)):
|
| 191 |
+
if t != t: # NaN check
|
| 192 |
+
safe_texts.append("")
|
| 193 |
+
else:
|
| 194 |
+
safe_texts.append(str(t))
|
| 195 |
+
else:
|
| 196 |
+
safe_texts.append(str(t))
|
| 197 |
+
|
| 198 |
+
return self.tokenizer(
|
| 199 |
+
safe_texts,
|
| 200 |
+
truncation=True,
|
| 201 |
+
padding=True,
|
| 202 |
+
max_length=self.params["max_length"]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def train(self, X_train, y_train) -> dict[str,any]:
|
| 207 |
+
"""
|
| 208 |
+
Train the model using HF Trainer and log to MLflow.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
X_train (list): Training input texts.
|
| 212 |
+
y_train (list or np.ndarray): Training labels.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
dict[str, any]: Dictionary of parameters used for training.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
if self.model is None:
|
| 219 |
+
raise ValueError("Model is not initialized. Call setup_model() before training.")
|
| 220 |
+
|
| 221 |
+
# log parameters to MLflow without model_name_hf
|
| 222 |
+
params_to_log = {k: v for k, v in self.params.items() if k != "model_name_hf" and k != "num_labels"}
|
| 223 |
+
|
| 224 |
+
logger.info(f"Starting training for: {self.language.upper()}")
|
| 225 |
+
|
| 226 |
+
# Prepare dataset (train/val split)
|
| 227 |
+
train_encodings = self._tokenize(X_train)
|
| 228 |
+
full_dataset = CodeBERTaDataset(train_encodings, y_train, num_labels=self.params["num_labels"])
|
| 229 |
+
train_size = int(self.params["train_size"] * len(full_dataset))
|
| 230 |
+
val_size = len(full_dataset) - train_size
|
| 231 |
+
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
|
| 232 |
+
|
| 233 |
+
temp_ckpt_dir = os.path.join(MODELS_DIR, "temp_checkpoints")
|
| 234 |
+
|
| 235 |
+
use_fp16 = torch.cuda.is_available()
|
| 236 |
+
if not use_fp16:
|
| 237 |
+
logger.info("Mixed Precision (fp16) disabled because CUDA is not available.")
|
| 238 |
+
|
| 239 |
+
training_args = TrainingArguments(
|
| 240 |
+
output_dir=temp_ckpt_dir,
|
| 241 |
+
num_train_epochs=self.params["epochs"],
|
| 242 |
+
per_device_train_batch_size=self.params["batch_size_train"],
|
| 243 |
+
per_device_eval_batch_size=self.params["batch_size_eval"],
|
| 244 |
+
learning_rate=self.params["learning_rate"],
|
| 245 |
+
weight_decay=self.params["weight_decay"],
|
| 246 |
+
eval_strategy="epoch",
|
| 247 |
+
save_strategy="epoch",
|
| 248 |
+
load_best_model_at_end=True,
|
| 249 |
+
metric_for_best_model="f1",
|
| 250 |
+
greater_is_better=True,
|
| 251 |
+
save_total_limit=2,
|
| 252 |
+
logging_dir='./logs',
|
| 253 |
+
logging_steps=50,
|
| 254 |
+
fp16=use_fp16,
|
| 255 |
+
optim="adamw_torch",
|
| 256 |
+
report_to="none",
|
| 257 |
+
no_cuda=not torch.cuda.is_available()
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
trainer = Trainer(
|
| 261 |
+
model=self.model,
|
| 262 |
+
args=training_args,
|
| 263 |
+
train_dataset=train_dataset,
|
| 264 |
+
eval_dataset=val_dataset,
|
| 265 |
+
compute_metrics=compute_metrics,
|
| 266 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=self.params["early_stopping_patience"], early_stopping_threshold=self.params["early_stopping_threshold"])]
|
| 267 |
+
)
|
| 268 |
+
trainer.train()
|
| 269 |
+
logger.info(f"Training for {self.language.upper()} completed.")
|
| 270 |
+
|
| 271 |
+
if os.path.exists(temp_ckpt_dir):
|
| 272 |
+
shutil.rmtree(temp_ckpt_dir)
|
| 273 |
+
|
| 274 |
+
return params_to_log
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def evaluate(self, X_test, y_test) -> dict[str,any]:
|
| 278 |
+
"""
|
| 279 |
+
Evaluate model on test data, return metrics and log to MLflow.
|
| 280 |
+
Handles automatic conversion of y_test to match multi-label prediction shape.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
X_test (list): Input test data.
|
| 284 |
+
y_test (list or np.ndarray): True labels for test data.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
dict[str, any]: Dictionary of evaluation metrics.
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
# Obtain predictions
|
| 291 |
+
y_pred = self.predict(X_test)
|
| 292 |
+
|
| 293 |
+
# Convert y_test to numpy array if needed
|
| 294 |
+
if not isinstance(y_test, (np.ndarray, torch.Tensor)):
|
| 295 |
+
y_test_np = np.array(y_test)
|
| 296 |
+
elif isinstance(y_test, torch.Tensor):
|
| 297 |
+
y_test_np = y_test.cpu().numpy()
|
| 298 |
+
else:
|
| 299 |
+
y_test_np = y_test
|
| 300 |
+
|
| 301 |
+
num_labels = self.params["num_labels"]
|
| 302 |
+
is_multilabel_pred = (y_pred.ndim == 2 and y_pred.shape[1] > 1)
|
| 303 |
+
is_flat_truth = (y_test_np.ndim == 1) or (y_test_np.ndim == 2 and y_test_np.shape[1] == 1)
|
| 304 |
+
|
| 305 |
+
if is_multilabel_pred and is_flat_truth:
|
| 306 |
+
# Create a zero matrix
|
| 307 |
+
y_test_expanded = np.zeros((y_test_np.shape[0], num_labels), dtype=int)
|
| 308 |
+
|
| 309 |
+
# Flatten y_test for iteration
|
| 310 |
+
indices = y_test_np.flatten()
|
| 311 |
+
|
| 312 |
+
# Use indices to set the correct column to 1
|
| 313 |
+
for i, label_idx in enumerate(indices):
|
| 314 |
+
idx = int(label_idx)
|
| 315 |
+
if 0 <= idx < num_labels:
|
| 316 |
+
y_test_expanded[i, idx] = 1
|
| 317 |
+
|
| 318 |
+
y_test_np = y_test_expanded
|
| 319 |
+
|
| 320 |
+
# Generate classification report
|
| 321 |
+
report = classification_report(y_test_np, y_pred, zero_division=0)
|
| 322 |
+
print("\n" + "=" * 50)
|
| 323 |
+
print("CLASSIFICATION REPORT")
|
| 324 |
+
print(report)
|
| 325 |
+
print("=" * 50 + "\n")
|
| 326 |
+
|
| 327 |
+
metrics = {
|
| 328 |
+
"accuracy": accuracy_score(y_test_np, y_pred),
|
| 329 |
+
"precision": precision_score(y_test_np, y_pred, average="macro", zero_division=0),
|
| 330 |
+
"recall": recall_score(y_test_np, y_pred, average="macro", zero_division=0),
|
| 331 |
+
"f1_score": f1_score(y_test_np, y_pred, average="macro"),
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
mlflow.log_metrics(metrics)
|
| 335 |
+
|
| 336 |
+
logger.info(
|
| 337 |
+
f"Evaluation completed — Accuracy: {metrics['accuracy']:.3f}, F1: {metrics['f1_score']:.3f}"
|
| 338 |
+
)
|
| 339 |
+
return metrics
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def predict(self, X) -> ndarray:
|
| 343 |
+
"""
|
| 344 |
+
Make predictions for Multi-Label classification.
|
| 345 |
+
Returns Binary Matrix (Multi-Hot) where multiple classes can be 1.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
X (list): Input texts for prediction.
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
np.ndarray: Multi-Hot Encoded predictions (e.g., [[0, 1, 1, 0], ...])
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
if self.model is None:
|
| 355 |
+
raise ValueError("Model is not trained. Call train() or load() before prediction.")
|
| 356 |
+
|
| 357 |
+
# Set model to evaluation mode
|
| 358 |
+
self.model.eval()
|
| 359 |
+
|
| 360 |
+
encodings = self._tokenize(X)
|
| 361 |
+
# Pass None as labels because we are in inference
|
| 362 |
+
dataset = CodeBERTaDataset(encodings, labels=None)
|
| 363 |
+
|
| 364 |
+
use_fp16 = torch.cuda.is_available()
|
| 365 |
+
|
| 366 |
+
training_args = TrainingArguments(
|
| 367 |
+
output_dir="./pred_temp",
|
| 368 |
+
per_device_eval_batch_size=self.params["batch_size_eval"],
|
| 369 |
+
fp16=use_fp16,
|
| 370 |
+
report_to="none",
|
| 371 |
+
no_cuda=not torch.cuda.is_available()
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
trainer = Trainer(model=self.model, args=training_args)
|
| 375 |
+
output = trainer.predict(dataset)
|
| 376 |
+
|
| 377 |
+
# Clean up temporary prediction directory
|
| 378 |
+
if os.path.exists("./pred_temp"):
|
| 379 |
+
shutil.rmtree("./pred_temp")
|
| 380 |
+
|
| 381 |
+
# Convert logits to probabilities
|
| 382 |
+
logits = output.predictions
|
| 383 |
+
probs = 1 / (1 + np.exp(-logits))
|
| 384 |
+
|
| 385 |
+
# Apply a threshold of 0.5 (if prob > 0.5, predict 1 else 0)
|
| 386 |
+
preds_binary = (probs > 0.5).astype(int)
|
| 387 |
+
|
| 388 |
+
return preds_binary
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def save(self, path, model_name):
|
| 392 |
+
"""
|
| 393 |
+
Save model locally and log to MLflow as artifact.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
path (str): Directory path to save the model.
|
| 397 |
+
model_name (str): Name for the saved model.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
if self.model is None:
|
| 401 |
+
raise ValueError("Model is not trained. Cannot save uninitialized model.")
|
| 402 |
+
|
| 403 |
+
# Local Saving
|
| 404 |
+
complete_path = os.path.join(path, f"{model_name}_{self.language}")
|
| 405 |
+
|
| 406 |
+
# Remove existing directory if it exists
|
| 407 |
+
if os.path.exists(complete_path) and os.path.isdir(complete_path):
|
| 408 |
+
shutil.rmtree(complete_path)
|
| 409 |
+
|
| 410 |
+
# Save model and tokenizer
|
| 411 |
+
logger.info(f"Saving model to: {complete_path}")
|
| 412 |
+
self.model.save_pretrained(complete_path)
|
| 413 |
+
self.tokenizer.save_pretrained(complete_path)
|
| 414 |
+
logger.info("Model saved locally.")
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
# Log to MLflow
|
| 418 |
+
logger.info("Logging artifacts to MLflow...")
|
| 419 |
+
mlflow.log_artifacts(local_dir=complete_path, artifact_path=f"{model_name}_{self.language}")
|
| 420 |
+
except Exception as e:
|
| 421 |
+
logger.error(f"Failed to log model artifacts to MLflow: {e}")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def load(self, model_path):
|
| 425 |
+
"""
|
| 426 |
+
Load model from a local path OR an MLflow URI.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
model_path (str): Local path or MLflow URI to load the model from.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
logger.info(f"Loading model from: {model_path}")
|
| 433 |
+
local_model_path = model_path
|
| 434 |
+
|
| 435 |
+
# Downloading model from MLflow and saving to local path
|
| 436 |
+
if model_path.startswith("models:/") or model_path.startswith("runs:/"):
|
| 437 |
+
try:
|
| 438 |
+
logger.info("Detected MLflow model URI. Attempting to load from MLflow...")
|
| 439 |
+
local_model_path = os.path.join(MODELS_DIR, "mlflow_temp_models")
|
| 440 |
+
local_model_path = mlflow.artifacts.download_artifacts(artifact_uri=model_path, dst_path=local_model_path)
|
| 441 |
+
logger.info(f"Model downloaded from MLflow to: {local_model_path}")
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.error(f"Failed to load from MLflow: {e}")
|
| 444 |
+
raise e
|
| 445 |
+
|
| 446 |
+
# Loading from local path
|
| 447 |
+
try:
|
| 448 |
+
if not os.path.exists(local_model_path):
|
| 449 |
+
raise FileNotFoundError(f"Model path not found: {local_model_path}")
|
| 450 |
+
|
| 451 |
+
# Load tokenizer and model from local path
|
| 452 |
+
self.tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
| 453 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 454 |
+
local_model_path
|
| 455 |
+
).to(self.device)
|
| 456 |
+
logger.info("Model loaded from local path successfully.")
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"Failed to load model from local path: {e}")
|
| 460 |
+
raise e
|
| 461 |
+
|
| 462 |
+
# Set model to evaluation mode
|
| 463 |
+
self.model.eval()
|
turing/modeling/models/graphCodeBert.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import mlflow
|
| 7 |
+
import numpy as np
|
| 8 |
+
from numpy import ndarray
|
| 9 |
+
from sklearn.metrics import (
|
| 10 |
+
accuracy_score,
|
| 11 |
+
classification_report,
|
| 12 |
+
f1_score,
|
| 13 |
+
precision_score,
|
| 14 |
+
recall_score,
|
| 15 |
+
)
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoModelForSequenceClassification,
|
| 20 |
+
AutoTokenizer,
|
| 21 |
+
EarlyStoppingCallback,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from turing.config import MODELS_DIR
|
| 27 |
+
|
| 28 |
+
from ..baseModel import BaseModel
|
| 29 |
+
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_metrics(eval_pred):
|
| 34 |
+
predictions, labels = eval_pred
|
| 35 |
+
|
| 36 |
+
# Sigmoid function to convert logits to probabilities
|
| 37 |
+
probs = 1 / (1 + np.exp(-predictions))
|
| 38 |
+
|
| 39 |
+
# Apply threshold of 0.5 (becomes 1 if > 0.5, otherwise 0)
|
| 40 |
+
preds = (probs > 0.5).astype(int)
|
| 41 |
+
|
| 42 |
+
# Calculate F1 score (macro average for multi-label)
|
| 43 |
+
f1 = f1_score(labels, preds, average="macro")
|
| 44 |
+
precision = precision_score(labels, preds, average="macro", zero_division=0)
|
| 45 |
+
recall = recall_score(labels, preds, average="macro", zero_division=0)
|
| 46 |
+
|
| 47 |
+
return {
|
| 48 |
+
"f1": f1,
|
| 49 |
+
"precision": precision,
|
| 50 |
+
"recall": recall,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class GraphCodeBERTDataset(Dataset):
|
| 55 |
+
"""
|
| 56 |
+
Internal Dataset class for GraphCodeBERT.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, encodings, labels=None, num_labels=None):
|
| 60 |
+
"""
|
| 61 |
+
Initialize the InternalDataset.
|
| 62 |
+
Args:
|
| 63 |
+
encodings (dict): Tokenized encodings.
|
| 64 |
+
labels (list or np.ndarray, optional): Corresponding labels.
|
| 65 |
+
num_labels (int, optional): Total number of classes. Required for auto-converting indices to one-hot.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
self.encodings = {key: torch.tensor(val) for key, val in encodings.items()}
|
| 69 |
+
|
| 70 |
+
if labels is not None:
|
| 71 |
+
if not isinstance(labels, (np.ndarray, torch.Tensor)):
|
| 72 |
+
labels = np.array(labels)
|
| 73 |
+
|
| 74 |
+
# Case A: labels are indices (integers)
|
| 75 |
+
if num_labels is not None and (
|
| 76 |
+
len(labels.shape) == 1 or (len(labels.shape) == 2 and labels.shape[1] == 1)
|
| 77 |
+
):
|
| 78 |
+
labels_flat = labels.flatten()
|
| 79 |
+
|
| 80 |
+
# Create one-hot encoded matrix
|
| 81 |
+
one_hot = np.zeros((len(labels_flat), num_labels), dtype=np.float32)
|
| 82 |
+
|
| 83 |
+
# Set the corresponding index to 1
|
| 84 |
+
valid_indices = labels_flat < num_labels
|
| 85 |
+
one_hot[valid_indices, labels_flat[valid_indices]] = 1.0
|
| 86 |
+
|
| 87 |
+
self.labels = torch.tensor(one_hot, dtype=torch.float)
|
| 88 |
+
|
| 89 |
+
# Case B: labels are already vectors (e.g., One-Hot or Multi-Hot)
|
| 90 |
+
else:
|
| 91 |
+
self.labels = torch.tensor(labels, dtype=torch.float)
|
| 92 |
+
else:
|
| 93 |
+
self.labels = None
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, idx):
|
| 96 |
+
"""
|
| 97 |
+
Retrieve item at index idx.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
idx (int): Index of the item to retrieve.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
dict: Dictionary containing input_ids, attention_mask, and labels (if available).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
item = {key: val[idx] for key, val in self.encodings.items()}
|
| 107 |
+
if self.labels is not None:
|
| 108 |
+
item["labels"] = self.labels[idx]
|
| 109 |
+
return item
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
"""
|
| 113 |
+
Return the length of the dataset.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
int: Length of the dataset.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return len(self.encodings["input_ids"])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class GraphCodeBERTClassifier(BaseModel):
|
| 123 |
+
"""
|
| 124 |
+
HuggingFace implementation of BaseModel for Code Comment Classification.
|
| 125 |
+
Uses GraphCodeBERT (microsoft/graphcodebert-base) for code understanding via data flow graphs.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, language, path=None):
|
| 129 |
+
"""
|
| 130 |
+
Initialize the GraphCodeBERT model with configuration parameters.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
language (str): Language for the model.
|
| 134 |
+
path (str, optional): Path to load a pre-trained model. Defaults to None.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
self.params = {
|
| 138 |
+
"model_name_hf": "microsoft/graphcodebert-base",
|
| 139 |
+
"num_labels": 7 if language == "java" else 5 if language == "python" else 6,
|
| 140 |
+
"max_length": 256,
|
| 141 |
+
"epochs": 15,
|
| 142 |
+
"batch_size_train": 16,
|
| 143 |
+
"batch_size_eval": 64,
|
| 144 |
+
"learning_rate": 2e-5,
|
| 145 |
+
"weight_decay": 0.01,
|
| 146 |
+
"train_size": 0.8,
|
| 147 |
+
"early_stopping_patience": 3,
|
| 148 |
+
"early_stopping_threshold": 0.0,
|
| 149 |
+
"warmup_steps": 500,
|
| 150 |
+
"seed": 42,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 154 |
+
self.tokenizer = None
|
| 155 |
+
|
| 156 |
+
super().__init__(language, path)
|
| 157 |
+
|
| 158 |
+
def setup_model(self):
|
| 159 |
+
"""
|
| 160 |
+
Initialize the GraphCodeBERT tokenizer and model.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
logger.info(f"Initializing {self.params['model_name_hf']} on {self.device}...")
|
| 164 |
+
|
| 165 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.params["model_name_hf"])
|
| 166 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 167 |
+
self.params["model_name_hf"],
|
| 168 |
+
num_labels=self.params["num_labels"],
|
| 169 |
+
problem_type="multi_label_classification",
|
| 170 |
+
use_safetensors=True, # Force use of safetensors for security
|
| 171 |
+
).to(self.device)
|
| 172 |
+
logger.info("GraphCodeBERT model initialized.")
|
| 173 |
+
|
| 174 |
+
def _tokenize(self, texts):
|
| 175 |
+
"""
|
| 176 |
+
Helper to tokenize list of texts efficiently.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
texts (list): List of text strings to tokenize.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
dict: Tokenized encodings.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
safe_texts = []
|
| 186 |
+
for t in texts:
|
| 187 |
+
if t is None:
|
| 188 |
+
safe_texts.append("")
|
| 189 |
+
elif isinstance(t, (int, float)):
|
| 190 |
+
if t != t: # NaN check
|
| 191 |
+
safe_texts.append("")
|
| 192 |
+
else:
|
| 193 |
+
safe_texts.append(str(t))
|
| 194 |
+
else:
|
| 195 |
+
safe_texts.append(str(t))
|
| 196 |
+
|
| 197 |
+
return self.tokenizer(
|
| 198 |
+
safe_texts, truncation=True, padding=True, max_length=self.params["max_length"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def train(self, X_train, y_train) -> dict[str, any]:
|
| 202 |
+
"""
|
| 203 |
+
Train the model using HF Trainer and log to MLflow.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
X_train (list): Training input texts.
|
| 207 |
+
y_train (list or np.ndarray): Training labels.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
dict[str, any]: Dictionary of parameters used for training.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
if self.model is None:
|
| 214 |
+
raise ValueError("Model is not initialized. Call setup_model() before training.")
|
| 215 |
+
|
| 216 |
+
# log parameters to MLflow without model_name_hf
|
| 217 |
+
params_to_log = {
|
| 218 |
+
k: v for k, v in self.params.items() if k != "model_name_hf" and k != "num_labels"
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
logger.info(f"Starting training for: {self.language.upper()}")
|
| 222 |
+
|
| 223 |
+
# Prepare dataset (train/val split)
|
| 224 |
+
train_encodings = self._tokenize(X_train)
|
| 225 |
+
full_dataset = GraphCodeBERTDataset(
|
| 226 |
+
train_encodings, y_train, num_labels=self.params["num_labels"]
|
| 227 |
+
)
|
| 228 |
+
train_size = int(self.params["train_size"] * len(full_dataset))
|
| 229 |
+
val_size = len(full_dataset) - train_size
|
| 230 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
| 231 |
+
full_dataset, [train_size, val_size]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
temp_ckpt_dir = os.path.join(MODELS_DIR, "temp_checkpoints")
|
| 235 |
+
|
| 236 |
+
use_fp16 = torch.cuda.is_available()
|
| 237 |
+
if not use_fp16:
|
| 238 |
+
logger.info("Mixed Precision (fp16) disabled because CUDA is not available.")
|
| 239 |
+
|
| 240 |
+
training_args = TrainingArguments(
|
| 241 |
+
output_dir=temp_ckpt_dir,
|
| 242 |
+
num_train_epochs=self.params["epochs"],
|
| 243 |
+
per_device_train_batch_size=self.params["batch_size_train"],
|
| 244 |
+
per_device_eval_batch_size=self.params["batch_size_eval"],
|
| 245 |
+
learning_rate=self.params["learning_rate"],
|
| 246 |
+
weight_decay=self.params["weight_decay"],
|
| 247 |
+
eval_strategy="epoch",
|
| 248 |
+
save_strategy="epoch",
|
| 249 |
+
load_best_model_at_end=True,
|
| 250 |
+
metric_for_best_model="f1",
|
| 251 |
+
greater_is_better=True,
|
| 252 |
+
save_total_limit=2,
|
| 253 |
+
logging_dir="./logs",
|
| 254 |
+
logging_steps=50,
|
| 255 |
+
fp16=use_fp16,
|
| 256 |
+
optim="adamw_torch",
|
| 257 |
+
report_to="none",
|
| 258 |
+
no_cuda=not torch.cuda.is_available(),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
trainer = Trainer(
|
| 262 |
+
model=self.model,
|
| 263 |
+
args=training_args,
|
| 264 |
+
train_dataset=train_dataset,
|
| 265 |
+
eval_dataset=val_dataset,
|
| 266 |
+
compute_metrics=compute_metrics,
|
| 267 |
+
callbacks=[
|
| 268 |
+
EarlyStoppingCallback(
|
| 269 |
+
early_stopping_patience=self.params["early_stopping_patience"],
|
| 270 |
+
early_stopping_threshold=self.params["early_stopping_threshold"],
|
| 271 |
+
)
|
| 272 |
+
],
|
| 273 |
+
)
|
| 274 |
+
trainer.train()
|
| 275 |
+
logger.info(f"Training for {self.language.upper()} completed.")
|
| 276 |
+
|
| 277 |
+
if os.path.exists(temp_ckpt_dir):
|
| 278 |
+
shutil.rmtree(temp_ckpt_dir)
|
| 279 |
+
|
| 280 |
+
return params_to_log
|
| 281 |
+
|
| 282 |
+
def evaluate(self, X_test, y_test) -> dict[str, any]:
|
| 283 |
+
"""
|
| 284 |
+
Evaluate model on test data, return metrics and log to MLflow.
|
| 285 |
+
Handles automatic conversion of y_test to match multi-label prediction shape.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
X_test (list): Input test data.
|
| 289 |
+
y_test (list or np.ndarray): True labels for test data.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
dict[str, any]: Dictionary of evaluation metrics.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
# Obtain predictions
|
| 296 |
+
y_pred = self.predict(X_test)
|
| 297 |
+
|
| 298 |
+
# Convert y_test to numpy array if needed
|
| 299 |
+
if not isinstance(y_test, (np.ndarray, torch.Tensor)):
|
| 300 |
+
y_test_np = np.array(y_test)
|
| 301 |
+
elif isinstance(y_test, torch.Tensor):
|
| 302 |
+
y_test_np = y_test.cpu().numpy()
|
| 303 |
+
else:
|
| 304 |
+
y_test_np = y_test
|
| 305 |
+
|
| 306 |
+
num_labels = self.params["num_labels"]
|
| 307 |
+
is_multilabel_pred = y_pred.ndim == 2 and y_pred.shape[1] > 1
|
| 308 |
+
is_flat_truth = (y_test_np.ndim == 1) or (y_test_np.ndim == 2 and y_test_np.shape[1] == 1)
|
| 309 |
+
|
| 310 |
+
if is_multilabel_pred and is_flat_truth:
|
| 311 |
+
# Create a zero matrix
|
| 312 |
+
y_test_expanded = np.zeros((y_test_np.shape[0], num_labels), dtype=int)
|
| 313 |
+
|
| 314 |
+
# Flatten y_test for iteration
|
| 315 |
+
indices = y_test_np.flatten()
|
| 316 |
+
|
| 317 |
+
# Use indices to set the correct column to 1
|
| 318 |
+
for i, label_idx in enumerate(indices):
|
| 319 |
+
idx = int(label_idx)
|
| 320 |
+
if 0 <= idx < num_labels:
|
| 321 |
+
y_test_expanded[i, idx] = 1
|
| 322 |
+
|
| 323 |
+
y_test_np = y_test_expanded
|
| 324 |
+
|
| 325 |
+
# Generate classification report
|
| 326 |
+
report = classification_report(y_test_np, y_pred, zero_division=0)
|
| 327 |
+
print("\n" + "=" * 50)
|
| 328 |
+
print("CLASSIFICATION REPORT")
|
| 329 |
+
print(report)
|
| 330 |
+
print("=" * 50 + "\n")
|
| 331 |
+
|
| 332 |
+
metrics = {
|
| 333 |
+
"accuracy": accuracy_score(y_test_np, y_pred),
|
| 334 |
+
"precision": precision_score(y_test_np, y_pred, average="macro", zero_division=0),
|
| 335 |
+
"recall": recall_score(y_test_np, y_pred, average="macro", zero_division=0),
|
| 336 |
+
"f1_score": f1_score(y_test_np, y_pred, average="macro", zero_division=0),
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
mlflow.log_metrics(metrics)
|
| 340 |
+
|
| 341 |
+
logger.info(
|
| 342 |
+
f"Evaluation completed — Accuracy: {metrics['accuracy']:.3f}, F1: {metrics['f1_score']:.3f}"
|
| 343 |
+
)
|
| 344 |
+
return metrics
|
| 345 |
+
|
| 346 |
+
def predict(self, X) -> ndarray:
|
| 347 |
+
"""
|
| 348 |
+
Make predictions for Multi-Label classification.
|
| 349 |
+
Returns Binary Matrix (Multi-Hot) where multiple classes can be 1.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
X (list): Input texts for prediction.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
np.ndarray: Multi-Hot Encoded predictions (e.g., [[0, 1, 1, 0], ...])
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
if self.model is None:
|
| 359 |
+
raise ValueError("Model is not trained. Call train() or load() before prediction.")
|
| 360 |
+
|
| 361 |
+
# Set model to evaluation mode
|
| 362 |
+
self.model.eval()
|
| 363 |
+
|
| 364 |
+
encodings = self._tokenize(X)
|
| 365 |
+
# Pass None as labels because we are in inference
|
| 366 |
+
dataset = GraphCodeBERTDataset(encodings, labels=None)
|
| 367 |
+
|
| 368 |
+
use_fp16 = torch.cuda.is_available()
|
| 369 |
+
|
| 370 |
+
training_args = TrainingArguments(
|
| 371 |
+
output_dir="./pred_temp",
|
| 372 |
+
per_device_eval_batch_size=self.params["batch_size_eval"],
|
| 373 |
+
fp16=use_fp16,
|
| 374 |
+
report_to="none",
|
| 375 |
+
no_cuda=not torch.cuda.is_available(),
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
trainer = Trainer(model=self.model, args=training_args)
|
| 379 |
+
output = trainer.predict(dataset)
|
| 380 |
+
|
| 381 |
+
# Clean up temporary prediction directory
|
| 382 |
+
if os.path.exists("./pred_temp"):
|
| 383 |
+
shutil.rmtree("./pred_temp")
|
| 384 |
+
|
| 385 |
+
# Convert logits to probabilities
|
| 386 |
+
logits = output.predictions
|
| 387 |
+
probs = 1 / (1 + np.exp(-logits))
|
| 388 |
+
|
| 389 |
+
# Apply a threshold of 0.5 (if prob > 0.5, predict 1 else 0)
|
| 390 |
+
preds_binary = (probs > 0.5).astype(int)
|
| 391 |
+
|
| 392 |
+
return preds_binary
|
| 393 |
+
|
| 394 |
+
def save(self, path, model_name):
|
| 395 |
+
"""
|
| 396 |
+
Save model locally and log to MLflow as artifact.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
path (str): Directory path to save the model.
|
| 400 |
+
model_name (str): Name for the saved model.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
if self.model is None:
|
| 404 |
+
raise ValueError("Model is not trained. Cannot save uninitialized model.")
|
| 405 |
+
|
| 406 |
+
# Local Saving
|
| 407 |
+
complete_path = os.path.join(path, f"{model_name}_{self.language}")
|
| 408 |
+
|
| 409 |
+
# Remove existing directory if it exists
|
| 410 |
+
if os.path.exists(complete_path) and os.path.isdir(complete_path):
|
| 411 |
+
shutil.rmtree(complete_path)
|
| 412 |
+
|
| 413 |
+
# Save model and tokenizer
|
| 414 |
+
logger.info(f"Saving model to: {complete_path}")
|
| 415 |
+
self.model.save_pretrained(complete_path)
|
| 416 |
+
self.tokenizer.save_pretrained(complete_path)
|
| 417 |
+
logger.info("Model saved locally.")
|
| 418 |
+
|
| 419 |
+
try:
|
| 420 |
+
# Log to MLflow
|
| 421 |
+
logger.info("Logging artifacts to MLflow...")
|
| 422 |
+
mlflow.log_artifacts(
|
| 423 |
+
local_dir=complete_path, artifact_path=f"{model_name}_{self.language}"
|
| 424 |
+
)
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.error(f"Failed to log model artifacts to MLflow: {e}")
|
| 427 |
+
|
| 428 |
+
def load(self, model_path):
|
| 429 |
+
"""
|
| 430 |
+
Load model from a local path OR an MLflow URI.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
model_path (str): Local path or MLflow URI to load the model from.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
logger.info(f"Loading model from: {model_path}")
|
| 437 |
+
local_model_path = model_path
|
| 438 |
+
|
| 439 |
+
# Downloading model from MLflow and saving to local path
|
| 440 |
+
if model_path.startswith("models:/") or model_path.startswith("runs:/"):
|
| 441 |
+
try:
|
| 442 |
+
logger.info("Detected MLflow model URI. Attempting to load from MLflow...")
|
| 443 |
+
local_model_path = os.path.join(MODELS_DIR, "mlflow_temp_models")
|
| 444 |
+
local_model_path = mlflow.artifacts.download_artifacts(
|
| 445 |
+
artifact_uri=model_path, dst_path=local_model_path
|
| 446 |
+
)
|
| 447 |
+
logger.info(f"Model downloaded from MLflow to: {local_model_path}")
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logger.error(f"Failed to load from MLflow: {e}")
|
| 450 |
+
raise e
|
| 451 |
+
|
| 452 |
+
# Loading from local path
|
| 453 |
+
try:
|
| 454 |
+
if not os.path.exists(local_model_path):
|
| 455 |
+
raise FileNotFoundError(f"Model path not found: {local_model_path}")
|
| 456 |
+
|
| 457 |
+
# Load tokenizer and model from local path
|
| 458 |
+
self.tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
| 459 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(local_model_path).to(
|
| 460 |
+
self.device
|
| 461 |
+
)
|
| 462 |
+
logger.info("Model loaded from local path successfully.")
|
| 463 |
+
|
| 464 |
+
except Exception as e:
|
| 465 |
+
logger.error(f"Failed to load model from local path: {e}")
|
| 466 |
+
raise e
|
| 467 |
+
|
| 468 |
+
# Set model to evaluation mode
|
| 469 |
+
self.model.eval()
|
turing/modeling/models/randomForestTfIdf.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 7 |
+
from sklearn.metrics import (
|
| 8 |
+
accuracy_score,
|
| 9 |
+
classification_report,
|
| 10 |
+
f1_score,
|
| 11 |
+
precision_score,
|
| 12 |
+
recall_score,
|
| 13 |
+
)
|
| 14 |
+
from sklearn.model_selection import GridSearchCV
|
| 15 |
+
from sklearn.multioutput import MultiOutputClassifier
|
| 16 |
+
from sklearn.pipeline import Pipeline
|
| 17 |
+
|
| 18 |
+
from ..baseModel import BaseModel
|
| 19 |
+
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RandomForestTfIdf(BaseModel):
|
| 24 |
+
"""
|
| 25 |
+
Sklearn implementation of BaseModel with integrated Grid Search.
|
| 26 |
+
Builds a TF-IDF + RandomForest pipeline for multi-output text classification.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, language, path=None):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the RandomForestTfIdf model with configuration parameters.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
language (str): Language for the model.
|
| 35 |
+
path (str, optional): Path to load a pre-trained model. Defaults to None.
|
| 36 |
+
If None, a new model is initialized.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
self.params = {"stop_words": "english", "random_state": 42, "cv_folds": 5}
|
| 40 |
+
|
| 41 |
+
self.grid_params = {
|
| 42 |
+
"clf__estimator__n_estimators": [50, 100, 200],
|
| 43 |
+
"clf__estimator__max_depth": [None, 10, 20],
|
| 44 |
+
"tfidf__max_features": [3000, 5000, 8000],
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
super().__init__(language, path)
|
| 48 |
+
|
| 49 |
+
def setup_model(self):
|
| 50 |
+
"""
|
| 51 |
+
Initialize the scikit-learn pipeline with TF-IDF vectorizer and RandomForest classifier.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
base_estimator = RandomForestClassifier(
|
| 55 |
+
random_state=self.params["random_state"], n_jobs=-1
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.pipeline = Pipeline(
|
| 59 |
+
[
|
| 60 |
+
(
|
| 61 |
+
"tfidf",
|
| 62 |
+
TfidfVectorizer(ngram_range=(1, 2), stop_words=self.params["stop_words"]),
|
| 63 |
+
),
|
| 64 |
+
("clf", MultiOutputClassifier(base_estimator, n_jobs=-1)),
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.model = self.pipeline
|
| 69 |
+
logger.info("Scikit-learn pipeline initialized.")
|
| 70 |
+
|
| 71 |
+
def train(self, X_train, y_train) -> dict[str, any]:
|
| 72 |
+
"""
|
| 73 |
+
Train the model using Grid Search to find the best hyperparameters.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
X_train: Input training data.
|
| 77 |
+
y_train: True labels for training data.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
if self.model is None:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
"Model pipeline is not initialized. Call setup_model() before training."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
logger.info(f"Starting training for: {self.language.upper()}")
|
| 86 |
+
logger.info("Performing Grid Search for best hyperparameters...")
|
| 87 |
+
grid_search = GridSearchCV(
|
| 88 |
+
self.pipeline,
|
| 89 |
+
param_grid=self.grid_params,
|
| 90 |
+
cv=self.params["cv_folds"],
|
| 91 |
+
scoring="f1_weighted",
|
| 92 |
+
n_jobs=-1,
|
| 93 |
+
verbose=1,
|
| 94 |
+
)
|
| 95 |
+
grid_search.fit(X_train, y_train)
|
| 96 |
+
|
| 97 |
+
logger.success(f"Best params found: {grid_search.best_params_}")
|
| 98 |
+
|
| 99 |
+
parameters_to_log = {
|
| 100 |
+
"max_features": grid_search.best_params_["tfidf__max_features"],
|
| 101 |
+
"n_estimators": grid_search.best_params_["clf__estimator__n_estimators"],
|
| 102 |
+
"max_depth": grid_search.best_params_["clf__estimator__max_depth"],
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
self.model = grid_search.best_estimator_
|
| 106 |
+
logger.success(f"Training for {self.language.upper()} completed.")
|
| 107 |
+
|
| 108 |
+
return parameters_to_log
|
| 109 |
+
|
| 110 |
+
def evaluate(self, X_test, y_test) -> dict[str, any]:
|
| 111 |
+
"""
|
| 112 |
+
Evaluate model on test data and return metrics.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
X_test: Input test data.
|
| 116 |
+
y_test: True labels for test data.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
y_pred = self.predict(X_test)
|
| 120 |
+
|
| 121 |
+
report = classification_report(y_test, y_pred, zero_division=0)
|
| 122 |
+
print("\n" + "=" * 50)
|
| 123 |
+
print("CLASSIFICATION REPORT")
|
| 124 |
+
print(report)
|
| 125 |
+
print("=" * 50 + "\n")
|
| 126 |
+
|
| 127 |
+
metrics = {
|
| 128 |
+
"accuracy": accuracy_score(y_test, y_pred),
|
| 129 |
+
"precision": precision_score(y_test, y_pred, average="macro", zero_division=0),
|
| 130 |
+
"recall": recall_score(y_test, y_pred, average="macro", zero_division=0),
|
| 131 |
+
"f1_score": f1_score(y_test, y_pred, average="weighted"),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
logger.info(
|
| 135 |
+
f"Evaluation completed — Accuracy: {metrics['accuracy']:.3f}, F1: {metrics['f1_score']:.3f}"
|
| 136 |
+
)
|
| 137 |
+
return metrics
|
| 138 |
+
|
| 139 |
+
def predict(self, X) -> ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Make predictions using the trained model.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
X: Input data for prediction.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Predictions made by the model.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
if self.model is None:
|
| 151 |
+
raise ValueError("Model is not trained. Call train() or load() before prediction.")
|
| 152 |
+
|
| 153 |
+
return self.model.predict(X)
|
turing/modeling/models/tinyBert.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ultra-lightweight multi-label text classification model for code comment analysis.
|
| 3 |
+
|
| 4 |
+
This module implements a specialized neural architecture combining TinyBERT
|
| 5 |
+
(15MB, 96 layers compressed) with a custom multi-label classification head.
|
| 6 |
+
Designed for efficient inference on resource-constrained environments while
|
| 7 |
+
maintaining competitive performance on code comment classification tasks.
|
| 8 |
+
|
| 9 |
+
Architecture:
|
| 10 |
+
- Encoder: TinyBERT (prajjwal1/bert-tiny)
|
| 11 |
+
- Hidden dimension: 312
|
| 12 |
+
- Classification layers: 312 -> 128 (ReLU) -> num_labels (Sigmoid)
|
| 13 |
+
- Regularization: Dropout(0.2) for preventing overfitting
|
| 14 |
+
- Loss function: Binary Cross-Entropy for multi-label classification
|
| 15 |
+
|
| 16 |
+
Performance characteristics:
|
| 17 |
+
- Model size: ~15MB
|
| 18 |
+
- Inference latency: ~50ms per sample
|
| 19 |
+
- Memory footprint: ~200MB during training
|
| 20 |
+
- Supports multi-label outputs via sigmoid activation
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from typing import List
|
| 24 |
+
|
| 25 |
+
from loguru import logger
|
| 26 |
+
import numpy as np
|
| 27 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 28 |
+
import torch
|
| 29 |
+
from torch import nn
|
| 30 |
+
from torch.optim import Adam
|
| 31 |
+
|
| 32 |
+
import turing.config as config
|
| 33 |
+
from turing.modeling.baseModel import BaseModel
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from transformers import AutoModel, AutoTokenizer
|
| 37 |
+
except ImportError:
|
| 38 |
+
logger.error("transformers library required. Install with: pip install transformers torch")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TinyBERTClassifier(BaseModel):
|
| 42 |
+
"""
|
| 43 |
+
Ultra-lightweight multi-label classifier for code comment analysis.
|
| 44 |
+
|
| 45 |
+
Combines TinyBERT encoder with a custom classification head optimized for
|
| 46 |
+
multi-label code comment classification across Java, Python, and Pharo.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
device (torch.device): Computation device (CPU/GPU).
|
| 50 |
+
model (nn.ModuleDict): Container for encoder and classifier components.
|
| 51 |
+
tokenizer (AutoTokenizer): Hugging Face tokenizer for text preprocessing.
|
| 52 |
+
classifier (nn.Sequential): Custom multi-label classification head.
|
| 53 |
+
num_labels (int): Number of output classes per language.
|
| 54 |
+
labels_map (list): Mapping of label indices to semantic categories.
|
| 55 |
+
|
| 56 |
+
References:
|
| 57 |
+
TinyBERT: https://huggingface.co/prajjwal1/bert-tiny
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, language: str, path: str = None):
|
| 61 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
+
logger.info(f"TinyBERT using device: {self.device}")
|
| 63 |
+
self.model = None
|
| 64 |
+
self.tokenizer = None
|
| 65 |
+
self.classifier = None
|
| 66 |
+
self.mlb = MultiLabelBinarizer()
|
| 67 |
+
self.labels_map = config.LABELS_MAP.get(language, [])
|
| 68 |
+
self.num_labels = len(self.labels_map)
|
| 69 |
+
self.params = {
|
| 70 |
+
"model": "TinyBERT",
|
| 71 |
+
"model_size": "15MB",
|
| 72 |
+
"epochs": 15,
|
| 73 |
+
"batch_size": 8,
|
| 74 |
+
"learning_rate": 1e-3,
|
| 75 |
+
}
|
| 76 |
+
super().__init__(language=language, path=path)
|
| 77 |
+
|
| 78 |
+
def setup_model(self):
|
| 79 |
+
"""
|
| 80 |
+
Initialize TinyBERT encoder and custom classification head.
|
| 81 |
+
|
| 82 |
+
Loads the pre-trained TinyBERT model from Hugging Face model hub and
|
| 83 |
+
constructs a custom multi-label classification head with:
|
| 84 |
+
- Input: 312-dimensional encoder embeddings [CLS] token
|
| 85 |
+
- Hidden layer: 128 units with ReLU activation
|
| 86 |
+
- Dropout: 0.2 for regularization
|
| 87 |
+
- Output: num_labels units with Sigmoid activation
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
Exception: If model initialization fails due to network or missing dependencies.
|
| 91 |
+
"""
|
| 92 |
+
self._initialize_model()
|
| 93 |
+
|
| 94 |
+
def _initialize_model(self):
|
| 95 |
+
"""
|
| 96 |
+
Initialize TinyBERT encoder and custom classification head.
|
| 97 |
+
|
| 98 |
+
Loads the pre-trained TinyBERT model from Hugging Face model hub and
|
| 99 |
+
constructs a custom multi-label classification head with:
|
| 100 |
+
- Input: 312-dimensional encoder embeddings [CLS] token
|
| 101 |
+
- Hidden layer: 128 units with ReLU activation
|
| 102 |
+
- Dropout: 0.2 for regularization
|
| 103 |
+
- Output: num_labels units with Sigmoid activation
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
Exception: If model initialization fails due to network or missing dependencies.
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
model_name = "prajjwal1/bert-tiny"
|
| 110 |
+
|
| 111 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 112 |
+
encoder = AutoModel.from_pretrained(model_name)
|
| 113 |
+
encoder.to(self.device)
|
| 114 |
+
|
| 115 |
+
hidden_dim = encoder.config.hidden_size
|
| 116 |
+
|
| 117 |
+
self.classifier = nn.Sequential(
|
| 118 |
+
nn.Linear(hidden_dim, 128),
|
| 119 |
+
nn.ReLU(),
|
| 120 |
+
nn.Dropout(0.2),
|
| 121 |
+
nn.Linear(128, self.num_labels),
|
| 122 |
+
nn.Sigmoid(),
|
| 123 |
+
).to(self.device)
|
| 124 |
+
|
| 125 |
+
self.model = nn.ModuleDict({"encoder": encoder, "classifier": self.classifier})
|
| 126 |
+
|
| 127 |
+
logger.success(f"Initialized TinyBERTClassifier for {self.language}")
|
| 128 |
+
logger.info(f"Model size: ~15MB | Labels: {self.num_labels}")
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error initializing model: {e}")
|
| 132 |
+
raise
|
| 133 |
+
|
| 134 |
+
def train(
|
| 135 |
+
self,
|
| 136 |
+
X_train: List[str],
|
| 137 |
+
y_train: np.ndarray,
|
| 138 |
+
path: str = None,
|
| 139 |
+
model_name: str = "tinybert_classifier",
|
| 140 |
+
epochs: int = 15,
|
| 141 |
+
batch_size: int = 8,
|
| 142 |
+
learning_rate: float = 1e-3,
|
| 143 |
+
) -> dict:
|
| 144 |
+
"""
|
| 145 |
+
Train the classifier using binary cross-entropy loss.
|
| 146 |
+
|
| 147 |
+
Implements gradient descent optimization with adaptive learning rate scheduling.
|
| 148 |
+
Supports checkpoint saving for model persistence and recovery.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
X_train (List[str]): Training text samples (code comments).
|
| 152 |
+
y_train (np.ndarray): Binary label matrix of shape (n_samples, n_labels).
|
| 153 |
+
path (str, optional): Directory path for model checkpoint saving.
|
| 154 |
+
model_name (str): Identifier for saved model artifacts.
|
| 155 |
+
epochs (int): Number of complete training iterations. Default: 3.
|
| 156 |
+
batch_size (int): Number of samples per gradient update. Default: 16.
|
| 157 |
+
learning_rate (float): Adam optimizer learning rate. Default: 2e-5.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
dict: Training configuration including hyperparameters and model metadata.
|
| 161 |
+
|
| 162 |
+
Raises:
|
| 163 |
+
Exception: If training fails due to data inconsistency or resource exhaustion.
|
| 164 |
+
"""
|
| 165 |
+
try:
|
| 166 |
+
if self.model is None:
|
| 167 |
+
self._initialize_model()
|
| 168 |
+
|
| 169 |
+
optimizer = Adam(self.classifier.parameters(), lr=learning_rate)
|
| 170 |
+
criterion = nn.BCELoss()
|
| 171 |
+
|
| 172 |
+
num_samples = len(X_train)
|
| 173 |
+
num_batches = (num_samples + batch_size - 1) // batch_size
|
| 174 |
+
|
| 175 |
+
logger.info(f"Starting training: {epochs} epochs, {num_batches} batches per epoch")
|
| 176 |
+
|
| 177 |
+
for epoch in range(epochs):
|
| 178 |
+
total_loss = 0.0
|
| 179 |
+
|
| 180 |
+
for batch_idx in range(num_batches):
|
| 181 |
+
start_idx = batch_idx * batch_size
|
| 182 |
+
end_idx = min(start_idx + batch_size, num_samples)
|
| 183 |
+
|
| 184 |
+
batch_texts = X_train[start_idx:end_idx]
|
| 185 |
+
batch_labels = y_train[start_idx:end_idx]
|
| 186 |
+
|
| 187 |
+
optimizer.zero_grad()
|
| 188 |
+
|
| 189 |
+
tokens = self.tokenizer(
|
| 190 |
+
batch_texts,
|
| 191 |
+
padding=True,
|
| 192 |
+
truncation=True,
|
| 193 |
+
max_length=128,
|
| 194 |
+
return_tensors="pt",
|
| 195 |
+
).to(self.device)
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
encoder_output = self.model["encoder"](**tokens)
|
| 199 |
+
cls_token = encoder_output.last_hidden_state[:, 0, :]
|
| 200 |
+
|
| 201 |
+
logits = self.classifier(cls_token)
|
| 202 |
+
|
| 203 |
+
labels_tensor = torch.tensor(batch_labels, dtype=torch.float32).to(self.device)
|
| 204 |
+
loss = criterion(logits, labels_tensor)
|
| 205 |
+
|
| 206 |
+
loss.backward()
|
| 207 |
+
optimizer.step()
|
| 208 |
+
|
| 209 |
+
total_loss += loss.item()
|
| 210 |
+
|
| 211 |
+
avg_loss = total_loss / num_batches
|
| 212 |
+
logger.info(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}")
|
| 213 |
+
|
| 214 |
+
logger.success(f"Training completed for {self.language}")
|
| 215 |
+
|
| 216 |
+
if path:
|
| 217 |
+
self.save(path, model_name)
|
| 218 |
+
|
| 219 |
+
return {
|
| 220 |
+
"epochs": epochs,
|
| 221 |
+
"batch_size": batch_size,
|
| 222 |
+
"learning_rate": learning_rate,
|
| 223 |
+
"model_size_mb": 15,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logger.error(f"Error training model: {e}")
|
| 228 |
+
raise
|
| 229 |
+
|
| 230 |
+
def predict(self, texts: List[str], threshold: float = 0.3) -> np.ndarray:
|
| 231 |
+
"""
|
| 232 |
+
Generate multi-label predictions for code comments.
|
| 233 |
+
|
| 234 |
+
Performs inference in evaluation mode without gradient computation.
|
| 235 |
+
Applies probability threshold to convert sigmoid outputs to binary labels.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
texts (List[str]): Code comment samples for classification.
|
| 239 |
+
threshold (float): Decision boundary for label assignment. Default: 0.5.
|
| 240 |
+
Values below threshold are mapped to 0, above to 1.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
np.ndarray: Binary predictions matrix of shape (n_samples, n_labels).
|
| 244 |
+
|
| 245 |
+
Raises:
|
| 246 |
+
ValueError: If model is not initialized.
|
| 247 |
+
Exception: If inference fails due to incompatible input dimensions.
|
| 248 |
+
"""
|
| 249 |
+
if self.model is None:
|
| 250 |
+
raise ValueError("Model not initialized. Train or load a model first.")
|
| 251 |
+
|
| 252 |
+
self.model.eval()
|
| 253 |
+
predictions = []
|
| 254 |
+
|
| 255 |
+
# Convert various types to list: pandas Series, Dataset Column, etc.
|
| 256 |
+
if hasattr(texts, "tolist"):
|
| 257 |
+
texts = texts.tolist()
|
| 258 |
+
elif hasattr(texts, "__iter__") and not isinstance(texts, list):
|
| 259 |
+
texts = list(texts)
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
tokens = self.tokenizer(
|
| 264 |
+
texts, padding=True, truncation=True, max_length=128, return_tensors="pt"
|
| 265 |
+
).to(self.device)
|
| 266 |
+
|
| 267 |
+
encoder_output = self.model["encoder"](**tokens)
|
| 268 |
+
cls_token = encoder_output.last_hidden_state[:, 0, :]
|
| 269 |
+
|
| 270 |
+
logits = self.classifier(cls_token)
|
| 271 |
+
probabilities = logits.cpu().numpy()
|
| 272 |
+
|
| 273 |
+
predictions = (probabilities > threshold).astype(int)
|
| 274 |
+
|
| 275 |
+
return predictions
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"Error during prediction: {e}")
|
| 279 |
+
raise
|
| 280 |
+
|
| 281 |
+
def evaluate(self, X_test: List[str], y_test: np.ndarray) -> dict:
|
| 282 |
+
"""
|
| 283 |
+
Evaluate classification performance on test set.
|
| 284 |
+
|
| 285 |
+
Computes per-label and macro-averaged metrics:
|
| 286 |
+
- Precision: TP / (TP + FP) - correctness of positive predictions
|
| 287 |
+
- Recall: TP / (TP + FN) - coverage of actual positive instances
|
| 288 |
+
- F1-Score: 2 * (P * R) / (P + R) - harmonic mean of precision and recall
|
| 289 |
+
- Accuracy: Per-sample exact match rate
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
X_test (List[str]): Test text samples for evaluation.
|
| 293 |
+
y_test (np.ndarray): Ground truth binary label matrix or indices.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
dict: Evaluation metrics including f1_score, precision, recall, accuracy.
|
| 297 |
+
|
| 298 |
+
Raises:
|
| 299 |
+
Exception: If evaluation fails due to prediction errors.
|
| 300 |
+
"""
|
| 301 |
+
try:
|
| 302 |
+
predictions = self.predict(X_test)
|
| 303 |
+
|
| 304 |
+
# Convert y_test to numpy array if needed
|
| 305 |
+
if not isinstance(y_test, (np.ndarray, torch.Tensor)):
|
| 306 |
+
y_test_np = np.array(y_test)
|
| 307 |
+
elif isinstance(y_test, torch.Tensor):
|
| 308 |
+
y_test_np = y_test.cpu().numpy()
|
| 309 |
+
else:
|
| 310 |
+
y_test_np = y_test
|
| 311 |
+
|
| 312 |
+
# Handle conversion from flat indices to multi-hot encoding if needed
|
| 313 |
+
is_multilabel_pred = predictions.ndim == 2 and predictions.shape[1] > 1
|
| 314 |
+
is_flat_truth = (y_test_np.ndim == 1) or (
|
| 315 |
+
y_test_np.ndim == 2 and y_test_np.shape[1] == 1
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if is_multilabel_pred and is_flat_truth:
|
| 319 |
+
# Create zero matrix for multi-hot encoding
|
| 320 |
+
y_test_expanded = np.zeros((y_test_np.shape[0], self.num_labels), dtype=int)
|
| 321 |
+
indices = y_test_np.flatten()
|
| 322 |
+
|
| 323 |
+
# Set columns to 1 based on indices
|
| 324 |
+
for i, label_idx in enumerate(indices):
|
| 325 |
+
idx = int(label_idx)
|
| 326 |
+
if 0 <= idx < self.num_labels:
|
| 327 |
+
y_test_expanded[i, idx] = 1
|
| 328 |
+
|
| 329 |
+
y_test_np = y_test_expanded
|
| 330 |
+
|
| 331 |
+
tp = np.sum((predictions == 1) & (y_test_np == 1), axis=0)
|
| 332 |
+
fp = np.sum((predictions == 1) & (y_test_np == 0), axis=0)
|
| 333 |
+
fn = np.sum((predictions == 0) & (y_test_np == 1), axis=0)
|
| 334 |
+
|
| 335 |
+
precision_per_label = tp / (tp + fp + 1e-10)
|
| 336 |
+
recall_per_label = tp / (tp + fn + 1e-10)
|
| 337 |
+
f1_per_label = (
|
| 338 |
+
2
|
| 339 |
+
* (precision_per_label * recall_per_label)
|
| 340 |
+
/ (precision_per_label + recall_per_label + 1e-10)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
metrics = {
|
| 344 |
+
"f1_score": float(np.mean(f1_per_label)),
|
| 345 |
+
"precision": float(np.mean(precision_per_label)),
|
| 346 |
+
"recall": float(np.mean(recall_per_label)),
|
| 347 |
+
"accuracy": float(np.mean(predictions == y_test_np)),
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
logger.info(f"Evaluation metrics: {metrics}")
|
| 351 |
+
return metrics
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.error(f"Error evaluating model: {e}")
|
| 355 |
+
raise
|
| 356 |
+
|
| 357 |
+
def save(self, path: str, model_name: str = "tinybert_classifier"):
|
| 358 |
+
"""
|
| 359 |
+
Persist model artifacts including weights, tokenizer, and configuration.
|
| 360 |
+
|
| 361 |
+
Saves the following components:
|
| 362 |
+
- classifier.pt: PyTorch state dictionary of classification head
|
| 363 |
+
- tokenizer configuration: Hugging Face tokenizer files
|
| 364 |
+
- config.json: Model metadata and label mappings
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
path (str): Parent directory for model checkpoint storage.
|
| 368 |
+
model_name (str): Model identifier used as subdirectory name.
|
| 369 |
+
|
| 370 |
+
Raises:
|
| 371 |
+
Exception: If file I/O or serialization fails.
|
| 372 |
+
"""
|
| 373 |
+
try:
|
| 374 |
+
import os
|
| 375 |
+
|
| 376 |
+
model_path = os.path.join(path, model_name)
|
| 377 |
+
os.makedirs(model_path, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
if self.classifier:
|
| 380 |
+
torch.save(self.classifier.state_dict(), os.path.join(model_path, "classifier.pt"))
|
| 381 |
+
|
| 382 |
+
if self.tokenizer:
|
| 383 |
+
self.tokenizer.save_pretrained(model_path)
|
| 384 |
+
|
| 385 |
+
config_data = {
|
| 386 |
+
"language": self.language,
|
| 387 |
+
"num_labels": self.num_labels,
|
| 388 |
+
"labels_map": self.labels_map,
|
| 389 |
+
"model_type": "tinybert_classifier",
|
| 390 |
+
"model_name": model_name,
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
import json
|
| 394 |
+
|
| 395 |
+
with open(os.path.join(model_path, "config.json"), "w") as f:
|
| 396 |
+
json.dump(config_data, f, indent=2)
|
| 397 |
+
|
| 398 |
+
logger.success(f"Model saved to {model_path}")
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"Error saving model: {e}")
|
| 402 |
+
raise
|
| 403 |
+
|
| 404 |
+
def load(self, path: str):
|
| 405 |
+
"""
|
| 406 |
+
Restore model state from checkpoint directory.
|
| 407 |
+
|
| 408 |
+
Loads classifier weights from serialized PyTorch tensors and reinitializes
|
| 409 |
+
the tokenizer from saved configuration. Restores language-specific label
|
| 410 |
+
mappings from JSON metadata.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
path (str): Directory containing model checkpoint files.
|
| 414 |
+
|
| 415 |
+
Raises:
|
| 416 |
+
Exception: If file not found or deserialization fails.
|
| 417 |
+
"""
|
| 418 |
+
try:
|
| 419 |
+
import json
|
| 420 |
+
import os
|
| 421 |
+
|
| 422 |
+
self._initialize_model()
|
| 423 |
+
|
| 424 |
+
classifier_path = os.path.join(path, "classifier.pt")
|
| 425 |
+
if os.path.exists(classifier_path):
|
| 426 |
+
self.classifier.load_state_dict(
|
| 427 |
+
torch.load(classifier_path, map_location=self.device)
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
config_path = os.path.join(path, "config.json")
|
| 431 |
+
if os.path.exists(config_path):
|
| 432 |
+
with open(config_path, "r") as f:
|
| 433 |
+
config_data = json.load(f)
|
| 434 |
+
self.language = config_data.get("language", self.language)
|
| 435 |
+
self.labels_map = config_data.get("labels_map", self.labels_map)
|
| 436 |
+
|
| 437 |
+
logger.success(f"Model loaded from {path}")
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
logger.error(f"Error loading model: {e}")
|
| 441 |
+
raise
|
turing/modeling/predict.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import dagshub
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import mlflow
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from turing.config import INPUT_COLUMN, LABELS_MAP, LANGS, MODEL_CONFIG, MODELS_DIR
|
| 11 |
+
from turing.dataset import DatasetManager
|
| 12 |
+
from turing.modeling.model_selector import get_best_model_info
|
| 13 |
+
from turing.modeling.models.codeBerta import CodeBERTa
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ModelInference:
|
| 17 |
+
# Model Configuration (Fallback Registry)
|
| 18 |
+
FALLBACK_MODEL_REGISTRY = {
|
| 19 |
+
"java": {
|
| 20 |
+
"run_id": "446f4459780347da8c796e619129be37",
|
| 21 |
+
"artifact": "fine-tuned-CodeBERTa_java",
|
| 22 |
+
"model_id": "codeberta",
|
| 23 |
+
},
|
| 24 |
+
"python": {
|
| 25 |
+
"run_id": "ef5fd8ebf33a412087dcf02afd9e3147",
|
| 26 |
+
"artifact": "fine-tuned-CodeBERTa_python",
|
| 27 |
+
"model_id": "codeberta",
|
| 28 |
+
},
|
| 29 |
+
"pharo": {
|
| 30 |
+
"run_id": "97822c6d84fc40c5b2363c9201a39997",
|
| 31 |
+
"artifact": "fine-tuned-CodeBERTa_pharo",
|
| 32 |
+
"model_id": "codeberta",
|
| 33 |
+
},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def __init__(self, repo_owner="se4ai2526-uniba", repo_name="Turing", use_best_model_tags=True):
|
| 38 |
+
dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True)
|
| 39 |
+
warnings.filterwarnings("ignore")
|
| 40 |
+
self.dataset_manager = DatasetManager()
|
| 41 |
+
self.use_best_model_tags = use_best_model_tags
|
| 42 |
+
|
| 43 |
+
# Initialize model registry based on configuration
|
| 44 |
+
if use_best_model_tags:
|
| 45 |
+
logger.info("Using MLflow tags to find best models")
|
| 46 |
+
|
| 47 |
+
self.model_registry = {}
|
| 48 |
+
for lang in LANGS:
|
| 49 |
+
try:
|
| 50 |
+
model_info = get_best_model_info(
|
| 51 |
+
lang, fallback_registry=self.FALLBACK_MODEL_REGISTRY
|
| 52 |
+
)
|
| 53 |
+
self.model_registry[lang] = model_info
|
| 54 |
+
logger.info(f"Loaded model info for {lang}: {model_info}")
|
| 55 |
+
|
| 56 |
+
# raise error if any required info is missing
|
| 57 |
+
if not all(k in model_info for k in ("run_id", "artifact", "model_id")):
|
| 58 |
+
raise ValueError(f"Incomplete model info for {lang}: {model_info}")
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning(f"Could not load model info for {lang}: {e}")
|
| 62 |
+
if lang in self.FALLBACK_MODEL_REGISTRY:
|
| 63 |
+
self.model_registry[lang] = self.FALLBACK_MODEL_REGISTRY[lang]
|
| 64 |
+
|
| 65 |
+
# Pre-cache models locally
|
| 66 |
+
run_id = self.model_registry[lang]["run_id"]
|
| 67 |
+
artifact = self.model_registry[lang]["artifact"]
|
| 68 |
+
self._get_cached_model_path(run_id, artifact, lang)
|
| 69 |
+
else:
|
| 70 |
+
logger.info("Using hardcoded model registry")
|
| 71 |
+
self.model_registry = self.FALLBACK_MODEL_REGISTRY
|
| 72 |
+
|
| 73 |
+
def _decode_predictions(self, raw_predictions, language: str):
|
| 74 |
+
"""
|
| 75 |
+
Converts the binary matrix from the model into human-readable labels.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
raw_predictions: Numpy array or similar with binary predictions
|
| 79 |
+
language: Programming language for label mapping
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
labels_map = LABELS_MAP.get(language, [])
|
| 83 |
+
decoded_results = []
|
| 84 |
+
|
| 85 |
+
# Ensure input is a numpy array for processing
|
| 86 |
+
if isinstance(raw_predictions, list):
|
| 87 |
+
raw_array = np.array(raw_predictions)
|
| 88 |
+
elif isinstance(raw_predictions, pd.DataFrame):
|
| 89 |
+
raw_array = raw_predictions.values
|
| 90 |
+
else:
|
| 91 |
+
raw_array = raw_predictions
|
| 92 |
+
|
| 93 |
+
# Iterate over rows
|
| 94 |
+
for row in raw_array:
|
| 95 |
+
indices = np.where(row == 1)[0]
|
| 96 |
+
# Map indices to labels safely
|
| 97 |
+
row_labels = [labels_map[i] for i in indices if i < len(labels_map)]
|
| 98 |
+
decoded_results.append(row_labels)
|
| 99 |
+
|
| 100 |
+
return decoded_results
|
| 101 |
+
|
| 102 |
+
def _get_cached_model_path(self, run_id: str, artifact_name: str, language: str) -> str:
|
| 103 |
+
"""Checks if model exists locally; if not, downloads it from MLflow."""
|
| 104 |
+
# Define local path: models/mlflow_temp_models/language/artifact_name
|
| 105 |
+
local_path = MODELS_DIR / "mlflow_temp_models" / language / artifact_name
|
| 106 |
+
|
| 107 |
+
if local_path.exists():
|
| 108 |
+
logger.info(f"Loading {language} model from local cache: {local_path}")
|
| 109 |
+
return str(local_path)
|
| 110 |
+
|
| 111 |
+
logger.info(
|
| 112 |
+
f"Model not found locally. Downloading {language} model from MLflow (Run ID: {run_id})..."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Ensure parent directory exists
|
| 116 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# Download artifacts to the parent directory (artifact_name folder will be created inside)
|
| 119 |
+
mlflow.artifacts.download_artifacts(
|
| 120 |
+
run_id=run_id, artifact_path=artifact_name, dst_path=str(local_path.parent)
|
| 121 |
+
)
|
| 122 |
+
logger.success(f"Model downloaded and cached at: {local_path}")
|
| 123 |
+
|
| 124 |
+
return str(local_path)
|
| 125 |
+
|
| 126 |
+
def predict_payload(self, texts: list[str], language: str):
|
| 127 |
+
"""
|
| 128 |
+
API Prediction: Automatically fetches the correct model from the registry based on language.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
texts: List of code comments to classify
|
| 132 |
+
language: Programming language
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# 1. Validate Language and Fetch Config
|
| 136 |
+
if language not in self.model_registry:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"Language '{language}' is not supported or the model is not configured."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
model_config = self.model_registry[language]
|
| 142 |
+
run_id = model_config["run_id"]
|
| 143 |
+
artifact_name = model_config["artifact"]
|
| 144 |
+
model_id = model_config["model_id"]
|
| 145 |
+
|
| 146 |
+
# Dynamically import model class
|
| 147 |
+
config_entry = MODEL_CONFIG[model_id]
|
| 148 |
+
module_name = config_entry["model_class_module"]
|
| 149 |
+
class_name = config_entry["model_class_name"]
|
| 150 |
+
module = importlib.import_module(module_name)
|
| 151 |
+
model_class = getattr(module, class_name)
|
| 152 |
+
|
| 153 |
+
# 2. Get Model Path (Local Cache or Download)
|
| 154 |
+
model_path = self._get_cached_model_path(run_id, artifact_name, language)
|
| 155 |
+
|
| 156 |
+
# Load Model
|
| 157 |
+
model = model_class(language=language, path=model_path)
|
| 158 |
+
|
| 159 |
+
# 3. Predict
|
| 160 |
+
raw_predictions = model.predict(texts)
|
| 161 |
+
|
| 162 |
+
# 4. Decode Labels
|
| 163 |
+
decoded_labels = self._decode_predictions(raw_predictions, language)
|
| 164 |
+
|
| 165 |
+
return raw_predictions, decoded_labels, run_id, artifact_name
|
| 166 |
+
|
| 167 |
+
def predict_from_mlflow(
|
| 168 |
+
self, mlflow_run_id: str, artifact_name: str, language: str, model_class=CodeBERTa
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Legacy method for CML/CLI: Predicts on the test dataset stored on disk.
|
| 172 |
+
"""
|
| 173 |
+
# Load Dataset
|
| 174 |
+
try:
|
| 175 |
+
full_dataset = self.dataset_manager.get_dataset()
|
| 176 |
+
dataset_key = f"{language}_test"
|
| 177 |
+
if dataset_key not in full_dataset:
|
| 178 |
+
raise ValueError(f"Dataset key '{dataset_key}' not found.")
|
| 179 |
+
test_ds = full_dataset[dataset_key]
|
| 180 |
+
X_test = test_ds[INPUT_COLUMN]
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"Error loading dataset: {e}")
|
| 183 |
+
raise e
|
| 184 |
+
|
| 185 |
+
# Load Model (Local Cache or Download)
|
| 186 |
+
model_path = self._get_cached_model_path(mlflow_run_id, artifact_name, language)
|
| 187 |
+
model = model_class(language=language, path=model_path)
|
| 188 |
+
|
| 189 |
+
raw_predictions = model.predict(X_test)
|
| 190 |
+
|
| 191 |
+
# Decode output
|
| 192 |
+
readable_predictions = self._decode_predictions(raw_predictions, language)
|
| 193 |
+
|
| 194 |
+
logger.info("Dataset prediction completed.")
|
| 195 |
+
return readable_predictions
|
turing/modeling/train.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import dagshub
|
| 6 |
+
from loguru import logger
|
| 7 |
+
import mlflow
|
| 8 |
+
from mlflow.tracking import MlflowClient
|
| 9 |
+
import numpy as np
|
| 10 |
+
import typer
|
| 11 |
+
|
| 12 |
+
import turing.config as config
|
| 13 |
+
from turing.dataset import DatasetManager
|
| 14 |
+
from turing.evaluate_model import evaluate_models
|
| 15 |
+
|
| 16 |
+
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Turing", mlflow=True)
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore")
|
| 19 |
+
|
| 20 |
+
DEFAULT_MODEL = "codeberta"
|
| 21 |
+
_default_cfg = config.MODEL_CONFIG[DEFAULT_MODEL]
|
| 22 |
+
|
| 23 |
+
MODEL_CLASS_MODULE = _default_cfg["model_class_module"]
|
| 24 |
+
MODEL_CLASS_NAME = _default_cfg["model_class_name"]
|
| 25 |
+
MODEL_CLASS = __import__(MODEL_CLASS_MODULE, fromlist=[MODEL_CLASS_NAME])
|
| 26 |
+
MODEL_CLASS = getattr(MODEL_CLASS, MODEL_CLASS_NAME)
|
| 27 |
+
EXP_NAME = _default_cfg["exp_name"]
|
| 28 |
+
MODEL_NAME = _default_cfg["model_name"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
app = typer.Typer()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def tag_best_models(
|
| 36 |
+
metric: str = "f1_score"
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Tag the best existing models in MLflow based on the specified metric.
|
| 40 |
+
Remove previous best_model tags before tagging the new best models.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
metric: Metric to use for determining the best model
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Turing", mlflow=True)
|
| 47 |
+
client = MlflowClient()
|
| 48 |
+
|
| 49 |
+
# Get all experiments from Mlflow
|
| 50 |
+
experiments = client.search_experiments()
|
| 51 |
+
if not experiments:
|
| 52 |
+
logger.error("No experiments found in MLflow")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
# Find the best run for each language
|
| 56 |
+
experiments_ids = [exp.experiment_id for exp in experiments]
|
| 57 |
+
for lang in config.LANGS:
|
| 58 |
+
# Get all runs for the language
|
| 59 |
+
runs = client.search_runs(
|
| 60 |
+
experiment_ids=experiments_ids,
|
| 61 |
+
filter_string=f"tags.Language = '{lang}'",
|
| 62 |
+
order_by=[f"metrics.{metric} DESC"]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if not runs:
|
| 66 |
+
logger.warning(f"No runs found for language {lang}")
|
| 67 |
+
continue
|
| 68 |
+
logger.info(f"Found {len(runs)} runs for {lang}")
|
| 69 |
+
|
| 70 |
+
# Get the best run for the language
|
| 71 |
+
best_run = runs[0]
|
| 72 |
+
run_id = best_run.info.run_id
|
| 73 |
+
|
| 74 |
+
# Remove previous best_model tags for this language
|
| 75 |
+
for run in runs[1:]:
|
| 76 |
+
try:
|
| 77 |
+
client.delete_tag(run.info.run_id, "best_model")
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
# Tag the best model
|
| 82 |
+
client.set_tag(run_id, "best_model", "true")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def show_tagged_models():
|
| 86 |
+
"""
|
| 87 |
+
Show all models tagged as best_model.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Turing", mlflow=True)
|
| 91 |
+
client = MlflowClient()
|
| 92 |
+
|
| 93 |
+
# Get all experiments from Mlflow
|
| 94 |
+
experiments = client.search_experiments()
|
| 95 |
+
if not experiments:
|
| 96 |
+
logger.error("No experiments found in MLflow")
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
# Find all runs tagged as best_model
|
| 100 |
+
runs = client.search_runs(
|
| 101 |
+
experiment_ids=[exp.experiment_id for exp in experiments],
|
| 102 |
+
filter_string="tags.best_model = 'true'",
|
| 103 |
+
order_by=["tags.Language ASC"]
|
| 104 |
+
)
|
| 105 |
+
logger.info(f"\nFound {len(runs)} best models in experiments:\n")
|
| 106 |
+
|
| 107 |
+
# Display details of each tagged best model
|
| 108 |
+
for run in runs:
|
| 109 |
+
language = run.data.tags.get("Language", "unknown")
|
| 110 |
+
exp_name = client.get_experiment(run.info.experiment_id).name
|
| 111 |
+
run_id = run.info.run_id
|
| 112 |
+
run_name = run.data.tags.get("mlflow.runName", "N/A")
|
| 113 |
+
dataset_name = run.data.tags.get("dataset_name", "unknown")
|
| 114 |
+
|
| 115 |
+
logger.info(f"Language: {language}")
|
| 116 |
+
logger.info(f" Run: {exp_name}/{run_name} ({run_id})")
|
| 117 |
+
logger.info(f" Dataset: {dataset_name}")
|
| 118 |
+
|
| 119 |
+
if run.data.metrics:
|
| 120 |
+
for metric in run.data.metrics:
|
| 121 |
+
logger.info(f" {metric}: {run.data.metrics[metric]:.4f}")
|
| 122 |
+
|
| 123 |
+
logger.info("")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@app.command()
|
| 127 |
+
def main(model: str = typer.Option("codeberta", help="Model to train: codeberta, graphcodebert, tinybert, or randomforest"), dataset: str = typer.Option(None, help="Dataset to use for training")):
|
| 128 |
+
# Get model configuration from config
|
| 129 |
+
model_key = model.lower()
|
| 130 |
+
if model_key not in config.MODEL_CONFIG:
|
| 131 |
+
logger.error(f"Unknown model: {model_key}. Available models: {list(config.MODEL_CONFIG.keys())}")
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
model_cfg = config.MODEL_CONFIG[model_key]
|
| 135 |
+
model_name = model_cfg["model_name"]
|
| 136 |
+
exp_name = model_cfg["exp_name"]
|
| 137 |
+
|
| 138 |
+
# Dynamically import model class
|
| 139 |
+
module = import_module(model_cfg["model_class_module"])
|
| 140 |
+
model_class = getattr(module, model_cfg["model_class_name"])
|
| 141 |
+
|
| 142 |
+
logger.info(f"Training model: {model_name}")
|
| 143 |
+
|
| 144 |
+
# Load dataset
|
| 145 |
+
dataset_path = config.INTERIM_DATA_DIR / "features" / dataset
|
| 146 |
+
dataset_manager = DatasetManager(dataset_path=dataset_path)
|
| 147 |
+
try:
|
| 148 |
+
full_dataset = dataset_manager.get_dataset()
|
| 149 |
+
dataset_name = dataset_manager.get_dataset_name()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Error loading dataset: {e}")
|
| 152 |
+
return
|
| 153 |
+
logger.info(f"Dataset loaded successfully: {dataset_name}")
|
| 154 |
+
|
| 155 |
+
# Train and evaluate models for each language
|
| 156 |
+
mlflow.set_experiment(exp_name)
|
| 157 |
+
models = {}
|
| 158 |
+
for lang in config.LANGS:
|
| 159 |
+
# Prepare training and testing data
|
| 160 |
+
train_ds = full_dataset[f"{lang}_train"]
|
| 161 |
+
test_ds = full_dataset[f"{lang}_test"]
|
| 162 |
+
X_train = train_ds[config.INPUT_COLUMN]
|
| 163 |
+
y_train = train_ds[config.LABEL_COLUMN]
|
| 164 |
+
X_test = test_ds[config.INPUT_COLUMN]
|
| 165 |
+
y_test = test_ds[config.LABEL_COLUMN]
|
| 166 |
+
X_train = list(X_train)
|
| 167 |
+
X_test = list(X_test)
|
| 168 |
+
y_train = np.array(y_train)
|
| 169 |
+
|
| 170 |
+
# Initialize model
|
| 171 |
+
model = model_class(language=lang)
|
| 172 |
+
|
| 173 |
+
# Train and evaluate model within an MLflow run
|
| 174 |
+
try:
|
| 175 |
+
with mlflow.start_run(run_name=f"{model_name}_{lang}"):
|
| 176 |
+
mlflow.set_tag("Language", lang)
|
| 177 |
+
mlflow.set_tag("dataset_name", dataset_name)
|
| 178 |
+
mlflow.set_tag("model_id", model_key)
|
| 179 |
+
mlflow.log_params(model.params)
|
| 180 |
+
parameters_to_log = model.train(
|
| 181 |
+
X_train,
|
| 182 |
+
y_train
|
| 183 |
+
)
|
| 184 |
+
mlflow.log_params(parameters_to_log)
|
| 185 |
+
model.save(os.path.join(config.MODELS_DIR, exp_name),model_name=model_name)
|
| 186 |
+
metrics = model.evaluate(X_test, y_test)
|
| 187 |
+
mlflow.log_metrics(metrics)
|
| 188 |
+
|
| 189 |
+
# Log model name for later retrieval
|
| 190 |
+
mlflow.set_tag("model_name", f"{model_name}_{lang}")
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"Error training/evaluating model for {lang}: {e}")
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# Store trained model
|
| 197 |
+
models[lang] = model
|
| 198 |
+
logger.success(f"All {model_name} models trained and evaluated.")
|
| 199 |
+
|
| 200 |
+
# Competition-style evaluation of trained models
|
| 201 |
+
logger.info("Starting competition-style evaluation of trained models...")
|
| 202 |
+
evaluate_models(models, full_dataset)
|
| 203 |
+
logger.success("Evaluation completed.")
|
| 204 |
+
|
| 205 |
+
logger.info("Tagging best models in MLflow...")
|
| 206 |
+
tag_best_models()
|
| 207 |
+
logger.info("Best models:")
|
| 208 |
+
show_tagged_models()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
app()
|
turing/plots.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import typer
|
| 6 |
+
|
| 7 |
+
from turing.config import FIGURES_DIR, PROCESSED_DATA_DIR
|
| 8 |
+
|
| 9 |
+
app = typer.Typer()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@app.command()
|
| 13 |
+
def main(
|
| 14 |
+
# ---- REPLACE DEFAULT PATHS AS APPROPRIATE ----
|
| 15 |
+
input_path: Path = PROCESSED_DATA_DIR / "dataset.csv",
|
| 16 |
+
output_path: Path = FIGURES_DIR / "plot.png",
|
| 17 |
+
# -----------------------------------------
|
| 18 |
+
):
|
| 19 |
+
# ---- REPLACE THIS WITH YOUR OWN CODE ----
|
| 20 |
+
logger.info("Generating plot from data...")
|
| 21 |
+
for i in tqdm(range(10), total=10):
|
| 22 |
+
if i == 5:
|
| 23 |
+
logger.info("Something happened for iteration 5.")
|
| 24 |
+
logger.success("Plot generation complete.")
|
| 25 |
+
# -----------------------------------------
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
app()
|
turing/reporting.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import platform
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from loguru import logger
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from turing.config import REPORTS_DIR
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestReportGenerator:
|
| 13 |
+
"""
|
| 14 |
+
Handles the generation of structured Markdown reports specifically for test execution results.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, context_name: str, report_category: str):
|
| 18 |
+
self.context_name = context_name
|
| 19 |
+
self.report_category = report_category
|
| 20 |
+
self.timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 21 |
+
self.content = []
|
| 22 |
+
self.output_dir = REPORTS_DIR / self.report_category
|
| 23 |
+
|
| 24 |
+
def add_header(self, text: str, level: int = 1):
|
| 25 |
+
self.content.append(f"\n{'#' * level} {text}\n")
|
| 26 |
+
|
| 27 |
+
def add_divider(self, style: str = "thin"):
|
| 28 |
+
"""Add a visual divider line."""
|
| 29 |
+
dividers = {
|
| 30 |
+
"thin": "---",
|
| 31 |
+
"thick": "___",
|
| 32 |
+
"section": "\n---\n",
|
| 33 |
+
}
|
| 34 |
+
self.content.append(f"\n{dividers.get(style, dividers['thin'])}\n")
|
| 35 |
+
|
| 36 |
+
def add_code_block(self, content: str, language: str = ""):
|
| 37 |
+
"""Add a code block."""
|
| 38 |
+
self.content.append(f"\n```{language}\n{content}\n```\n")
|
| 39 |
+
|
| 40 |
+
def add_alert_box(self, message: str, box_type: str = "info"):
|
| 41 |
+
"""Add a styled alert box using blockquotes."""
|
| 42 |
+
box_headers = {
|
| 43 |
+
"info": "INFO",
|
| 44 |
+
"success": "SUCCESS",
|
| 45 |
+
"warning": "WARNING",
|
| 46 |
+
"error": "ERROR",
|
| 47 |
+
}
|
| 48 |
+
header = box_headers.get(box_type, "INFO")
|
| 49 |
+
self.content.append(f"\n> **{header}**: {message}\n")
|
| 50 |
+
|
| 51 |
+
def add_progress_bar(self, passed: int, total: int, width: int = 50):
|
| 52 |
+
"""Add an ASCII progress bar."""
|
| 53 |
+
if total == 0:
|
| 54 |
+
percentage = 0
|
| 55 |
+
filled = 0
|
| 56 |
+
else:
|
| 57 |
+
percentage = (passed / total * 100)
|
| 58 |
+
filled = int(width * passed / total)
|
| 59 |
+
|
| 60 |
+
empty = width - filled
|
| 61 |
+
bar = "█" * filled + "░" * empty
|
| 62 |
+
self.add_code_block(f"Progress: [{bar}] {percentage:.1f}%\nPassed: {passed}/{total} tests", "")
|
| 63 |
+
|
| 64 |
+
def add_summary_box(self, total: int, passed: int, failed: int, skipped: int = 0):
|
| 65 |
+
"""Add a visually enhanced summary box."""
|
| 66 |
+
success_rate = (passed / total * 100) if total > 0 else 0
|
| 67 |
+
|
| 68 |
+
# Determine status
|
| 69 |
+
if success_rate == 100:
|
| 70 |
+
status = "ALL TESTS PASSED"
|
| 71 |
+
elif success_rate >= 80:
|
| 72 |
+
status = "MOSTLY PASSED"
|
| 73 |
+
elif success_rate >= 50:
|
| 74 |
+
status = "PARTIAL SUCCESS"
|
| 75 |
+
else:
|
| 76 |
+
status = "NEEDS ATTENTION"
|
| 77 |
+
|
| 78 |
+
self.add_header("Executive Summary", level=2)
|
| 79 |
+
self.add_text(f"**Overall Status:** {status}")
|
| 80 |
+
self.add_text(f"**Success Rate:** {success_rate:.1f}%")
|
| 81 |
+
|
| 82 |
+
# Summary table
|
| 83 |
+
summary_data = [
|
| 84 |
+
["Total Tests", str(total)],
|
| 85 |
+
["Passed", str(passed)],
|
| 86 |
+
["Failed", str(failed)],
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
if skipped > 0:
|
| 90 |
+
summary_data.append(["Skipped", str(skipped)])
|
| 91 |
+
|
| 92 |
+
summary_data.append(["Success Rate", f"{success_rate:.1f}%"])
|
| 93 |
+
|
| 94 |
+
df = pd.DataFrame(summary_data, columns=["Metric", "Count"])
|
| 95 |
+
self.add_dataframe(df, title=None, align=("left", "right"))
|
| 96 |
+
|
| 97 |
+
# Progress bar
|
| 98 |
+
self.add_text("**Visual Progress:**")
|
| 99 |
+
self.add_progress_bar(passed, total)
|
| 100 |
+
|
| 101 |
+
def add_environment_metadata(self):
|
| 102 |
+
"""Add enhanced environment metadata."""
|
| 103 |
+
self.add_header("Environment Information", level=2)
|
| 104 |
+
|
| 105 |
+
metadata = [
|
| 106 |
+
["Timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S")],
|
| 107 |
+
["Context", self.context_name.upper()],
|
| 108 |
+
["Python Version", sys.version.split()[0]],
|
| 109 |
+
["Platform", platform.platform()],
|
| 110 |
+
["Architecture", platform.machine()],
|
| 111 |
+
]
|
| 112 |
+
df = pd.DataFrame(metadata, columns=["Parameter", "Value"])
|
| 113 |
+
self.add_dataframe(df, title=None, align=("left", "left"))
|
| 114 |
+
|
| 115 |
+
def add_text(self, text: str):
|
| 116 |
+
self.content.append(f"\n{text}\n")
|
| 117 |
+
|
| 118 |
+
def add_category_stats(self, df: pd.DataFrame, category: str):
|
| 119 |
+
"""Add statistics for a test category."""
|
| 120 |
+
total = len(df)
|
| 121 |
+
passed = len(df[df['Result'] == "PASS"])
|
| 122 |
+
failed = len(df[df['Result'] == "FAIL"])
|
| 123 |
+
skipped = len(df[df['Result'] == "SKIP"])
|
| 124 |
+
|
| 125 |
+
stats = [
|
| 126 |
+
["Total", str(total)],
|
| 127 |
+
["Passed", f"{passed} ({passed/total*100:.1f}%)" if total > 0 else "0"],
|
| 128 |
+
["Failed", f"{failed} ({failed/total*100:.1f}%)" if total > 0 else "0"],
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
if skipped > 0:
|
| 132 |
+
stats.append(["Skipped", f"{skipped} ({skipped/total*100:.1f}%)"])
|
| 133 |
+
|
| 134 |
+
stats_df = pd.DataFrame(stats, columns=["Status", "Count"])
|
| 135 |
+
self.add_dataframe(stats_df, title="Statistics", align=("left", "right"))
|
| 136 |
+
|
| 137 |
+
def add_dataframe(self, df: pd.DataFrame, title: Optional[str] = None, align: tuple = None):
|
| 138 |
+
"""Add a formatted dataframe table."""
|
| 139 |
+
if title:
|
| 140 |
+
self.add_header(title, level=3)
|
| 141 |
+
|
| 142 |
+
if df.empty:
|
| 143 |
+
self.content.append("\n_No data available._\n")
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
if not align:
|
| 148 |
+
align = tuple(["left"] * len(df.columns))
|
| 149 |
+
|
| 150 |
+
table_md = df.to_markdown(index=False, tablefmt="pipe", colalign=align)
|
| 151 |
+
self.content.append(f"\n{table_md}\n")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.warning(f"Tabulate error: {e}. Using simple text.")
|
| 154 |
+
self.content.append(f"\n```text\n{df.to_string(index=False)}\n```\n")
|
| 155 |
+
|
| 156 |
+
def save(self, filename: str = "test_report.md") -> str:
|
| 157 |
+
"""Save the report to a file."""
|
| 158 |
+
try:
|
| 159 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 160 |
+
file_path = self.output_dir / filename
|
| 161 |
+
|
| 162 |
+
# Add footer
|
| 163 |
+
self.add_divider("section")
|
| 164 |
+
self.add_text(f"*Report generated on {datetime.now().strftime('%Y-%m-%d at %H:%M:%S')}*")
|
| 165 |
+
self.add_text("*Powered by Turing Test Suite*")
|
| 166 |
+
|
| 167 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 168 |
+
f.write("\n".join(self.content))
|
| 169 |
+
logger.info(f"Test report saved: {file_path}")
|
| 170 |
+
return str(file_path)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Save failed: {e}")
|
| 173 |
+
raise
|
turing/tests/behavioral/test_directional.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These tests check that adding or removing keywords logically changes the prediction
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_java_directional_add_deprecation(java_model, get_predicted_labels):
|
| 5 |
+
"""Tests that adding '@deprecated' ADDs the 'deprecation' label"""
|
| 6 |
+
# Base comment should be a 'Pointer' due to the link
|
| 7 |
+
base_comment = "/** Use {@link #newUserMethod()} instead. */"
|
| 8 |
+
# Perturbed comment adds a keyword
|
| 9 |
+
pert_comment = "/** @deprecated Use {@link #newUserMethod()} instead. */"
|
| 10 |
+
|
| 11 |
+
preds_base = get_predicted_labels(java_model, base_comment, "java")
|
| 12 |
+
preds_pert = get_predicted_labels(java_model, pert_comment, "java")
|
| 13 |
+
|
| 14 |
+
# The base comment should not have 'deprecation'
|
| 15 |
+
assert "deprecation" not in preds_base
|
| 16 |
+
# The perturbed comment must have 'deprecation'
|
| 17 |
+
assert "deprecation" in preds_pert
|
| 18 |
+
# The original 'Pointer' label should still be there
|
| 19 |
+
assert "Pointer" in preds_base
|
| 20 |
+
assert "Pointer" in preds_pert
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_python_directional_remove_todo(python_model, get_predicted_labels):
|
| 24 |
+
"""Tests that removing 'TODO' REMOVES the 'DevelopmentNotes' labe."""
|
| 25 |
+
base_comment = "# TODO: Refactor this entire block."
|
| 26 |
+
pert_comment = "# Refactor this entire block."
|
| 27 |
+
|
| 28 |
+
preds_base = get_predicted_labels(python_model, base_comment, "python")
|
| 29 |
+
preds_pert = get_predicted_labels(python_model, pert_comment, "python")
|
| 30 |
+
|
| 31 |
+
# The base comment must have 'DevelopmentNotes'
|
| 32 |
+
assert "DevelopmentNotes" in preds_base
|
| 33 |
+
# The perturbed comment must not have 'DevelopmentNotes'
|
| 34 |
+
assert "DevelopmentNotes" not in preds_pert
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_pharo_directional_add_responsibility(pharo_model, get_predicted_labels):
|
| 38 |
+
"""Tests that adding 'i am responsible for' adds the 'Responsibilities' label"""
|
| 39 |
+
base_comment = '"i am a simple arrow"'
|
| 40 |
+
pert_comment = '"i am a simple arrow. i am responsible for drawing."'
|
| 41 |
+
|
| 42 |
+
preds_base = get_predicted_labels(pharo_model, base_comment, "pharo")
|
| 43 |
+
preds_pert = get_predicted_labels(pharo_model, pert_comment, "pharo")
|
| 44 |
+
|
| 45 |
+
# base comment should have 'Intent'
|
| 46 |
+
assert "Intent" in preds_base
|
| 47 |
+
# base comment should not have 'Responsibilities'
|
| 48 |
+
assert "Responsibilities" not in preds_base
|
| 49 |
+
# perturbed comment must have 'Responsibilities'
|
| 50 |
+
assert "Responsibilities" in preds_pert
|
| 51 |
+
# original 'Intent' label should still be there
|
| 52 |
+
assert "Intent" in preds_pert
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_java_directional_contrast_rational(java_model, get_predicted_labels):
|
| 56 |
+
"""
|
| 57 |
+
Tests that adding a design rationale adds the 'rational' label
|
| 58 |
+
"""
|
| 59 |
+
# Base comment is a simple summary
|
| 60 |
+
base_comment = "/** Returns the user ID. */"
|
| 61 |
+
# Perturbed comment adds a design rationale
|
| 62 |
+
pert_comment = "/** Returns the user ID. This is cached for performance. */"
|
| 63 |
+
|
| 64 |
+
preds_base = get_predicted_labels(java_model, base_comment, "java")
|
| 65 |
+
preds_pert = get_predicted_labels(java_model, pert_comment, "java")
|
| 66 |
+
|
| 67 |
+
# Base comment should be a 'summary'
|
| 68 |
+
assert "summary" in preds_base
|
| 69 |
+
# Base comment should not have 'rational'
|
| 70 |
+
assert "rational" not in preds_base
|
| 71 |
+
# Perturbed comment must now have 'rational'
|
| 72 |
+
assert "rational" in preds_pert
|
| 73 |
+
# Perturbed comment should ideally still be a 'summary'
|
| 74 |
+
assert "summary" in preds_pert
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_python_directional_contrast_todo(python_model, get_predicted_labels):
|
| 78 |
+
"""
|
| 79 |
+
Tests that adding a "TODO" clause adds the 'DevelopmentNotes' label
|
| 80 |
+
"""
|
| 81 |
+
# Base comment is a simple summary
|
| 82 |
+
base_comment = "Fetches the user profile."
|
| 83 |
+
# Perturbed comment adds a development note
|
| 84 |
+
pert_comment = "Fetches the user profile. TODO: This is deprecated."
|
| 85 |
+
|
| 86 |
+
preds_base = get_predicted_labels(python_model, base_comment, "python")
|
| 87 |
+
preds_pert = get_predicted_labels(python_model, pert_comment, "python")
|
| 88 |
+
|
| 89 |
+
# Base comment should be a 'Summary'
|
| 90 |
+
assert "Summary" in preds_base
|
| 91 |
+
# Base comment should not have 'DevelopmentNotes'
|
| 92 |
+
assert "DevelopmentNotes" not in preds_base
|
| 93 |
+
# Perturbed comment must now have 'DevelopmentNotes'
|
| 94 |
+
assert "DevelopmentNotes" in preds_pert
|
| 95 |
+
# Perturbed comment should ideally still be a 'Summary'
|
| 96 |
+
assert "Summary" in preds_pert
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def test_pharo_directional_contrast_collaborators(pharo_model, get_predicted_labels):
|
| 100 |
+
"""
|
| 101 |
+
Tests that adding a 'but i work with' clause adds the 'Collaborators' label
|
| 102 |
+
"""
|
| 103 |
+
# Base comment is a simple intent
|
| 104 |
+
base_comment = '"i am a simple arrow like arrowhead."'
|
| 105 |
+
pert_comment = '"i am a simple arrow, but i work with BlSpace to position."'
|
| 106 |
+
|
| 107 |
+
preds_base = get_predicted_labels(pharo_model, base_comment, "pharo")
|
| 108 |
+
preds_pert = get_predicted_labels(pharo_model, pert_comment, "pharo")
|
| 109 |
+
|
| 110 |
+
# Base comment should be 'Intent'
|
| 111 |
+
assert "Intent" in preds_base
|
| 112 |
+
# Base comment should not have 'Collaborators'
|
| 113 |
+
assert "Collaborators" not in preds_base
|
| 114 |
+
# Perturbed comment must now have 'Collaborators'
|
| 115 |
+
assert "Collaborators" in preds_pert
|
| 116 |
+
# Perturbed comment should ideally still have 'Intent'
|
| 117 |
+
assert "Intent" in preds_pert
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_java_directional_shift_summary_to_expand(java_model, get_predicted_labels):
|
| 121 |
+
"""
|
| 122 |
+
Tests that replacing a simple 'summary' with an 'Expand' implementation note
|
| 123 |
+
shifts the primary classification from 'summary' to 'Expand'
|
| 124 |
+
"""
|
| 125 |
+
# Base comment is a simple summary
|
| 126 |
+
base_comment = "/** Returns the user ID. */"
|
| 127 |
+
# Perturbed comment shifts the focus entirely to implementation details
|
| 128 |
+
pert_comment = "/** Implementation Note: This delegates to the old system. */"
|
| 129 |
+
|
| 130 |
+
preds_base = get_predicted_labels(java_model, base_comment, "java")
|
| 131 |
+
preds_pert = get_predicted_labels(java_model, pert_comment, "java")
|
| 132 |
+
|
| 133 |
+
# Base comment must have 'summary'
|
| 134 |
+
assert "summary" in preds_base
|
| 135 |
+
# Perturbed comment must not have 'summary'
|
| 136 |
+
assert "summary" not in preds_pert
|
| 137 |
+
# Perturbed comment must now have 'Expand'
|
| 138 |
+
assert "Expand" in preds_pert
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def test_python_directional_shift_summary_to_devnotes(python_model, get_predicted_labels):
|
| 142 |
+
"""
|
| 143 |
+
Tests that replacing a 'Summary' with a critical development note (deprecated)
|
| 144 |
+
shifts the classification from 'Summary' to 'DevelopmentNotes'
|
| 145 |
+
"""
|
| 146 |
+
print(f"\n[DEBUG] Oggetto modello Python: {python_model}, Lingua: {python_model.language}")
|
| 147 |
+
# Base comment is a clear Summary
|
| 148 |
+
base_comment = "Fetches the user profile."
|
| 149 |
+
# Perturbed comment shifts the focus entirely to a note about future work
|
| 150 |
+
pert_comment = "DEPRECATED: This function is scheduled for removal in v2.0."
|
| 151 |
+
|
| 152 |
+
preds_base = get_predicted_labels(python_model, base_comment, "python")
|
| 153 |
+
preds_pert = get_predicted_labels(python_model, pert_comment, "python")
|
| 154 |
+
|
| 155 |
+
# Base comment must have 'Summary'
|
| 156 |
+
assert "Summary" in preds_base
|
| 157 |
+
# Perturbed comment must not have 'Summary'
|
| 158 |
+
assert "Summary" not in preds_pert
|
| 159 |
+
# Perturbed comment must now have 'DevelopmentNotes'
|
| 160 |
+
assert "DevelopmentNotes" in preds_pert
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def test_pharo_directional_shift_to_example(pharo_model, get_predicted_labels):
|
| 164 |
+
"""
|
| 165 |
+
Tests that changing a comment from a 'Responsibility' statement to an
|
| 166 |
+
explicit 'Example' statement shifts the primary classification
|
| 167 |
+
"""
|
| 168 |
+
# Base comment is a clear 'Responsibilities'
|
| 169 |
+
base_comment = '"i provide a data structure independent api"'
|
| 170 |
+
# Perturbed comment replaces the responsibility claim with an explicit example pattern
|
| 171 |
+
pert_comment = '"[Example] run the data structure independent api."'
|
| 172 |
+
|
| 173 |
+
preds_base = get_predicted_labels(pharo_model, base_comment, "pharo")
|
| 174 |
+
preds_pert = get_predicted_labels(pharo_model, pert_comment, "pharo")
|
| 175 |
+
|
| 176 |
+
# Base comment msut have Responsibilities
|
| 177 |
+
assert "Responsibilities" in preds_base
|
| 178 |
+
# Base comment should not have Example
|
| 179 |
+
assert "Example" not in preds_base
|
| 180 |
+
# Perturbed comment must now have Example
|
| 181 |
+
assert "Example" in preds_pert
|
| 182 |
+
# Perturbed comment should not have Responsibilities
|
| 183 |
+
assert "Responsibilities" not in preds_pert
|
turing/tests/behavioral/test_invariance.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
# These tests check that "noise" (like capitalization or punctuation) does not change the prediction
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@pytest.mark.parametrize(
|
| 7 |
+
"comment",
|
| 8 |
+
[
|
| 9 |
+
":param user_id: The ID of the user.", # Base
|
| 10 |
+
":PARAM USER_ID: THE ID OF THE USER.", # Uppercase
|
| 11 |
+
" :param user_id: The ID of the user . ", # Whitespace
|
| 12 |
+
":param user_id: The ID of the user!!!", # Punctuation
|
| 13 |
+
],
|
| 14 |
+
)
|
| 15 |
+
def test_python_invariance_parameters(python_model, comment, get_predicted_labels):
|
| 16 |
+
"""Tests that noise doesn't break ':param' detection."""
|
| 17 |
+
expected = {"Parameters"}
|
| 18 |
+
preds = get_predicted_labels(python_model, comment, "python")
|
| 19 |
+
assert preds == expected
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_java_invariance_deprecation(java_model, get_predicted_labels):
|
| 23 |
+
"""Tests that noise doesn't break '@deprecated' detection"""
|
| 24 |
+
base_comment = "/** @deprecated Use newUserMethod() */"
|
| 25 |
+
pert_comment = "/** @DEPRECATED... Use newUserMethod()!!! */"
|
| 26 |
+
|
| 27 |
+
preds_base = get_predicted_labels(java_model, base_comment, "java")
|
| 28 |
+
preds_pert = get_predicted_labels(java_model, pert_comment, "java")
|
| 29 |
+
|
| 30 |
+
assert {"deprecation"} <= preds_base
|
| 31 |
+
assert preds_base == preds_pert
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_python_invariance_summary(python_model, get_predicted_labels):
|
| 35 |
+
"""Tests that noise doesn't break a simple 'Summary' detection"""
|
| 36 |
+
|
| 37 |
+
base_comment = "a service specific account of type bar."
|
| 38 |
+
expected = {"Summary"}
|
| 39 |
+
|
| 40 |
+
# Perturbations
|
| 41 |
+
variants = [
|
| 42 |
+
base_comment,
|
| 43 |
+
"A SERVICE SPECIFIC ACCOUNT OF TYPE BAR.",
|
| 44 |
+
" a service specific account of type bar. ",
|
| 45 |
+
"a service specific account of type bar!!!",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
for comment in variants:
|
| 49 |
+
preds = get_predicted_labels(python_model, comment, "python")
|
| 50 |
+
assert preds == expected
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_pharo_invariance_intent(pharo_model, get_predicted_labels):
|
| 54 |
+
"""Tests that noise doesn't break Pharo's 'Intent' detection"""
|
| 55 |
+
|
| 56 |
+
base_comment = '"i am a simple arrow like arrowhead."'
|
| 57 |
+
expected = {"Intent"}
|
| 58 |
+
|
| 59 |
+
# Perturbations
|
| 60 |
+
variants = [
|
| 61 |
+
base_comment,
|
| 62 |
+
'"I AM A SIMPLE ARROW LIKE ARROWHEAD."',
|
| 63 |
+
' "i am a simple arrow like arrowhead." ',
|
| 64 |
+
'"i am a simple arrow like arrowhead !!"', #
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
for comment in variants:
|
| 68 |
+
preds = get_predicted_labels(pharo_model, comment, "pharo")
|
| 69 |
+
assert preds == expected
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_python_invariance_typos_parameters(python_model, get_predicted_labels):
|
| 73 |
+
"""
|
| 74 |
+
Tests typo tolerance
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# Define the single expected outcome
|
| 79 |
+
expected_labels = {"Parameters"}
|
| 80 |
+
|
| 81 |
+
# Define the base case and all its variants (with typos)
|
| 82 |
+
variants = [
|
| 83 |
+
":param user_id: The ID of the user.",
|
| 84 |
+
":paramater user_id: The ID of the user.",
|
| 85 |
+
":pram user_id: The ID of teh user.",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# Loop through all variants and assert they all produce the *exact* expected outcome
|
| 89 |
+
for comment in variants:
|
| 90 |
+
preds = get_predicted_labels(python_model, comment, "python")
|
| 91 |
+
assert preds == expected_labels
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_java_invariance_semantic_summary(java_model, get_predicted_labels):
|
| 95 |
+
"""
|
| 96 |
+
Tests semantic invariance
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
# Get the prediction for the base comment
|
| 101 |
+
base_comment = "/** Returns the user ID. */"
|
| 102 |
+
base_preds = get_predicted_labels(java_model, base_comment, "java")
|
| 103 |
+
|
| 104 |
+
# Define semantic paraphrases of the base comment
|
| 105 |
+
variants = [
|
| 106 |
+
base_comment,
|
| 107 |
+
"/** Gets the user ID. */",
|
| 108 |
+
"/** Fetches the ID for the user. */",
|
| 109 |
+
"/** A method to return the user's ID. */",
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
# Check that the base prediction is valid (summary)
|
| 113 |
+
assert "summary" in base_preds
|
| 114 |
+
|
| 115 |
+
for comment in variants:
|
| 116 |
+
preds = get_predicted_labels(java_model, comment, "java")
|
| 117 |
+
assert preds == base_preds
|
turing/tests/behavioral/test_minimum_functionality.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
# These tests check for basic, obvious classifications
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@pytest.mark.parametrize(
|
| 7 |
+
"comment, expected_labels",
|
| 8 |
+
[
|
| 9 |
+
("test getfilestatus and related listing operations.", {"summary"}),
|
| 10 |
+
("/* @deprecated Use something else. */", {"deprecation"}),
|
| 11 |
+
("code source of this file http grepcode.com", {"Pointer"}),
|
| 12 |
+
("this is balanced if each pool is balanced.", {"rational"}),
|
| 13 |
+
("// For internal use only.", {"Ownership"}),
|
| 14 |
+
("this impl delegates to the old filesystem", {"Expand"}),
|
| 15 |
+
("/** Usage: new MyClass(arg1). */", {"usage"}),
|
| 16 |
+
],
|
| 17 |
+
)
|
| 18 |
+
def test_java_mft(java_model, comment, expected_labels, get_predicted_labels):
|
| 19 |
+
preds = get_predicted_labels(java_model, comment, "java")
|
| 20 |
+
assert preds == expected_labels
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.mark.parametrize(
|
| 24 |
+
"comment, expected_labels",
|
| 25 |
+
[
|
| 26 |
+
("a service specific account of type bar.", {"Summary"}),
|
| 27 |
+
(":param user_id: The ID of the user.", {"Parameters"}),
|
| 28 |
+
("# TODO: Refactor this entire block.", {"DevelopmentNotes"}),
|
| 29 |
+
("use this class if you want access to all of the mechanisms", {"Usage"}),
|
| 30 |
+
("# create a new list by filtering duplicates from the input", {"Expand"}),
|
| 31 |
+
],
|
| 32 |
+
)
|
| 33 |
+
def test_python_mft(python_model, comment, expected_labels, get_predicted_labels):
|
| 34 |
+
preds = get_predicted_labels(python_model, comment, "python")
|
| 35 |
+
assert preds == expected_labels
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@pytest.mark.parametrize(
|
| 39 |
+
"comment, expected_labels",
|
| 40 |
+
[
|
| 41 |
+
("i am a simple arrow like arrowhead.", {"Intent"}),
|
| 42 |
+
("the example below shows how to create a simple element", {"Example"}),
|
| 43 |
+
("i provide a data structure independent api", {"Responsibilities"}),
|
| 44 |
+
("the cache is cleared after each test to ensure isolation.", {"Keyimplementationpoints"}),
|
| 45 |
+
("it is possible hovewer to customize a length fraction", {"Keymessages"}),
|
| 46 |
+
("collaborators: BlElement, BlSpace", {"Collaborators"}),
|
| 47 |
+
],
|
| 48 |
+
)
|
| 49 |
+
def test_pharo_mft(pharo_model, comment, expected_labels, get_predicted_labels):
|
| 50 |
+
"""Tests basic keyword-to-label mapping for Pharo (e.g., 'I am...')."""
|
| 51 |
+
preds = get_predicted_labels(pharo_model, comment, "pharo")
|
| 52 |
+
assert preds == expected_labels
|
turing/tests/conftest.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
import turing.config as config
|
| 10 |
+
from turing.dataset import DatasetManager
|
| 11 |
+
from turing.reporting import TestReportGenerator
|
| 12 |
+
|
| 13 |
+
# --- Path Setup ---
|
| 14 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 15 |
+
proj_root = os.path.dirname(os.path.dirname(script_dir))
|
| 16 |
+
sys.path.append(proj_root)
|
| 17 |
+
|
| 18 |
+
train_dir = os.path.join(proj_root, "turing", "modeling")
|
| 19 |
+
sys.path.insert(1, train_dir)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# Import train.py
|
| 24 |
+
import turing.modeling.train as train
|
| 25 |
+
except ImportError as e:
|
| 26 |
+
pytest.skip(
|
| 27 |
+
f"Could not import 'train.py'. Check sys.path. Error: {e}", allow_module_level=True
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# --- Reporting Setup ---
|
| 31 |
+
execution_results = []
|
| 32 |
+
active_categories = set()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def clean_test_name(nodeid):
|
| 36 |
+
"""Pulisce il nome del test rimuovendo parametri lunghi."""
|
| 37 |
+
parts = nodeid.split("::")
|
| 38 |
+
test_name = parts[-1]
|
| 39 |
+
if len(test_name) > 50:
|
| 40 |
+
test_name = test_name[:47] + "..."
|
| 41 |
+
return test_name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def format_error_message(long_repr):
|
| 45 |
+
"""Estrae solo l'errore principale."""
|
| 46 |
+
if not long_repr:
|
| 47 |
+
return ""
|
| 48 |
+
lines = str(long_repr).split("\n")
|
| 49 |
+
last_line = lines[-1]
|
| 50 |
+
clean_msg = last_line.replace("|", "-").strip()
|
| 51 |
+
if len(clean_msg) > 60:
|
| 52 |
+
clean_msg = clean_msg[:57] + "..."
|
| 53 |
+
return clean_msg
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
| 57 |
+
def pytest_runtest_makereport(item, call):
|
| 58 |
+
outcome = yield
|
| 59 |
+
report = outcome.get_result()
|
| 60 |
+
|
| 61 |
+
if report.when == "call":
|
| 62 |
+
path_str = str(item.fspath)
|
| 63 |
+
category = "GENERAL"
|
| 64 |
+
|
| 65 |
+
if "unit" in path_str:
|
| 66 |
+
category = "UNIT"
|
| 67 |
+
elif "behavioral" in path_str:
|
| 68 |
+
category = "BEHAVIORAL"
|
| 69 |
+
elif "modeling" in path_str:
|
| 70 |
+
category = "MODELING"
|
| 71 |
+
|
| 72 |
+
active_categories.add(category)
|
| 73 |
+
|
| 74 |
+
# Simplified status mapping
|
| 75 |
+
status_map = {"passed": "PASS", "failed": "FAIL", "skipped": "SKIP"}
|
| 76 |
+
status_str = status_map.get(report.outcome, report.outcome.upper())
|
| 77 |
+
|
| 78 |
+
execution_results.append(
|
| 79 |
+
{
|
| 80 |
+
"Category": category,
|
| 81 |
+
"Module": item.fspath.basename,
|
| 82 |
+
"Test Case": clean_test_name(item.nodeid),
|
| 83 |
+
"Result": status_str,
|
| 84 |
+
"Time": f"{report.duration:.2f}s",
|
| 85 |
+
"Message": format_error_message(report.longrepr) if report.failed else "",
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def pytest_sessionfinish(session, exitstatus):
|
| 91 |
+
"""Generate enhanced test report at session end."""
|
| 92 |
+
if not execution_results:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
report_type = (
|
| 96 |
+
f"{list(active_categories)[0].lower()}_tests"
|
| 97 |
+
if len(active_categories) == 1
|
| 98 |
+
else "unit_and_behavioral_tests"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
manager = TestReportGenerator(context_name="turing", report_category=report_type)
|
| 103 |
+
|
| 104 |
+
# Main title
|
| 105 |
+
manager.add_header("Turing Test Execution Report")
|
| 106 |
+
manager.add_divider("section")
|
| 107 |
+
|
| 108 |
+
# Environment info
|
| 109 |
+
manager.add_environment_metadata()
|
| 110 |
+
manager.add_divider("thin")
|
| 111 |
+
|
| 112 |
+
df = pd.DataFrame(execution_results)
|
| 113 |
+
|
| 114 |
+
# Sommario
|
| 115 |
+
total = len(df)
|
| 116 |
+
passed = len(df[df["Result"] == "[ PASS ]"])
|
| 117 |
+
failed = len(df[df["Result"] == "[ FAILED ]"])
|
| 118 |
+
summary = pd.DataFrame(
|
| 119 |
+
[
|
| 120 |
+
{
|
| 121 |
+
"Total": total,
|
| 122 |
+
"Passed": passed,
|
| 123 |
+
"Failed": failed,
|
| 124 |
+
"Success Rate": f"{(passed / total) * 100:.1f}%",
|
| 125 |
+
}
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
manager.add_dataframe(summary, title="Executive Summary")
|
| 129 |
+
|
| 130 |
+
# Detailed breakdown by category
|
| 131 |
+
cols = ["Module", "Test Case", "Result", "Time", "Message"]
|
| 132 |
+
|
| 133 |
+
if len(active_categories) > 1:
|
| 134 |
+
manager.add_header("Detailed Test Results by Category", level=2)
|
| 135 |
+
manager.add_divider("thin")
|
| 136 |
+
|
| 137 |
+
for cat in sorted(active_categories):
|
| 138 |
+
subset = df[df["Category"] == cat][cols]
|
| 139 |
+
manager.add_dataframe(subset, title=f"{cat} Tests")
|
| 140 |
+
else:
|
| 141 |
+
manager.add_alert_box(
|
| 142 |
+
"All tests passed successfully!",
|
| 143 |
+
box_type="success"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
manager.save("report.md")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"\nError generating report: {e}")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# --- Fixtures ---
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@pytest.fixture(scope="function")
|
| 155 |
+
def manager() -> DatasetManager:
|
| 156 |
+
"""
|
| 157 |
+
Provides a instance of DatasetManager for each test.
|
| 158 |
+
"""
|
| 159 |
+
return DatasetManager()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@pytest.fixture(scope="function")
|
| 163 |
+
def fake_csv_data_dir(tmp_path: Path) -> Path:
|
| 164 |
+
"""
|
| 165 |
+
Creates a temporary directory structure mocking 'data/interim/features/clean-aug-soft-k5000'
|
| 166 |
+
and populates it with minimal, valid CSV files for testing.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Path: The path to the *parent* of 'features' (e.g., the mocked INTERIM_DATA_DIR).
|
| 170 |
+
"""
|
| 171 |
+
interim_dir = tmp_path / "interim_test"
|
| 172 |
+
features_dir = interim_dir / "features" / "clean-aug-soft-k5000"
|
| 173 |
+
features_dir.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
# Define minimal valid CSV content
|
| 176 |
+
csv_content = (
|
| 177 |
+
"combo,labels\n"
|
| 178 |
+
'"java code text","[1, 0, 0, 0, 0, 0, 0]"\n'
|
| 179 |
+
'"other java code","[0, 1, 0, 0, 0, 0, 0]"\n'
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Write mock files
|
| 183 |
+
(features_dir / "java_train.csv").write_text(csv_content)
|
| 184 |
+
(features_dir / "java_test.csv").write_text(csv_content)
|
| 185 |
+
|
| 186 |
+
# Return the root of the mocked interim directory
|
| 187 |
+
return interim_dir
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@pytest.fixture(scope="session")
|
| 191 |
+
def mock_data():
|
| 192 |
+
"""
|
| 193 |
+
Provides a minimal, consistent, session-scoped dataset for model testing.
|
| 194 |
+
This simulates the (X, y) data structure used for training and evaluation.
|
| 195 |
+
"""
|
| 196 |
+
X = [
|
| 197 |
+
"this is java code for summary",
|
| 198 |
+
"python is great for parameters",
|
| 199 |
+
"a java example for usage",
|
| 200 |
+
"running python script for development notes",
|
| 201 |
+
"pharo is a language for intent",
|
| 202 |
+
"another java rational example",
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
# Mock labels for a 'java' model (7 categories)
|
| 206 |
+
# Shape (6 samples, 7 features)
|
| 207 |
+
y = np.array(
|
| 208 |
+
[
|
| 209 |
+
[1, 0, 0, 0, 0, 0, 0],
|
| 210 |
+
[0, 1, 0, 0, 0, 0, 0],
|
| 211 |
+
[1, 0, 0, 1, 0, 0, 0],
|
| 212 |
+
[0, 0, 1, 0, 0, 0, 0],
|
| 213 |
+
[0, 0, 0, 0, 1, 0, 0],
|
| 214 |
+
[1, 0, 0, 0, 0, 0, 1],
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
return {"X": X, "y": y}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@pytest.fixture(scope="module")
|
| 221 |
+
def trained_rf_model(mock_data, tmp_path_factory):
|
| 222 |
+
"""
|
| 223 |
+
Provides a fully-trained RandomForestTfIdf model instance.
|
| 224 |
+
"""
|
| 225 |
+
# Import locally to ensure proj_root is set
|
| 226 |
+
from modeling.models.randomForestTfIdf import RandomForestTfIdf
|
| 227 |
+
|
| 228 |
+
# Arrange
|
| 229 |
+
model = RandomForestTfIdf(language="java")
|
| 230 |
+
|
| 231 |
+
# Monkeypatch grid search parameters for maximum speed
|
| 232 |
+
model.grid_params = {
|
| 233 |
+
"tfidf__max_features": [10, 20], # Use minimal features
|
| 234 |
+
"clf__estimator__n_estimators": [2, 5], # Use minimal trees
|
| 235 |
+
}
|
| 236 |
+
model.params["cv_folds"] = 2 # Use minimal CV folds
|
| 237 |
+
|
| 238 |
+
# Create a persistent temp dir for this module's run
|
| 239 |
+
model_path = tmp_path_factory.mktemp("trained_rf_model")
|
| 240 |
+
|
| 241 |
+
# Act: Train the model
|
| 242 |
+
model.train(mock_data["X"], mock_data["y"], path=str(model_path), model_name="test_model")
|
| 243 |
+
|
| 244 |
+
# Yield the trained model and its save path
|
| 245 |
+
yield model, model_path
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
MODEL_CLASS_TO_TEST = train.MODEL_CLASS
|
| 249 |
+
MODEL_EXPERIMENT_NAME = train.EXP_NAME
|
| 250 |
+
MODEL_NAME_BASE = train.MODEL_NAME
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@pytest.fixture(scope="session")
|
| 254 |
+
def get_predicted_labels():
|
| 255 |
+
def _helper(model, comment_sentence: str, lang: str) -> set:
|
| 256 |
+
if config.INPUT_COLUMN == "combo":
|
| 257 |
+
combo_input = f"DummyClass.{lang} | {comment_sentence}"
|
| 258 |
+
input_data = [combo_input]
|
| 259 |
+
else:
|
| 260 |
+
input_data = [comment_sentence]
|
| 261 |
+
|
| 262 |
+
prediction_array = model.predict(input_data)[0]
|
| 263 |
+
labels_map = config.LABELS_MAP[lang]
|
| 264 |
+
predicted_labels = {labels_map[i] for i, val in enumerate(prediction_array) if val == 1}
|
| 265 |
+
return predicted_labels
|
| 266 |
+
|
| 267 |
+
return _helper
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@pytest.fixture(scope="module")
|
| 271 |
+
def java_model():
|
| 272 |
+
"""Loads the Java model from the config path"""
|
| 273 |
+
model_path = os.path.join(config.MODELS_DIR, MODEL_EXPERIMENT_NAME, f"{MODEL_NAME_BASE}_java")
|
| 274 |
+
if not os.path.exists(model_path):
|
| 275 |
+
pytest.skip(
|
| 276 |
+
"Production model not found. Skipping behavioral tests for Java.",
|
| 277 |
+
allow_module_level=True,
|
| 278 |
+
)
|
| 279 |
+
return MODEL_CLASS_TO_TEST(language="java", path=model_path)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@pytest.fixture(scope="module")
|
| 283 |
+
def python_model():
|
| 284 |
+
"""Loads the Python model from the config path"""
|
| 285 |
+
model_path = os.path.join(
|
| 286 |
+
config.MODELS_DIR, MODEL_EXPERIMENT_NAME, f"{MODEL_NAME_BASE}_python"
|
| 287 |
+
)
|
| 288 |
+
if not os.path.exists(model_path):
|
| 289 |
+
pytest.skip(
|
| 290 |
+
"Production model not found. Skipping behavioral tests for Python.",
|
| 291 |
+
allow_module_level=True,
|
| 292 |
+
)
|
| 293 |
+
return MODEL_CLASS_TO_TEST(language="python", path=model_path)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@pytest.fixture(scope="module")
|
| 297 |
+
def pharo_model():
|
| 298 |
+
"""Loads the Pharo model from the config path"""
|
| 299 |
+
model_path = os.path.join(config.MODELS_DIR, MODEL_EXPERIMENT_NAME, f"{MODEL_NAME_BASE}_pharo")
|
| 300 |
+
if not os.path.exists(model_path):
|
| 301 |
+
pytest.skip(
|
| 302 |
+
"Production model not found. Skipping behavioral tests for Pharo.",
|
| 303 |
+
allow_module_level=True,
|
| 304 |
+
)
|
| 305 |
+
return MODEL_CLASS_TO_TEST(language="pharo", path=model_path)
|
turing/tests/unit/test_api.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import patch
|
| 2 |
+
|
| 3 |
+
from fastapi.testclient import TestClient
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from turing.api.app import app
|
| 8 |
+
from turing.api.schemas import PredictionRequest, PredictionResponse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def client():
|
| 13 |
+
"""Fixture that provides a test client for the FastAPI app."""
|
| 14 |
+
return TestClient(app)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def mock_inference_engine():
|
| 19 |
+
"""Fixture that provides a mocked inference engine."""
|
| 20 |
+
with patch('turing.api.app.inference_engine') as mock:
|
| 21 |
+
yield mock
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestHealthCheck:
|
| 25 |
+
"""Test suite for the health check endpoint."""
|
| 26 |
+
|
| 27 |
+
def test_health_check_returns_ok(self, client):
|
| 28 |
+
"""Test that the health check endpoint returns status ok."""
|
| 29 |
+
response = client.get("/")
|
| 30 |
+
assert response.status_code == 200
|
| 31 |
+
assert response.json() == {
|
| 32 |
+
"status": "ok",
|
| 33 |
+
"message": "Turing Code Classification API is ready."
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TestPredictEndpoint:
|
| 38 |
+
"""Test suite for the predict endpoint."""
|
| 39 |
+
|
| 40 |
+
def test_predict_success_java(self, client, mock_inference_engine):
|
| 41 |
+
"""Test successful prediction for Java code."""
|
| 42 |
+
# Setup mock
|
| 43 |
+
mock_inference_engine.predict_payload.return_value = (
|
| 44 |
+
np.array([0, 1]), # raw predictions as numpy array
|
| 45 |
+
["class", "method"], # labels
|
| 46 |
+
"run_id_123", # run_id
|
| 47 |
+
"models:/CodeBERTa_java/Production" # artifact
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Make request
|
| 51 |
+
request_data = {
|
| 52 |
+
"texts": ["public class Main", "public void test()"],
|
| 53 |
+
"language": "java"
|
| 54 |
+
}
|
| 55 |
+
response = client.post("/predict", json=request_data)
|
| 56 |
+
|
| 57 |
+
# Assertions
|
| 58 |
+
assert response.status_code == 200
|
| 59 |
+
data = response.json()
|
| 60 |
+
assert "predictions" in data
|
| 61 |
+
assert "labels" in data
|
| 62 |
+
assert "model_info" in data
|
| 63 |
+
assert data["labels"] == ["class", "method"]
|
| 64 |
+
assert data["model_info"]["language"] == "java"
|
| 65 |
+
|
| 66 |
+
def test_predict_success_python(self, client, mock_inference_engine):
|
| 67 |
+
"""Test successful prediction for Python code."""
|
| 68 |
+
# Setup mock
|
| 69 |
+
mock_inference_engine.predict_payload.return_value = (
|
| 70 |
+
np.array([1, 0]), # raw predictions as numpy array
|
| 71 |
+
["function", "class"], # labels
|
| 72 |
+
"run_id_456", # run_id
|
| 73 |
+
"models:/CodeBERTa_python/Production" # artifact
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Make request
|
| 77 |
+
request_data = {
|
| 78 |
+
"texts": ["def main():", "class MyClass:"],
|
| 79 |
+
"language": "python"
|
| 80 |
+
}
|
| 81 |
+
response = client.post("/predict", json=request_data)
|
| 82 |
+
|
| 83 |
+
# Assertions
|
| 84 |
+
assert response.status_code == 200
|
| 85 |
+
data = response.json()
|
| 86 |
+
assert data["labels"] == ["function", "class"]
|
| 87 |
+
assert data["model_info"]["language"] == "python"
|
| 88 |
+
|
| 89 |
+
def test_predict_success_pharo(self, client, mock_inference_engine):
|
| 90 |
+
"""Test successful prediction for Pharo code."""
|
| 91 |
+
# Setup mock
|
| 92 |
+
mock_inference_engine.predict_payload.return_value = (
|
| 93 |
+
np.array([0]), # raw predictions as numpy array
|
| 94 |
+
["method"], # labels
|
| 95 |
+
"run_id_789", # run_id
|
| 96 |
+
"models:/CodeBERTa_pharo/Production" # artifact
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Make request
|
| 100 |
+
request_data = {
|
| 101 |
+
"texts": ["initialize"],
|
| 102 |
+
"language": "pharo"
|
| 103 |
+
}
|
| 104 |
+
response = client.post("/predict", json=request_data)
|
| 105 |
+
|
| 106 |
+
# Assertions
|
| 107 |
+
assert response.status_code == 200
|
| 108 |
+
data = response.json()
|
| 109 |
+
assert data["labels"] == ["method"]
|
| 110 |
+
assert data["model_info"]["language"] == "pharo"
|
| 111 |
+
|
| 112 |
+
def test_predict_missing_texts(self, client):
|
| 113 |
+
"""Test that prediction fails when texts are missing."""
|
| 114 |
+
request_data = {
|
| 115 |
+
"language": "java"
|
| 116 |
+
}
|
| 117 |
+
response = client.post("/predict", json=request_data)
|
| 118 |
+
assert response.status_code == 422 # Validation error
|
| 119 |
+
|
| 120 |
+
def test_predict_missing_language(self, client):
|
| 121 |
+
"""Test that prediction fails when language is missing."""
|
| 122 |
+
request_data = {
|
| 123 |
+
"texts": ["public class Main"]
|
| 124 |
+
}
|
| 125 |
+
response = client.post("/predict", json=request_data)
|
| 126 |
+
assert response.status_code == 422 # Validation error
|
| 127 |
+
|
| 128 |
+
def test_predict_empty_texts(self, client, mock_inference_engine):
|
| 129 |
+
"""Test prediction with empty texts list."""
|
| 130 |
+
mock_inference_engine.predict_payload.return_value = (
|
| 131 |
+
np.array([]), # raw predictions as empty numpy array
|
| 132 |
+
[], # labels
|
| 133 |
+
"run_id_000", # run_id
|
| 134 |
+
"models:/CodeBERTa_java/Production" # artifact
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
request_data = {
|
| 138 |
+
"texts": [],
|
| 139 |
+
"language": "java"
|
| 140 |
+
}
|
| 141 |
+
response = client.post("/predict", json=request_data)
|
| 142 |
+
|
| 143 |
+
# Should succeed with empty results
|
| 144 |
+
assert response.status_code == 200
|
| 145 |
+
data = response.json()
|
| 146 |
+
assert data["predictions"] == []
|
| 147 |
+
assert data["labels"] == []
|
| 148 |
+
|
| 149 |
+
def test_predict_error_handling(self, client, mock_inference_engine):
|
| 150 |
+
"""Test that prediction endpoint handles errors gracefully."""
|
| 151 |
+
# Setup mock to raise an exception
|
| 152 |
+
mock_inference_engine.predict_payload.side_effect = Exception("Model loading failed")
|
| 153 |
+
|
| 154 |
+
request_data = {
|
| 155 |
+
"texts": ["public class Main"],
|
| 156 |
+
"language": "java"
|
| 157 |
+
}
|
| 158 |
+
response = client.post("/predict", json=request_data)
|
| 159 |
+
|
| 160 |
+
# Should return 500 error
|
| 161 |
+
assert response.status_code == 500
|
| 162 |
+
assert "Model loading failed" in response.json()["detail"]
|
| 163 |
+
|
| 164 |
+
def test_predict_invalid_language(self, client, mock_inference_engine):
|
| 165 |
+
"""Test prediction with invalid language parameter."""
|
| 166 |
+
# The model might raise an error for unsupported language
|
| 167 |
+
mock_inference_engine.predict_payload.side_effect = ValueError("Unsupported language: cobol")
|
| 168 |
+
|
| 169 |
+
request_data = {
|
| 170 |
+
"texts": ["IDENTIFICATION DIVISION."],
|
| 171 |
+
"language": "cobol"
|
| 172 |
+
}
|
| 173 |
+
response = client.post("/predict", json=request_data)
|
| 174 |
+
|
| 175 |
+
# Should return 500 error
|
| 176 |
+
assert response.status_code == 500
|
| 177 |
+
assert "Unsupported language" in response.json()["detail"]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TestAPISchemas:
|
| 181 |
+
"""Test suite for API schemas validation."""
|
| 182 |
+
|
| 183 |
+
def test_prediction_request_valid(self):
|
| 184 |
+
"""Test that PredictionRequest validates correct data."""
|
| 185 |
+
request = PredictionRequest(
|
| 186 |
+
texts=["public void main"],
|
| 187 |
+
language="java"
|
| 188 |
+
)
|
| 189 |
+
assert request.texts == ["public void main"]
|
| 190 |
+
assert request.language == "java"
|
| 191 |
+
|
| 192 |
+
def test_prediction_response_valid(self):
|
| 193 |
+
"""Test that PredictionResponse validates correct data."""
|
| 194 |
+
response = PredictionResponse(
|
| 195 |
+
predictions=[0, 1],
|
| 196 |
+
labels=["class", "method"],
|
| 197 |
+
model_info={"artifact": "models:/CodeBERTa_java/Production", "language": "java"}
|
| 198 |
+
)
|
| 199 |
+
assert response.predictions == [0, 1]
|
| 200 |
+
assert response.labels == ["class", "method"]
|
| 201 |
+
assert response.model_info["language"] == "java"
|
turing/tests/unit/test_config.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
# Import the module to be tested
|
| 8 |
+
import turing.config as config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.mark.config
|
| 12 |
+
class TestConfig:
|
| 13 |
+
"""
|
| 14 |
+
Test suite for validating the project's configuration module (config.py).
|
| 15 |
+
|
| 16 |
+
These tests verify that paths are structured correctly, critical constants
|
| 17 |
+
are of the expected type and value, and module-level logic
|
| 18 |
+
(like calculations and .env loading) executes as intended.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def test_proj_root_is_correctly_identified(self):
|
| 22 |
+
"""
|
| 23 |
+
Validates that PROJ_ROOT is a Path object and points to the
|
| 24 |
+
actual project root directory (which should contain 'pyproject.toml').
|
| 25 |
+
"""
|
| 26 |
+
assert isinstance(config.PROJ_ROOT, Path)
|
| 27 |
+
assert config.PROJ_ROOT.is_dir()
|
| 28 |
+
|
| 29 |
+
# A common "sanity check" is to look for a known file at the root
|
| 30 |
+
expected_file = config.PROJ_ROOT / "pyproject.toml"
|
| 31 |
+
assert expected_file.is_file(), (
|
| 32 |
+
f"PROJ_ROOT ({config.PROJ_ROOT}) does not seem to be the project root. "
|
| 33 |
+
f"Could not find {expected_file}"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def test_directory_paths_are_correctly_structured(self):
|
| 37 |
+
"""
|
| 38 |
+
Ensures all key directory variables are Path objects
|
| 39 |
+
and are correctly parented under PROJ_ROOT.
|
| 40 |
+
"""
|
| 41 |
+
# List of all directory variables defined in config.py
|
| 42 |
+
path_vars = [
|
| 43 |
+
config.DATA_DIR,
|
| 44 |
+
config.RAW_DATA_DIR,
|
| 45 |
+
config.INTERIM_DATA_DIR,
|
| 46 |
+
config.PROCESSED_DATA_DIR,
|
| 47 |
+
config.EXTERNAL_DATA_DIR,
|
| 48 |
+
config.MODELS_DIR,
|
| 49 |
+
config.REPORTS_DIR,
|
| 50 |
+
config.FIGURES_DIR,
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
for path_var in path_vars:
|
| 54 |
+
assert isinstance(path_var, Path)
|
| 55 |
+
# Check that PROJ_ROOT is an ancestor of this path
|
| 56 |
+
assert config.PROJ_ROOT in path_var.parents
|
| 57 |
+
|
| 58 |
+
# Spot-check a few for correct relative paths
|
| 59 |
+
assert config.DATA_DIR == config.PROJ_ROOT / "data"
|
| 60 |
+
assert config.RAW_DATA_DIR == config.PROJ_ROOT / "data" / "raw"
|
| 61 |
+
assert config.FIGURES_DIR == config.PROJ_ROOT / "reports" / "figures"
|
| 62 |
+
|
| 63 |
+
def test_dataset_constants_are_valid(self):
|
| 64 |
+
"""
|
| 65 |
+
Validates that critical dataset constants are non-empty and of
|
| 66 |
+
the correct type.
|
| 67 |
+
"""
|
| 68 |
+
assert isinstance(config.DATASET_HF_ID, str)
|
| 69 |
+
assert config.DATASET_HF_ID == "NLBSE/nlbse26-code-comment-classification"
|
| 70 |
+
|
| 71 |
+
assert isinstance(config.LANGS, list)
|
| 72 |
+
assert len(config.LANGS) == 3
|
| 73 |
+
assert "java" in config.LANGS
|
| 74 |
+
|
| 75 |
+
assert isinstance(config.INPUT_COLUMN, str) and config.INPUT_COLUMN
|
| 76 |
+
assert isinstance(config.LABEL_COLUMN, str) and config.LABEL_COLUMN
|
| 77 |
+
|
| 78 |
+
def test_labels_map_and_total_categories_are_correct(self):
|
| 79 |
+
"""
|
| 80 |
+
Validates the LABELS_MAP structure and ensures TOTAL_CATEGORIES
|
| 81 |
+
is correctly calculated from it.
|
| 82 |
+
"""
|
| 83 |
+
assert isinstance(config.LABELS_MAP, dict)
|
| 84 |
+
|
| 85 |
+
# Ensure all languages in LANGS are keys in LABELS_MAP
|
| 86 |
+
for lang in config.LANGS:
|
| 87 |
+
assert lang in config.LABELS_MAP
|
| 88 |
+
assert isinstance(config.LABELS_MAP[lang], list)
|
| 89 |
+
assert len(config.LABELS_MAP[lang]) > 0
|
| 90 |
+
|
| 91 |
+
# Validate the derived calculation
|
| 92 |
+
expected_total = (
|
| 93 |
+
len(config.LABELS_MAP["java"])
|
| 94 |
+
+ len(config.LABELS_MAP["python"])
|
| 95 |
+
+ len(config.LABELS_MAP["pharo"])
|
| 96 |
+
)
|
| 97 |
+
assert config.TOTAL_CATEGORIES == expected_total
|
| 98 |
+
assert config.TOTAL_CATEGORIES == 18 # 7 + 5 + 6
|
| 99 |
+
|
| 100 |
+
def test_numeric_parameters_are_positive(self):
|
| 101 |
+
"""
|
| 102 |
+
Ensures that numeric scoring and training parameters are positive
|
| 103 |
+
and of the correct type.
|
| 104 |
+
"""
|
| 105 |
+
numeric_params = {
|
| 106 |
+
"MAX_AVG_RUNTIME": config.MAX_AVG_RUNTIME,
|
| 107 |
+
"MAX_AVG_FLOPS": config.MAX_AVG_FLOPS,
|
| 108 |
+
"DEFAULT_BATCH_SIZE": config.DEFAULT_BATCH_SIZE,
|
| 109 |
+
"DEFAULT_NUM_ITERATIONS": config.DEFAULT_NUM_ITERATIONS,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
for name, value in numeric_params.items():
|
| 113 |
+
assert isinstance(value, (int, float)), f"{name} is not numeric"
|
| 114 |
+
assert value > 0, f"{name} must be positive"
|
| 115 |
+
|
| 116 |
+
@patch("dotenv.load_dotenv")
|
| 117 |
+
def test_load_dotenv_is_called_on_module_load(self, mock_load_dotenv):
|
| 118 |
+
"""
|
| 119 |
+
Tests that the load_dotenv() function is executed when the
|
| 120 |
+
config.py module is loaded.
|
| 121 |
+
|
| 122 |
+
This requires reloading the module, as it's likely already been
|
| 123 |
+
imported by pytest or conftest.
|
| 124 |
+
"""
|
| 125 |
+
# Arrange (Patch is active)
|
| 126 |
+
|
| 127 |
+
# Act
|
| 128 |
+
# Reload the config module to trigger its top-level statements
|
| 129 |
+
importlib.reload(config)
|
| 130 |
+
|
| 131 |
+
# Assert
|
| 132 |
+
# Check that the patched load_dotenv was called
|
| 133 |
+
mock_load_dotenv.assert_called_once()
|
turing/tests/unit/test_dataset.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
# Project modules are importable thanks to conftest.py
|
| 6 |
+
import turing.config as config
|
| 7 |
+
from turing.dataset import DatasetManager
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.mark.data_loader
|
| 11 |
+
class TestDatasetManager:
|
| 12 |
+
"""
|
| 13 |
+
Unit tests for the DatasetManager class.
|
| 14 |
+
This test suite validates initialization, data transformation logic,
|
| 15 |
+
and data loading mechanisms, including error handling.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def test_initialization_paths_are_correct(self, manager: DatasetManager):
|
| 19 |
+
"""
|
| 20 |
+
Verifies that the DatasetManager initializes with the correct
|
| 21 |
+
Hugging Face ID and constructs its paths as expected.
|
| 22 |
+
"""
|
| 23 |
+
assert manager.hf_id == "NLBSE/nlbse26-code-comment-classification"
|
| 24 |
+
assert "data/raw" in str(manager.raw_data_dir)
|
| 25 |
+
# base_interim_path should contain either 'base' or 'features'
|
| 26 |
+
path_str = str(manager.base_interim_path)
|
| 27 |
+
assert "data/interim" in path_str and ("base" in path_str or "features" in path_str)
|
| 28 |
+
|
| 29 |
+
@pytest.mark.parametrize(
|
| 30 |
+
"input_labels, expected_output",
|
| 31 |
+
[
|
| 32 |
+
([1, 0, 1], "[1, 0, 1]"), # Case: Standard list
|
| 33 |
+
("[1, 0, 1]", "[1, 0, 1]"), # Case: Already a string
|
| 34 |
+
([], "[]"), # Case: Empty list
|
| 35 |
+
(None, None), # Case: None value
|
| 36 |
+
],
|
| 37 |
+
)
|
| 38 |
+
def test_format_labels_for_csv(self, manager: DatasetManager, input_labels, expected_output):
|
| 39 |
+
"""
|
| 40 |
+
Tests the internal _format_labels_for_csv method to ensure
|
| 41 |
+
it correctly serializes label lists (or handles other inputs) to strings.
|
| 42 |
+
"""
|
| 43 |
+
# Arrange
|
| 44 |
+
example = {"labels": input_labels}
|
| 45 |
+
|
| 46 |
+
# Act
|
| 47 |
+
formatted_example = manager._format_labels_for_csv(example)
|
| 48 |
+
|
| 49 |
+
# Assert
|
| 50 |
+
assert formatted_example["labels"] == expected_output
|
| 51 |
+
|
| 52 |
+
def test_get_dataset_raises_file_not_found(self, monkeypatch):
|
| 53 |
+
"""
|
| 54 |
+
Ensures that get_dataset() raises a FileNotFoundError when
|
| 55 |
+
the target interim CSV files do not exist.
|
| 56 |
+
"""
|
| 57 |
+
# Arrange
|
| 58 |
+
# Patch the config to point to a non-existent directory
|
| 59 |
+
fake_dir = Path("/path/that/is/totally/fake")
|
| 60 |
+
monkeypatch.setattr(config, "INTERIM_DATA_DIR", fake_dir)
|
| 61 |
+
|
| 62 |
+
# Manager must be initialized *after* patching config
|
| 63 |
+
manager_with_fake_path = DatasetManager()
|
| 64 |
+
|
| 65 |
+
# Act & Assert
|
| 66 |
+
with pytest.raises(FileNotFoundError, match="Dataset CSV files not found."):
|
| 67 |
+
manager_with_fake_path.get_dataset()
|
| 68 |
+
|
| 69 |
+
def test_get_dataset_success_and_label_parsing(self, fake_csv_data_dir: Path, monkeypatch):
|
| 70 |
+
"""
|
| 71 |
+
Verifies that get_dataset() successfully loads data from mock CSVs
|
| 72 |
+
and correctly parses the string-formatted labels back into lists.
|
| 73 |
+
"""
|
| 74 |
+
# Arrange
|
| 75 |
+
# Point the config at our temporary fixture directory
|
| 76 |
+
monkeypatch.setattr(config, "INTERIM_DATA_DIR", fake_csv_data_dir)
|
| 77 |
+
manager = DatasetManager()
|
| 78 |
+
|
| 79 |
+
# Act
|
| 80 |
+
dataset = manager.get_dataset()
|
| 81 |
+
|
| 82 |
+
# Assert
|
| 83 |
+
# Check that the correct splits were loaded
|
| 84 |
+
assert "java_train" in dataset
|
| 85 |
+
assert "java_test" in dataset
|
| 86 |
+
assert "python_train" not in dataset # Confirms only found files are loaded
|
| 87 |
+
|
| 88 |
+
# Check content integrity
|
| 89 |
+
assert len(dataset["java_train"]) == 2
|
| 90 |
+
assert dataset["java_train"][0]["combo"] == "java code text"
|
| 91 |
+
|
| 92 |
+
# Ccheck that the string '[1, 0, ...]' was parsed back to a list
|
| 93 |
+
expected_labels = [1, 0, 0, 0, 0, 0, 0]
|
| 94 |
+
assert dataset["java_train"][0]["labels"] == expected_labels
|
| 95 |
+
assert isinstance(dataset["java_train"][0]["labels"], list)
|