Skip to content

Vertex endpoint

VertexEndpointEmbedding #

Bases: BaseEmbedding

Source code in llama-index-integrations/embeddings/llama-index-embeddings-vertex-endpoint/llama_index/embeddings/vertex_endpoint/base.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class VertexEndpointEmbedding(BaseEmbedding):
    endpoint_id: str = Field(description="Vertex AI endpoint ID")
    project_id: str = Field(description="GCP Project ID")
    location: str = Field(description="GCP Region for Vertex AI")
    endpoint_kwargs: Dict[str, Any] = Field(
        default_factory=dict,
        description="Additional kwargs for the predict request.",
    )
    model_kwargs: Dict[str, Any] = Field(
        default_factory=dict,
        description="kwargs to pass to the model.",
    )
    content_handler: BaseIOHandler = Field(
        default=DEFAULT_IO_HANDLER,
        description="used to format input/output",
    )
    service_account_file: Optional[str] = Field(
        default=None, description="Path to the service account JSON file."
    )
    service_account_info: Optional[Dict[str, str]] = Field(
        default=None, description="Directly provide service account credentials."
    )
    timeout: Optional[float] = Field(
        default=60.0,
        description="Timeout for API requests in seconds.",
        ge=0,
    )
    _client: aiplatform.Endpoint = PrivateAttr()
    _verbose: bool = PrivateAttr()

    def __init__(
        self,
        endpoint_id: str,
        project_id: str,
        location: str,
        content_handler: BaseIOHandler = DEFAULT_IO_HANDLER,
        endpoint_kwargs: Optional[Dict[str, Any]] = {},
        model_kwargs: Optional[Dict[str, Any]] = {},
        service_account_file: Optional[str] = None,
        service_account_info: Optional[Dict[str, str]] = None,
        timeout: Optional[float] = 60.0,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        verbose: bool = False,
    ):
        super().__init__(
            endpoint_id=endpoint_id,
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            project_id=project_id,
            location=location,
            content_handler=content_handler,
            endpoint_kwargs=endpoint_kwargs or {},
            model_kwargs=model_kwargs or {},
            timeout=timeout,
        )

        # Initialize the client
        if service_account_file:
            credentials = service_account.Credentials.from_service_account_file(
                service_account_file
            )
        elif service_account_info:
            credentials = service_account.Credentials.from_service_account_info(
                service_account_info
            )
        else:
            credentials = None  # Use default application credentials if not provided
        try:
            self._client = aiplatform.Endpoint(
                endpoint_name=endpoint_id,
                project=project_id,
                location=location,
                credentials=credentials,
            )
        except Exception as e:
            raise ValueError("Please verify the provided credentials.") from (e)

        self._verbose = verbose

    @classmethod
    def class_name(cls) -> str:
        return "VertexEndpointEmbedding"

    def _get_embedding(self, payload: List[str], **kwargs: Any) -> List[Embedding]:
        # Combine model kwargs with any additional kwargs passed to the function
        endpoint_kwargs = {**self.endpoint_kwargs, **{"timeout": self.timeout}}
        model_kwargs = {**self.model_kwargs, **kwargs}

        # Directly send the input payload to the endpoint
        response = self._client.predict(
            instances=self.content_handler.serialize_input(payload),
            parameters=model_kwargs,
            **endpoint_kwargs
        )

        # Assuming response contains the embeddings in a field called 'predictions'
        return self.content_handler.deserialize_output(response)

    async def _aget_embedding(
        self, payload: List[str], **kwargs: Any
    ) -> List[Embedding]:
        # Combine model kwargs with any additional kwargs passed to the function
        endpoint_kwargs = {**self.endpoint_kwargs, **{"timeout": self.timeout}}
        model_kwargs = {**self.model_kwargs, **kwargs}

        # Directly send the input payload to the endpoint
        response = await self._client.predict_async(
            instances=self.content_handler.serialize_input(payload),
            parameters=model_kwargs,
            **endpoint_kwargs
        )

        # Assuming response contains the embeddings in a field called 'predictions'
        return self.content_handler.deserialize_output(response)

    def _get_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
        query = query.replace("\n", " ")
        return self._get_embedding([query], **kwargs)[0]

    def _get_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
        text = text.replace("\n", " ")
        return self._get_embedding([text], **kwargs)[0]

    def _get_text_embeddings(self, texts: List[str], **kwargs: Any) -> List[Embedding]:
        texts = [text.replace("\n", " ") for text in texts]
        return self._get_embedding(texts, **kwargs)

    async def _aget_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
        query = query.replace("\n", " ")
        return await self._aget_embedding([query], **kwargs)[0]

    async def _aget_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
        text = text.replace("\n", " ")
        return await self._aget_embedding([text], **kwargs)[0]

    async def _aget_text_embeddings(
        self, texts: List[str], **kwargs: Any
    ) -> List[Embedding]:
        texts = [text.replace("\n", " ") for text in texts]
        return await self._aget_embedding(texts, **kwargs)