Skip to content

Fastembed

FastEmbedEmbedding #

Bases: BaseEmbedding

Qdrant FastEmbedding models. FastEmbed is a lightweight, fast, Python library built for embedding generation. See more documentation at: * https://github.com/qdrant/fastembed/ * https://qdrant.github.io/fastembed/.

To use this class, you must install the fastembed Python package.

pip install fastembed Example: from llama_index.embeddings.fastembed import FastEmbedEmbedding fastembed = FastEmbedEmbedding()

Source code in llama-index-integrations/embeddings/llama-index-embeddings-fastembed/llama_index/embeddings/fastembed/base.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class FastEmbedEmbedding(BaseEmbedding):
    """
    Qdrant FastEmbedding models.
    FastEmbed is a lightweight, fast, Python library built for embedding generation.
    See more documentation at:
    * https://github.com/qdrant/fastembed/
    * https://qdrant.github.io/fastembed/.

    To use this class, you must install the `fastembed` Python package.

    `pip install fastembed`
    Example:
        from llama_index.embeddings.fastembed import FastEmbedEmbedding
        fastembed = FastEmbedEmbedding()
    """

    model_name: str = Field(
        "BAAI/bge-small-en-v1.5",
        description="Name of the FastEmbedding model to use.\n"
        "Defaults to 'BAAI/bge-small-en-v1.5'.\n"
        "Find the list of supported models at "
        "https://qdrant.github.io/fastembed/examples/Supported_Models/",
    )

    max_length: int = Field(
        512,
        description="The maximum number of tokens. Defaults to 512.\n"
        "Unknown behavior for values > 512.",
    )

    cache_dir: Optional[str] = Field(
        None,
        description="The path to the cache directory.\n"
        "Defaults to `local_cache` in the parent directory",
    )

    threads: Optional[int] = Field(
        None,
        description="The number of threads single onnxruntime session can use.\n"
        "Defaults to None",
    )

    doc_embed_type: Literal["default", "passage"] = Field(
        "default",
        description="Type of embedding method to use for documents.\n"
        "Available options are 'default' and 'passage'.",
    )

    providers: Optional[List[str]] = Field(
        default=None,
        description="The ONNX providers to use for the embedding model.",
    )

    _model: Any = PrivateAttr()

    @classmethod
    def class_name(self) -> str:
        return "FastEmbedEmbedding"

    def __init__(
        self,
        model_name: Optional[str] = "BAAI/bge-small-en-v1.5",
        max_length: Optional[int] = 512,
        cache_dir: Optional[str] = None,
        threads: Optional[int] = None,
        doc_embed_type: Literal["default", "passage"] = "default",
        providers: Optional[List[str]] = None,
    ):
        super().__init__(
            model_name=model_name,
            max_length=max_length,
            threads=threads,
            doc_embed_type=doc_embed_type,
            providers=providers,
        )

        self._model = TextEmbedding(
            model_name=model_name,
            max_length=max_length,
            cache_dir=cache_dir,
            threads=threads,
            providers=providers,
        )

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings: List[np.ndarray]
        if self.doc_embed_type == "passage":
            embeddings = list(self._model.passage_embed(text))
        else:
            embeddings = list(self._model.embed(text))
        return embeddings[0].tolist()

    def _get_query_embedding(self, query: str) -> List[float]:
        query_embeddings: np.ndarray = next(self._model.query_embed(query))
        return query_embeddings.tolist()

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)