Source code for gptchem.querier

from typing import Optional

import openai
import pandas as pd
from fastcore.basics import basic_repr, chunked

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # for exponential backoff


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return openai.Completion.create(**kwargs)


_PRESETS = {
    "classification": {
        "max_tokens": 2,  # first token is whitespace
    },
    "inverse": {
        "max_tokens": 800,
    },
    "regression": {
        "max_tokens": 5,
    },
}


[docs] class Querier: """Wrapper around the OpenAI API for querying a model for completions. This class tries to be as efficient as possible by querying the API in batches. It also handles the rate limiting of the API. Example: >>> querier = Querier("ada") >>> df = pd.DataFrame({"prompt": ["This is a test", "This is another test"]}) >>> completions = querier.query(df) >>> assert len(completions) == 2 True >>> assert all([isinstance(c, str) for c in completions]) True """ _parallel_max = 20 _sleep = 5 _stop = "@@@" def __init__(self, modelname, max_tokens: int = 10): self.modelname = modelname self.max_tokens = max_tokens
[docs] @classmethod def from_preset(cls, modelname: str, preset: str = "classification"): """Factory method to create a Querier from a preset. These presets set the max_tokens parameter to a value that is appropriate for the task. """ if preset not in _PRESETS: raise ValueError( f"Invalid preset: {preset}. Valid presets are: {list(_PRESETS.keys())}" ) return cls(modelname, **_PRESETS[preset])
[docs] def query( self, df: pd.DataFrame, temperature: float = 0, logprobs: Optional[int] = None ) -> dict: """Query the model for completions. Args: df (pd.DataFrame): DataFrame containing a column named "prompt" temperature (float): Temperature of the softmax. Defaults to 0. logprobs (Optional[int]): The number of logprobs to return. For classification, set it to the number of classes. Defaults to None. Raises: ValueError: If df is not a pandas DataFrame ValueError: If df does not have a column named "prompt" AssertionError: If temperature is < 0 Returns: dict: Dictionary containing the completions and logprobs """ if not isinstance(df, pd.DataFrame): raise ValueError("df must be a pandas DataFrame") if "prompt" not in df.columns: raise ValueError("df must have a column named 'prompt'") assert temperature >= 0, "temperature must be >= 0" completions = [] settings = {} if logprobs is not None and isinstance(logprobs, int): settings["logprobs"] = logprobs for chunk in chunked(df["prompt"], self._parallel_max): completions_ = completion_with_backoff( model=self.modelname, prompt=chunk, temperature=temperature, max_tokens=self.max_tokens, stop=self._stop, **settings, ) completions.append(completions_) completions = { "choices": [choice["text"] for c in completions for choice in c["choices"]], "logprobs": [choice["logprobs"] for c in completions for choice in c["choices"]], "model": self.modelname, } return completions
def __call__( self, df: pd.DataFrame, temperature: float = 0, logprobs: Optional[int] = None ) -> dict: return self.query(df, temperature, logprobs) __repr__ = basic_repr("modelname,max_tokens")