Spaces:
Build error
Build error
| import re | |
| import datasets | |
| import tensorflow as tf | |
| import promptsource.utils | |
| def feature_to_spec(feature, length=False): | |
| if isinstance(feature, datasets.ClassLabel): | |
| return tf.TensorSpec(shape=() if not length else (None if length == -1 else length,), dtype=tf.int64) | |
| elif isinstance(feature, datasets.Value): | |
| return tf.TensorSpec( | |
| shape=() if not length else (None if length == -1 else length,), dtype=getattr(tf.dtypes, feature.dtype) | |
| ) | |
| elif hasattr(feature, "dtype") and hasattr(feature, "shape"): | |
| return tf.TensorSpec(shape=feature.shape, dtype=feature.dtype) | |
| elif isinstance(feature, datasets.Sequence): | |
| return feature_to_spec(feature.feature, length=feature.length) | |
| elif isinstance(feature, list): | |
| return [feature_to_spec(f, length=length) for f in feature] | |
| elif isinstance(feature, dict): | |
| return {k: feature_to_spec(v, length=length) for k, v in feature.items()} | |
| else: | |
| raise ValueError(f"Unparseable feature type {type(feature)}") | |
| def hf_dataset_to_tf_dataset(dataset): | |
| return tf.data.Dataset.from_generator( | |
| dataset.__iter__, output_signature={k: feature_to_spec(v) for k, v in dataset.features.items()} | |
| ) | |
| def apply_template(dataset, template): | |
| def map_fn(ex): | |
| ex = promptsource.utils.removeHyphen(ex) | |
| inputs_and_targets = template.apply(ex) | |
| answer_choices = template.get_answer_choices_list(ex) | |
| if len(inputs_and_targets) == 2: | |
| inputs, targets = inputs_and_targets | |
| if targets == "": | |
| ex = {"inputs": inputs, "targets": "<NO LABEL>"} | |
| else: | |
| ex = {"inputs": inputs, "targets": targets} | |
| # When template results in an empty example, template.apply returns [""] | |
| # Also, if the template gets split wrong, len can be > 2 | |
| # We will filter these out later | |
| else: | |
| ex = {"inputs": "", "targets": ""} | |
| if answer_choices: | |
| ex["answer_choices"] = answer_choices | |
| return ex | |
| def filter_fn(ex): | |
| return len(ex["inputs"]) > 0 and len(ex["targets"]) > 0 | |
| original_columns = dataset.column_names | |
| dataset = dataset.map(map_fn).filter(filter_fn) | |
| # map keeps original columns, remove them | |
| return dataset.remove_columns(set(original_columns) - {"inputs", "targets", "answer_choices"}) | |
| def get_dataset_splits(dataset_name, subset_name=None): | |
| info = datasets.get_dataset_infos(dataset_name) | |
| subset_name = subset_name or list(info.keys())[0] | |
| return info[subset_name].splits | |
| def task_clean(text): | |
| # Clean the text according to allowed characters for a task name | |
| return re.sub(r"[^\w\d\._]+", "_", text) | |
| def get_task_name(dataset_name, subset_name, template_name): | |
| return task_clean(dataset_name + (f"_{subset_name}_" if subset_name is not None else "_") + template_name) | |