推荐使用PAI-EAS提供的官方SDK进行服务调用,从而有效减少编写调用逻辑的时间并提高调用稳定性。本文介绍官方Golang SDK接口详情,并以常见类型的输入输出为例,提供了使用Golang SDK进行服务调用的完整程序示例。
背景信息
使用Golang SDK进行服务调用时,由于在编译代码时,Golang的包管理工具会自动从Github上将Golang SDK的代码下载到本地,因此您无需提前安装Golang SDK。如果您需要自定义部分调用逻辑,可以先下载Golang SDK代码,再对其进行修改。
NewPredictClient(endpoint string, serviceName string) *PredictClientSetEndpoint(endpointName string)SetServiceName(serviceName string)SetEndpointType(endpointType string)SetToken(token string)SetHttpTransport(transport *http.Transport)SetRetryCount(max_retry_count int)SetTimeout(timeout int)Init()Init()Predict(request Request) Responseinterface(StringRequest, TFRequest,TorchRequest)interface(StringResponse, TFResponse,TorchResponse)StringPredict(request string) stringTorchPredict(request TorchRequest) TorchResponseTFPredict(request TFRequest) TFResponseTFRequest(signatureName string)AddFeed(?)(inputName string, shape []int64{}, content []?)AddFeedInt32()AddFetch(outputName string)GetTensorShape(outputName string) []int64Get(?)Val(outputName string) [](?)GetTensorShape()GetFloatVal()TorchRequest()AddFeed(?)(index int, shape []int64{}, content []?)AddFeedInt32()AddFetch(outputIndex int)GetTensorShape(outputIndex int) []int64Get(?)Val(outputIndex int) [](?)GetTensorShape()GetFloatVal()NewQueueClient(endpoint, queueName, token string) (*QueueClient, error)Truncate(ctx context.Context, index uint64) errorPut(ctx context.Context, data []byte, tags types.Tags) (index uint64, requestId string, err error)GetByIndex(ctx context.Context, index uint64) (dfs []types.DataFrame, err error)GetByRequestId(ctx context.Context, requestId string) (dfs []types.DataFrame, err error)Get(ctx context.Context, index uint64, length int, timeout time.Duration, autoDelete bool, tags types.Tags) (dfs []types.DataFrame, err error)GetByIndex()GetByRequestId()Get()Del()Del(ctx context.Context, indexes ...uint64)Attributes() (attrs types.Attributes, err error)Watch(ctx context.Context, index, window uint64, indexOnly bool, autocommit bool) (watcher types.Watcher, err error)Commit(ctx context.Context, indexes ...uint64) errorFrameChan() <-chan types.DataFrameClose()
package main
import (
"fmt"
"github.com/pai-eas/eas-golang-sdk/eas"
)
func main() {
client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "scorecard_pmml_example")
client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
client.Init()
req := "[{\"fea1\": 1, \"fea2\": 2}]"
for i := 0; i < 100; i++ {
resp, err := client.StringPredict(req)
if err != nil {
fmt.Printf("failed to predict: %v\n", err.Error())
} else {
fmt.Printf("%v\n", resp)
}
}
}
package main
import (
"fmt"
"github.com/pai-eas/eas-golang-sdk/eas"
)
func main() {
client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "mnist_saved_model_example")
client.SetToken("YTg2ZjE0ZjM4ZmE3OTc0NzYxZDMyNmYzMTJjZTQ1YmU0N2FjMTAy****")
client.Init()
tfreq := eas.TFRequest{}
tfreq.SetSignatureName("predict_images")
tfreq.AddFeedFloat32("images", []int64{1, 784}, make([]float32, 784))
for i := 0; i < 100; i++ {
resp, err := client.TFPredict(tfreq)
if err != nil {
fmt.Printf("failed to predict: %v", err)
} else {
fmt.Printf("%v\n", resp)
}
}
}
package main
import (
"fmt"
"github.com/pai-eas/eas-golang-sdk/eas"
)
func main() {
client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "pytorch_resnet_example")
client.SetTimeout(500)
client.SetToken("ZjdjZDg1NWVlMWI2NTU5YzJiMmY5ZmE5OTBmYzZkMjI0YjlmYWVl****")
client.Init()
req := eas.TorchRequest{}
req.AddFeedFloat32(0, []int64{1, 3, 224, 224}, make([]float32, 150528))
req.AddFetch(0)
for i := 0; i < 10; i++ {
resp, err := client.TorchPredict(req)
if err != nil {
fmt.Printf("failed to predict: %v", err)
} else {
fmt.Println(resp.GetTensorShape(0), resp.GetFloatVal(0))
}
}
}
package main
import (
"fmt"
"github.com/pai-eas/eas-golang-sdk/eas"
)
func main() {
client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "scorecard_pmml_example")
client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
client.SetEndpointType(eas.EndpointTypeDirect)
client.Init()
req := "[{\"fea1\": 1, \"fea2\": 2}]"
for i := 0; i < 100; i++ {
resp, err := client.StringPredict(req)
if err != nil {
fmt.Printf("failed to predict: %v\n", err.Error())
} else {
fmt.Printf("%v\n", resp)
}
}
}
package main
import (
"fmt"
"github.com/pai-eas/eas-golang-sdk/eas"
)
func main() {
client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "network_test")
client.SetToken("MDAwZDQ3NjE3OThhOTI4ODFmMjJiYzE0MDk1NWRkOGI1MmVhMGI0****")
client.SetEndpointType(eas.EndpointTypeDirect)
client.SetHttpTransport(&http.Transport{
MaxConnsPerHost: 300,
TLSHandshakeTimeout: 100 * time.Millisecond,
ResponseHeaderTimeout: 200 * time.Millisecond,
ExpectContinueTimeout: 200 * time.Millisecond,
})
}
const (
QueueEndpoint = "1828488879222746.cn-shanghai.pai-eas.aliyuncs.com"
QueueName = "test_group.qservice"
QueueToken = "YmE3NDkyMzdiMzNmMGM3ZmE4ZmNjZDk0M2NiMDA3OTZmNzc1MTUxNg=="
)
queue, err := NewQueueClient(QueueEndpoint, QueueName, QueueToken)
// truncate all messages in the queue
attrs, err := queue.Attributes()
if index, ok := attrs["stream.lastEntry"]; ok {
idx, _ := strconv.ParseUint(index, 10, 64)
queue.Truncate(context.Background(), idx+1)
}
ctx, cancel := context.WithCancel(context.Background())
// create a goroutine to send messages to the queue
go func() {
i := 0
for {
select {
case <-time.NewTicker(time.Microsecond * 1).C:
_, _, err := queue.Put(context.Background(), []byte(strconv.Itoa(i)), types.Tags{})
if err != nil {
fmt.Printf("Error occured, retry to handle it: %v\n", err)
}
i += 1
case <-ctx.Done():
break
}
}
}()
// create a watcher to watch the messages from the queue
watcher, err := queue.Watch(context.Background(), 0, 5, false, false)
if err != nil {
fmt.Printf("Failed to create a watcher to watch the queue: %v\n", err)
return
}
// read messages from the queue and commit manually
for i := 0; i < 100; i++ {
df := <-watcher.FrameChan()
err := queue.Commit(context.Background(), df.Index.Uint64())
if err != nil {
fmt.Printf("Failed to commit index: %v(%v)\n", df.Index, err)
}
}
// everything is done, close the watcher
watcher.Close()
cancel()