推荐使用PAI-EAS提供的官方SDK进行服务调用,从而有效减少编写调用逻辑的时间并提高调用稳定性。本文介绍官方Golang SDK接口详情,并以常见类型的输入输出为例,提供了使用Golang SDK进行服务调用的完整程序示例。

背景信息

使用Golang SDK进行服务调用时,由于在编译代码时,Golang的包管理工具会自动从Github上将Golang SDK的代码下载到本地,因此您无需提前安装Golang SDK。如果您需要自定义部分调用逻辑,可以先下载Golang SDK代码,再对其进行修改。

NewPredictClient(endpoint string, serviceName string) *PredictClientSetEndpoint(endpointName string)SetServiceName(serviceName string)SetEndpointType(endpointType string)SetToken(token string)SetHttpTransport(transport *http.Transport)SetRetryCount(max_retry_count int)SetTimeout(timeout int)Init()Init()Predict(request Request) Responseinterface(StringRequest, TFRequest,TorchRequest)interface(StringResponse, TFResponse,TorchResponse)StringPredict(request string) stringTorchPredict(request TorchRequest) TorchResponseTFPredict(request TFRequest) TFResponseTFRequest(signatureName string)AddFeed(?)(inputName string, shape []int64{}, content []?)AddFeedInt32()AddFetch(outputName string)GetTensorShape(outputName string) []int64Get(?)Val(outputName string) [](?)GetTensorShape()GetFloatVal()TorchRequest()AddFeed(?)(index int, shape []int64{}, content []?)AddFeedInt32()AddFetch(outputIndex int)GetTensorShape(outputIndex int) []int64Get(?)Val(outputIndex int) [](?)GetTensorShape()GetFloatVal()NewQueueClient(endpoint, queueName, token string) (*QueueClient, error)Truncate(ctx context.Context, index uint64) errorPut(ctx context.Context, data []byte, tags types.Tags) (index uint64, requestId string, err error)GetByIndex(ctx context.Context, index uint64) (dfs []types.DataFrame, err error)GetByRequestId(ctx context.Context, requestId string) (dfs []types.DataFrame, err error)Get(ctx context.Context, index uint64, length int, timeout time.Duration, autoDelete bool, tags types.Tags) (dfs []types.DataFrame, err error)GetByIndex()GetByRequestId()Get()Del()Del(ctx context.Context, indexes ...uint64)Attributes() (attrs types.Attributes, err error)Watch(ctx context.Context, index, window uint64, indexOnly bool, autocommit bool) (watcher types.Watcher, err error)Commit(ctx context.Context, indexes ...uint64) errorFrameChan() <-chan types.DataFrameClose()
package main

import (
        "fmt"
        "github.com/pai-eas/eas-golang-sdk/eas"
)

func main() {
    client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "scorecard_pmml_example")
    client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
    client.Init()
    req := "[{\"fea1\": 1, \"fea2\": 2}]"
    for i := 0; i < 100; i++ {
        resp, err := client.StringPredict(req)
        if err != nil {
            fmt.Printf("failed to predict: %v\n", err.Error())
        } else {
            fmt.Printf("%v\n", resp)
        }
    }
}
package main

import (
        "fmt"
        "github.com/pai-eas/eas-golang-sdk/eas"
)

func main() {
    client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "mnist_saved_model_example")
    client.SetToken("YTg2ZjE0ZjM4ZmE3OTc0NzYxZDMyNmYzMTJjZTQ1YmU0N2FjMTAy****")
    client.Init()

    tfreq := eas.TFRequest{}
    tfreq.SetSignatureName("predict_images")
    tfreq.AddFeedFloat32("images", []int64{1, 784}, make([]float32, 784))

    for i := 0; i < 100; i++ {
        resp, err := client.TFPredict(tfreq)
        if err != nil {
            fmt.Printf("failed to predict: %v", err)
        } else {
            fmt.Printf("%v\n", resp)
        }
    }
}
package main

import (
        "fmt"
        "github.com/pai-eas/eas-golang-sdk/eas"
)

func main() {
    client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "pytorch_resnet_example")
    client.SetTimeout(500)
    client.SetToken("ZjdjZDg1NWVlMWI2NTU5YzJiMmY5ZmE5OTBmYzZkMjI0YjlmYWVl****")
    client.Init()
    req := eas.TorchRequest{}
    req.AddFeedFloat32(0, []int64{1, 3, 224, 224}, make([]float32, 150528))
    req.AddFetch(0)
    for i := 0; i < 10; i++ {
        resp, err := client.TorchPredict(req)
        if err != nil {
            fmt.Printf("failed to predict: %v", err)
        } else {
            fmt.Println(resp.GetTensorShape(0), resp.GetFloatVal(0))
        }
    }
}
package main

import (
        "fmt"
        "github.com/pai-eas/eas-golang-sdk/eas"
)

func main() {
    client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "scorecard_pmml_example")
    client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
    client.SetEndpointType(eas.EndpointTypeDirect)
    client.Init()
    req := "[{\"fea1\": 1, \"fea2\": 2}]"
    for i := 0; i < 100; i++ {
        resp, err := client.StringPredict(req)
        if err != nil {
            fmt.Printf("failed to predict: %v\n", err.Error())
        } else {
            fmt.Printf("%v\n", resp)
        }
    }
}
package main

import (
        "fmt"
        "github.com/pai-eas/eas-golang-sdk/eas"
)

func main() {
    client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "network_test")
    client.SetToken("MDAwZDQ3NjE3OThhOTI4ODFmMjJiYzE0MDk1NWRkOGI1MmVhMGI0****")
    client.SetEndpointType(eas.EndpointTypeDirect)
    client.SetHttpTransport(&http.Transport{
        MaxConnsPerHost:       300,
        TLSHandshakeTimeout:   100 * time.Millisecond,
        ResponseHeaderTimeout: 200 * time.Millisecond,
        ExpectContinueTimeout: 200 * time.Millisecond,
    })
}
    const (
        QueueEndpoint = "1828488879222746.cn-shanghai.pai-eas.aliyuncs.com"
        QueueName     = "test_group.qservice"
        QueueToken    = "YmE3NDkyMzdiMzNmMGM3ZmE4ZmNjZDk0M2NiMDA3OTZmNzc1MTUxNg=="
    )
    queue, err := NewQueueClient(QueueEndpoint, QueueName, QueueToken)

    // truncate all messages in the queue
    attrs, err := queue.Attributes()
    if index, ok := attrs["stream.lastEntry"]; ok {
        idx, _ := strconv.ParseUint(index, 10, 64)
        queue.Truncate(context.Background(), idx+1)
    }

    ctx, cancel := context.WithCancel(context.Background())

    // create a goroutine to send messages to the queue
    go func() {
        i := 0
        for {
            select {
            case <-time.NewTicker(time.Microsecond * 1).C:
                _, _, err := queue.Put(context.Background(), []byte(strconv.Itoa(i)), types.Tags{})
                if err != nil {
                    fmt.Printf("Error occured, retry to handle it: %v\n", err)
                }
                i += 1
            case <-ctx.Done():
                break
            }
        }
    }()

    // create a watcher to watch the messages from the queue
    watcher, err := queue.Watch(context.Background(), 0, 5, false, false)
    if err != nil {
        fmt.Printf("Failed to create a watcher to watch the queue: %v\n", err)
        return
    }

    // read messages from the queue and commit manually
    for i := 0; i < 100; i++ {
        df := <-watcher.FrameChan()
        err := queue.Commit(context.Background(), df.Index.Uint64())
        if err != nil {
            fmt.Printf("Failed to commit index: %v(%v)\n", df.Index, err)
        }
    }

    // everything is done, close the watcher
    watcher.Close()
    cancel()