Source code for gptchem.extractor
import re
from typing import Union
from fastcore.basics import basic_repr
from fastcore.foundation import L
class BaseExtractor:
_stop_sequence = "@@"
def floatify(self, value: str) -> float:
try:
return float(value)
except ValueError:
return None
except TypeError:
return None
def intify(self, value: str) -> int:
try:
return int(self.floatify(value))
except ValueError:
return None
except TypeError:
return None
def split(self, value: str) -> str:
try:
return value.split(self._stop_sequence)[0]
except IndexError:
return None
def extract(self, data, **kwargs) -> Union[str, float, int]:
raise NotImplementedError
def extract_many(self, data, **kwargs) -> L:
return L([self.extract(entry, **kwargs) for entry in data])
def extract_many_from_dict(self, data, key="choices", **kwargs) -> L:
return L(sum([self.extract_many(entry[key], **kwargs) for entry in data], []))
def __call__(self, data, key="choices", **kwargs):
return self.extract_many_from_dict(data, key=key, **kwargs)
__repr__ = basic_repr()
[docs]
class ClassificationExtractor(BaseExtractor):
"""Extract integers from completions of classification tasks."""
def extract(self, data, **kwargs) -> int:
return self.intify(self.split(data).strip())
[docs]
class FewShotClassificationExtractor(BaseExtractor):
"""Extract integers from completions of few-shot classification tasks."""
_FIRST_NUMBER_REGEX = re.compile(r"(\d+)")
def extract(self, data, **kwargs) -> int:
first_number = self._FIRST_NUMBER_REGEX.findall(data)
if first_number:
return self.intify(first_number[0])
return None
[docs]
class FewShotRegressionExtractor(BaseExtractor):
"""Extract floats from completions of few-shot regression tasks."""
_FIRST_NUMBER_REGEX = re.compile(r"(\d+\.\d+)|(\d+)")
def extract(self, data, **kwargs) -> int:
first_number = self._FIRST_NUMBER_REGEX.findall(data)
if first_number:
return self.floatify(first_number[0][0] or first_number[0][1])
return None
[docs]
class RegressionExtractor(BaseExtractor):
"""Extract floats from completions of regression tasks."""
def extract(self, data, **kwargs) -> float:
return self.floatify(self.split(data).strip())
[docs]
class InverseExtractor(BaseExtractor):
"""Extract strings from completions of inverse tasks."""
def extract(self, data, **kwargs) -> float:
return self.split(data).split()[0].strip()
[docs]
class SolventExtractor(BaseExtractor):
"""Extract solvent name and composition from completions of solvent tasks."""
_SOLVENT_REGEX = re.compile(r"(\d+\.\d+)(\s[\w\(\)=\@]+)")
def _find_solvent(self, data):
parts = self._SOLVENT_REGEX.findall(data)
solvents = {}
if parts:
for am, s in parts:
solvents[s.strip()] = float(am)
return solvents
return None
def extract(self, data, **kwargs) -> dict:
return self._find_solvent(self.split(data))