fix: 修复wss链接的证书问题

This commit is contained in:
CaptainB 2024-09-04 13:23:16 +08:00 committed by 刘瑞斌
parent b500404a41
commit b8ba2458c0
4 changed files with 33 additions and 13 deletions

View File

@ -18,7 +18,7 @@ from hashlib import sha256
from io import BytesIO from io import BytesIO
from typing import Dict from typing import Dict
from urllib.parse import urlparse from urllib.parse import urlparse
import ssl
import websockets import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -61,6 +61,10 @@ NO_COMPRESSION = 0b0000
GZIP = 0b0001 GZIP = 0b0001
CUSTOM_COMPRESSION = 0b1111 CUSTOM_COMPRESSION = 0b1111
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
def generate_header( def generate_header(
version=PROTOCOL_VERSION, version=PROTOCOL_VERSION,
@ -292,7 +296,8 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
header = self.token_auth() header = self.token_auth()
elif self.auth_method == "signature": elif self.auth_method == "signature":
header = self.signature_auth(full_client_request) header = self.signature_auth(full_client_request)
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws: async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
# 发送 full client request # 发送 full client request
await ws.send(full_client_request) await ws.send(full_client_request)
res = await ws.recv() res = await ws.recv()
@ -319,7 +324,8 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
header = self.token_auth() header = self.token_auth()
async def check(): async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws: async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
pass pass
asyncio.run(check()) asyncio.run(check())

View File

@ -14,7 +14,7 @@ import gzip
import json import json
import uuid import uuid
from typing import Dict from typing import Dict
import ssl
import websockets import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -35,6 +35,10 @@ MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression
# reserved data: 0x00 (1 byte) # reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00') default_header = bytearray(b'\x11\x10\x11\x00')
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
volcanic_app_id: str volcanic_app_id: str
@ -68,7 +72,8 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
header = self.token_auth() header = self.token_auth()
async def check(): async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws: async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
pass pass
asyncio.run(check()) asyncio.run(check())
@ -113,7 +118,8 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload full_client_request.extend(payload_bytes) # payload
header = {"Authorization": f"Bearer; {self.volcanic_token}"} header = {"Authorization": f"Bearer; {self.volcanic_token}"}
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws: async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
await ws.send(full_client_request) await ws.send(full_client_request)
return await self.parse_response(ws) return await self.parse_response(ws)

View File

@ -11,7 +11,7 @@ import json
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
import ssl
import websockets import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -21,6 +21,10 @@ STATUS_FIRST_FRAME = 0 # 第一帧的标识
STATUS_CONTINUE_FRAME = 1 # 中间帧标识 STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识 STATUS_LAST_FRAME = 2 # 最后一帧的标识
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str spark_app_id: str
@ -86,14 +90,14 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
def check_auth(self): def check_auth(self):
async def check(): async def check():
async with websockets.connect(self.create_url()) as ws: async with websockets.connect(self.create_url(), ssl=ssl_context) as ws:
pass pass
asyncio.run(check()) asyncio.run(check())
def speech_to_text(self, file): def speech_to_text(self, file):
async def handle(): async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws: async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request # 发送 full client request
await self.send(ws, file) await self.send(ws, file)
return await self.handle_message(ws) return await self.handle_message(ws)

View File

@ -14,7 +14,7 @@ import os
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
import ssl
import websockets import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -24,6 +24,10 @@ STATUS_FIRST_FRAME = 0 # 第一帧的标识
STATUS_CONTINUE_FRAME = 1 # 中间帧标识 STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识 STATUS_LAST_FRAME = 2 # 最后一帧的标识
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
spark_app_id: str spark_app_id: str
@ -89,7 +93,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
def check_auth(self): def check_auth(self):
async def check(): async def check():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws: async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
pass pass
asyncio.run(check()) asyncio.run(check())
@ -99,7 +103,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
# 使用小语种须使用以下方式此处的unicode指的是 utf16小端的编码方式即"UTF-16LE"” # 使用小语种须使用以下方式此处的unicode指的是 utf16小端的编码方式即"UTF-16LE"”
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
async def handle(): async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws: async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request # 发送 full client request
await self.send(ws, text) await self.send(ws, text)
return await self.handle_message(ws) return await self.handle_message(ws)